mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	This helps preventing messages being sent with the wrong update type and payload combination, and it is shorter/neater. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
		
			
				
	
	
		
			465 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			465 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package notifier
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"sort"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/juanfont/headscale/hscontrol/types"
 | |
| 	"github.com/puzpuzpuz/xsync/v3"
 | |
| 	"github.com/rs/zerolog/log"
 | |
| 	"github.com/sasha-s/go-deadlock"
 | |
| 	"tailscale.com/envknob"
 | |
| 	"tailscale.com/tailcfg"
 | |
| 	"tailscale.com/util/set"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	debugDeadlock        = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
 | |
| 	debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	deadlock.Opts.Disable = !debugDeadlock
 | |
| 	if debugDeadlock {
 | |
| 		deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
 | |
| 		deadlock.Opts.PrintAllCurrentGoroutines = true
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type Notifier struct {
 | |
| 	l         deadlock.Mutex
 | |
| 	nodes     map[types.NodeID]chan<- types.StateUpdate
 | |
| 	connected *xsync.MapOf[types.NodeID, bool]
 | |
| 	b         *batcher
 | |
| 	cfg       *types.Config
 | |
| 	closed    bool
 | |
| }
 | |
| 
 | |
| func NewNotifier(cfg *types.Config) *Notifier {
 | |
| 	n := &Notifier{
 | |
| 		nodes:     make(map[types.NodeID]chan<- types.StateUpdate),
 | |
| 		connected: xsync.NewMapOf[types.NodeID, bool](),
 | |
| 		cfg:       cfg,
 | |
| 		closed:    false,
 | |
| 	}
 | |
| 	b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
 | |
| 	n.b = b
 | |
| 
 | |
| 	go b.doWork()
 | |
| 	return n
 | |
| }
 | |
| 
 | |
| // Close stops the batcher and closes all channels.
 | |
| func (n *Notifier) Close() {
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
 | |
| 
 | |
| 	n.closed = true
 | |
| 	n.b.close()
 | |
| 
 | |
| 	for _, c := range n.nodes {
 | |
| 		close(c)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
 | |
| 	log.Trace().
 | |
| 		Uint64("node.id", nID.Uint64()).
 | |
| 		Int("open_chans", len(n.nodes)).Msgf(msg, args...)
 | |
| }
 | |
| 
 | |
| func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
 | |
| 	start := time.Now()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
 | |
| 	notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
 | |
| 
 | |
| 	if n.closed {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// If a channel exists, it means the node has opened a new
 | |
| 	// connection. Close the old channel and replace it.
 | |
| 	if curr, ok := n.nodes[nodeID]; ok {
 | |
| 		n.tracef(nodeID, "channel present, closing and replacing")
 | |
| 		close(curr)
 | |
| 	}
 | |
| 
 | |
| 	n.nodes[nodeID] = c
 | |
| 	n.connected.Store(nodeID, true)
 | |
| 
 | |
| 	n.tracef(nodeID, "added new channel")
 | |
| 	notifierNodeUpdateChans.Inc()
 | |
| }
 | |
| 
 | |
| // RemoveNode removes a node and a given channel from the notifier.
 | |
| // It checks that the channel is the same as currently being updated
 | |
| // and ignores the removal if it is not.
 | |
| // RemoveNode reports if the node/chan was removed.
 | |
| func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
 | |
| 	start := time.Now()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
 | |
| 	notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
 | |
| 
 | |
| 	if n.closed {
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	if len(n.nodes) == 0 {
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	// If the channel exist, but it does not belong
 | |
| 	// to the caller, ignore.
 | |
| 	if curr, ok := n.nodes[nodeID]; ok {
 | |
| 		if curr != c {
 | |
| 			n.tracef(nodeID, "channel has been replaced, not removing")
 | |
| 			return false
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	delete(n.nodes, nodeID)
 | |
| 	n.connected.Store(nodeID, false)
 | |
| 
 | |
| 	n.tracef(nodeID, "removed channel")
 | |
| 	notifierNodeUpdateChans.Dec()
 | |
| 
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // IsConnected reports if a node is connected to headscale and has a
 | |
| // poll session open.
 | |
| func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
 | |
| 
 | |
| 	if val, ok := n.connected.Load(nodeID); ok {
 | |
| 		return val
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // IsLikelyConnected reports if a node is connected to headscale and has a
 | |
| // poll session open, but doesn't lock, so might be wrong.
 | |
| func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
 | |
| 	if val, ok := n.connected.Load(nodeID); ok {
 | |
| 		return val
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
 | |
| 	return n.connected
 | |
| }
 | |
| 
 | |
| func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
 | |
| 	n.NotifyWithIgnore(ctx, update)
 | |
| }
 | |
| 
 | |
| func (n *Notifier) NotifyWithIgnore(
 | |
| 	ctx context.Context,
 | |
| 	update types.StateUpdate,
 | |
| 	ignoreNodeIDs ...types.NodeID,
 | |
| ) {
 | |
| 	if n.closed {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
 | |
| 	n.b.addOrPassthrough(update)
 | |
| }
 | |
| 
 | |
| func (n *Notifier) NotifyByNodeID(
 | |
| 	ctx context.Context,
 | |
| 	update types.StateUpdate,
 | |
| 	nodeID types.NodeID,
 | |
| ) {
 | |
| 	start := time.Now()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
 | |
| 	notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
 | |
| 
 | |
| 	if n.closed {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if c, ok := n.nodes[nodeID]; ok {
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			log.Error().
 | |
| 				Err(ctx.Err()).
 | |
| 				Uint64("node.id", nodeID.Uint64()).
 | |
| 				Any("origin", types.NotifyOriginKey.Value(ctx)).
 | |
| 				Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
 | |
| 				Msgf("update not sent, context cancelled")
 | |
| 			if debugHighCardinalityMetrics {
 | |
| 				notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
 | |
| 			} else {
 | |
| 				notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
 | |
| 			}
 | |
| 
 | |
| 			return
 | |
| 		case c <- update:
 | |
| 			n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
 | |
| 			if debugHighCardinalityMetrics {
 | |
| 				notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
 | |
| 			} else {
 | |
| 				notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (n *Notifier) sendAll(update types.StateUpdate) {
 | |
| 	start := time.Now()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
 | |
| 	notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
 | |
| 
 | |
| 	if n.closed {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for id, c := range n.nodes {
 | |
| 		// Whenever an update is sent to all nodes, there is a chance that the node
 | |
| 		// has disconnected and the goroutine that was supposed to consume the update
 | |
| 		// has shut down the channel and is waiting for the lock held here in RemoveNode.
 | |
| 		// This means that there is potential for a deadlock which would stop all updates
 | |
| 		// going out to clients. This timeout prevents that from happening by moving on to the
 | |
| 		// next node if the context is cancelled. After sendAll releases the lock, the add/remove
 | |
| 		// call will succeed and the update will go to the correct nodes on the next call.
 | |
| 		ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
 | |
| 		defer cancel()
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			log.Error().
 | |
| 				Err(ctx.Err()).
 | |
| 				Uint64("node.id", id.Uint64()).
 | |
| 				Msgf("update not sent, context cancelled")
 | |
| 			if debugHighCardinalityMetrics {
 | |
| 				notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
 | |
| 			} else {
 | |
| 				notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
 | |
| 			}
 | |
| 
 | |
| 			return
 | |
| 		case c <- update:
 | |
| 			if debugHighCardinalityMetrics {
 | |
| 				notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
 | |
| 			} else {
 | |
| 				notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (n *Notifier) String() string {
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
 | |
| 	n.l.Lock()
 | |
| 	defer n.l.Unlock()
 | |
| 	notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
 | |
| 
 | |
| 	var b strings.Builder
 | |
| 	fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
 | |
| 
 | |
| 	var keys []types.NodeID
 | |
| 	n.connected.Range(func(key types.NodeID, value bool) bool {
 | |
| 		keys = append(keys, key)
 | |
| 		return true
 | |
| 	})
 | |
| 	sort.Slice(keys, func(i, j int) bool {
 | |
| 		return keys[i] < keys[j]
 | |
| 	})
 | |
| 
 | |
| 	for _, key := range keys {
 | |
| 		fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
 | |
| 	}
 | |
| 
 | |
| 	b.WriteString("\n")
 | |
| 	fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
 | |
| 
 | |
| 	for _, key := range keys {
 | |
| 		val, _ := n.connected.Load(key)
 | |
| 		fmt.Fprintf(&b, "\t%d: %t\n", key, val)
 | |
| 	}
 | |
| 
 | |
| 	return b.String()
 | |
| }
 | |
| 
 | |
| type batcher struct {
 | |
| 	tick *time.Ticker
 | |
| 
 | |
| 	mu sync.Mutex
 | |
| 
 | |
| 	cancelCh chan struct{}
 | |
| 
 | |
| 	changedNodeIDs set.Slice[types.NodeID]
 | |
| 	nodesChanged   bool
 | |
| 	patches        map[types.NodeID]tailcfg.PeerChange
 | |
| 	patchesChanged bool
 | |
| 
 | |
| 	n *Notifier
 | |
| }
 | |
| 
 | |
| func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
 | |
| 	return &batcher{
 | |
| 		tick:     time.NewTicker(batchTime),
 | |
| 		cancelCh: make(chan struct{}),
 | |
| 		patches:  make(map[types.NodeID]tailcfg.PeerChange),
 | |
| 		n:        n,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (b *batcher) close() {
 | |
| 	b.cancelCh <- struct{}{}
 | |
| }
 | |
| 
 | |
| // addOrPassthrough adds the update to the batcher, if it is not a
 | |
| // type that is currently batched, it will be sent immediately.
 | |
| func (b *batcher) addOrPassthrough(update types.StateUpdate) {
 | |
| 	notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
 | |
| 	b.mu.Lock()
 | |
| 	defer b.mu.Unlock()
 | |
| 	notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
 | |
| 
 | |
| 	switch update.Type {
 | |
| 	case types.StatePeerChanged:
 | |
| 		b.changedNodeIDs.Add(update.ChangeNodes...)
 | |
| 		b.nodesChanged = true
 | |
| 		notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
 | |
| 
 | |
| 	case types.StatePeerChangedPatch:
 | |
| 		for _, newPatch := range update.ChangePatches {
 | |
| 			if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
 | |
| 				overwritePatch(&curr, newPatch)
 | |
| 				b.patches[types.NodeID(newPatch.NodeID)] = curr
 | |
| 			} else {
 | |
| 				b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
 | |
| 			}
 | |
| 		}
 | |
| 		b.patchesChanged = true
 | |
| 		notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
 | |
| 
 | |
| 	default:
 | |
| 		b.n.sendAll(update)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // flush sends all the accumulated patches to all
 | |
| // nodes in the notifier.
 | |
| func (b *batcher) flush() {
 | |
| 	notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
 | |
| 	b.mu.Lock()
 | |
| 	defer b.mu.Unlock()
 | |
| 	notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
 | |
| 
 | |
| 	if b.nodesChanged || b.patchesChanged {
 | |
| 		var patches []*tailcfg.PeerChange
 | |
| 		// If a node is getting a full update from a change
 | |
| 		// node update, then the patch can be dropped.
 | |
| 		for nodeID, patch := range b.patches {
 | |
| 			if b.changedNodeIDs.Contains(nodeID) {
 | |
| 				delete(b.patches, nodeID)
 | |
| 			} else {
 | |
| 				patches = append(patches, &patch)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		changedNodes := b.changedNodeIDs.Slice().AsSlice()
 | |
| 		sort.Slice(changedNodes, func(i, j int) bool {
 | |
| 			return changedNodes[i] < changedNodes[j]
 | |
| 		})
 | |
| 
 | |
| 		if b.changedNodeIDs.Slice().Len() > 0 {
 | |
| 			update := types.UpdatePeerChanged(changedNodes...)
 | |
| 
 | |
| 			b.n.sendAll(update)
 | |
| 		}
 | |
| 
 | |
| 		if len(patches) > 0 {
 | |
| 			patchUpdate := types.UpdatePeerPatch(patches...)
 | |
| 
 | |
| 			b.n.sendAll(patchUpdate)
 | |
| 		}
 | |
| 
 | |
| 		b.changedNodeIDs = set.Slice[types.NodeID]{}
 | |
| 		notifierBatcherChanges.WithLabelValues().Set(0)
 | |
| 		b.nodesChanged = false
 | |
| 		b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
 | |
| 		notifierBatcherPatches.WithLabelValues().Set(0)
 | |
| 		b.patchesChanged = false
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (b *batcher) doWork() {
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-b.cancelCh:
 | |
| 			return
 | |
| 		case <-b.tick.C:
 | |
| 			b.flush()
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // overwritePatch takes the current patch and a newer patch
 | |
| // and override any field that has changed.
 | |
| func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
 | |
| 	if newPatch.DERPRegion != 0 {
 | |
| 		currPatch.DERPRegion = newPatch.DERPRegion
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.Cap != 0 {
 | |
| 		currPatch.Cap = newPatch.Cap
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.CapMap != nil {
 | |
| 		currPatch.CapMap = newPatch.CapMap
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.Endpoints != nil {
 | |
| 		currPatch.Endpoints = newPatch.Endpoints
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.Key != nil {
 | |
| 		currPatch.Key = newPatch.Key
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.KeySignature != nil {
 | |
| 		currPatch.KeySignature = newPatch.KeySignature
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.DiscoKey != nil {
 | |
| 		currPatch.DiscoKey = newPatch.DiscoKey
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.Online != nil {
 | |
| 		currPatch.Online = newPatch.Online
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.LastSeen != nil {
 | |
| 		currPatch.LastSeen = newPatch.LastSeen
 | |
| 	}
 | |
| 
 | |
| 	if newPatch.KeyExpiry != nil {
 | |
| 		currPatch.KeyExpiry = newPatch.KeyExpiry
 | |
| 	}
 | |
| }
 |