mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	batch updates in notifier (#1905)
This commit is contained in:
		
							parent
							
								
									fef8261339
								
							
						
					
					
						commit
						cb0b495ea9
					
				@ -137,7 +137,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
				
			|||||||
		noisePrivateKey:    noisePrivateKey,
 | 
							noisePrivateKey:    noisePrivateKey,
 | 
				
			||||||
		registrationCache:  registrationCache,
 | 
							registrationCache:  registrationCache,
 | 
				
			||||||
		pollNetMapStreamWG: sync.WaitGroup{},
 | 
							pollNetMapStreamWG: sync.WaitGroup{},
 | 
				
			||||||
		nodeNotifier:       notifier.NewNotifier(),
 | 
							nodeNotifier:       notifier.NewNotifier(cfg),
 | 
				
			||||||
		mapSessions:        make(map[types.NodeID]*mapSession),
 | 
							mapSessions:        make(map[types.NodeID]*mapSession),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,12 @@ var (
 | 
				
			|||||||
		Namespace: prometheusNamespace,
 | 
							Namespace: prometheusNamespace,
 | 
				
			||||||
		Name:      "notifier_update_sent_total",
 | 
							Name:      "notifier_update_sent_total",
 | 
				
			||||||
		Help:      "total count of update sent on nodes channel",
 | 
							Help:      "total count of update sent on nodes channel",
 | 
				
			||||||
	}, []string{"status", "type"})
 | 
						}, []string{"status", "type", "trigger"})
 | 
				
			||||||
 | 
						notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
 | 
				
			||||||
 | 
							Namespace: prometheusNamespace,
 | 
				
			||||||
 | 
							Name:      "notifier_update_received_total",
 | 
				
			||||||
 | 
							Help:      "total count of updates received by notifier",
 | 
				
			||||||
 | 
						}, []string{"type", "trigger"})
 | 
				
			||||||
	notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
 | 
						notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
 | 
				
			||||||
		Namespace: prometheusNamespace,
 | 
							Namespace: prometheusNamespace,
 | 
				
			||||||
		Name:      "notifier_open_channels_total",
 | 
							Name:      "notifier_open_channels_total",
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@ package notifier
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"slices"
 | 
						"sort"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@ -11,19 +11,27 @@ import (
 | 
				
			|||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
	"github.com/puzpuzpuz/xsync/v3"
 | 
						"github.com/puzpuzpuz/xsync/v3"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
						"tailscale.com/util/set"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Notifier struct {
 | 
					type Notifier struct {
 | 
				
			||||||
	l         sync.RWMutex
 | 
						l         sync.RWMutex
 | 
				
			||||||
	nodes     map[types.NodeID]chan<- types.StateUpdate
 | 
						nodes     map[types.NodeID]chan<- types.StateUpdate
 | 
				
			||||||
	connected *xsync.MapOf[types.NodeID, bool]
 | 
						connected *xsync.MapOf[types.NodeID, bool]
 | 
				
			||||||
 | 
						b         *batcher
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewNotifier() *Notifier {
 | 
					func NewNotifier(cfg *types.Config) *Notifier {
 | 
				
			||||||
	return &Notifier{
 | 
						n := &Notifier{
 | 
				
			||||||
		nodes:     make(map[types.NodeID]chan<- types.StateUpdate),
 | 
							nodes:     make(map[types.NodeID]chan<- types.StateUpdate),
 | 
				
			||||||
		connected: xsync.NewMapOf[types.NodeID, bool](),
 | 
							connected: xsync.NewMapOf[types.NodeID, bool](),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
 | 
				
			||||||
 | 
						n.b = b
 | 
				
			||||||
 | 
						// TODO(kradalby): clean this up
 | 
				
			||||||
 | 
						go b.doWork()
 | 
				
			||||||
 | 
						return n
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
 | 
					func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
 | 
				
			||||||
@ -108,13 +116,8 @@ func (n *Notifier) NotifyWithIgnore(
 | 
				
			|||||||
	update types.StateUpdate,
 | 
						update types.StateUpdate,
 | 
				
			||||||
	ignoreNodeIDs ...types.NodeID,
 | 
						ignoreNodeIDs ...types.NodeID,
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
	for nodeID := range n.nodes {
 | 
						notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
 | 
				
			||||||
		if slices.Contains(ignoreNodeIDs, nodeID) {
 | 
						n.b.addOrPassthrough(update)
 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		n.NotifyByNodeID(ctx, update, nodeID)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (n *Notifier) NotifyByNodeID(
 | 
					func (n *Notifier) NotifyByNodeID(
 | 
				
			||||||
@ -139,10 +142,10 @@ func (n *Notifier) NotifyByNodeID(
 | 
				
			|||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Err(ctx.Err()).
 | 
									Err(ctx.Err()).
 | 
				
			||||||
				Uint64("node.id", nodeID.Uint64()).
 | 
									Uint64("node.id", nodeID.Uint64()).
 | 
				
			||||||
				Any("origin", ctx.Value("origin")).
 | 
									Any("origin", types.NotifyOriginKey.Value(ctx)).
 | 
				
			||||||
				Any("origin-hostname", ctx.Value("hostname")).
 | 
									Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
 | 
				
			||||||
				Msgf("update not sent, context cancelled")
 | 
									Msgf("update not sent, context cancelled")
 | 
				
			||||||
			notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc()
 | 
								notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		case c <- update:
 | 
							case c <- update:
 | 
				
			||||||
@ -151,11 +154,23 @@ func (n *Notifier) NotifyByNodeID(
 | 
				
			|||||||
				Any("origin", ctx.Value("origin")).
 | 
									Any("origin", ctx.Value("origin")).
 | 
				
			||||||
				Any("origin-hostname", ctx.Value("hostname")).
 | 
									Any("origin-hostname", ctx.Value("hostname")).
 | 
				
			||||||
				Msgf("update successfully sent on chan")
 | 
									Msgf("update successfully sent on chan")
 | 
				
			||||||
			notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc()
 | 
								notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (n *Notifier) sendAll(update types.StateUpdate) {
 | 
				
			||||||
 | 
						start := time.Now()
 | 
				
			||||||
 | 
						n.l.RLock()
 | 
				
			||||||
 | 
						defer n.l.RUnlock()
 | 
				
			||||||
 | 
						notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, c := range n.nodes {
 | 
				
			||||||
 | 
							c <- update
 | 
				
			||||||
 | 
							notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (n *Notifier) String() string {
 | 
					func (n *Notifier) String() string {
 | 
				
			||||||
	n.l.RLock()
 | 
						n.l.RLock()
 | 
				
			||||||
	defer n.l.RUnlock()
 | 
						defer n.l.RUnlock()
 | 
				
			||||||
@ -177,3 +192,166 @@ func (n *Notifier) String() string {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return b.String()
 | 
						return b.String()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type batcher struct {
 | 
				
			||||||
 | 
						tick *time.Ticker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mu sync.Mutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cancelCh chan struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						changedNodeIDs set.Slice[types.NodeID]
 | 
				
			||||||
 | 
						nodesChanged   bool
 | 
				
			||||||
 | 
						patches        map[types.NodeID]tailcfg.PeerChange
 | 
				
			||||||
 | 
						patchesChanged bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n *Notifier
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
 | 
				
			||||||
 | 
						return &batcher{
 | 
				
			||||||
 | 
							tick:     time.NewTicker(batchTime),
 | 
				
			||||||
 | 
							cancelCh: make(chan struct{}),
 | 
				
			||||||
 | 
							patches:  make(map[types.NodeID]tailcfg.PeerChange),
 | 
				
			||||||
 | 
							n:        n,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (b *batcher) close() {
 | 
				
			||||||
 | 
						b.cancelCh <- struct{}{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// addOrPassthrough adds the update to the batcher, if it is not a
 | 
				
			||||||
 | 
					// type that is currently batched, it will be sent immediately.
 | 
				
			||||||
 | 
					func (b *batcher) addOrPassthrough(update types.StateUpdate) {
 | 
				
			||||||
 | 
						b.mu.Lock()
 | 
				
			||||||
 | 
						defer b.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch update.Type {
 | 
				
			||||||
 | 
						case types.StatePeerChanged:
 | 
				
			||||||
 | 
							b.changedNodeIDs.Add(update.ChangeNodes...)
 | 
				
			||||||
 | 
							b.nodesChanged = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						case types.StatePeerChangedPatch:
 | 
				
			||||||
 | 
							for _, newPatch := range update.ChangePatches {
 | 
				
			||||||
 | 
								if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
 | 
				
			||||||
 | 
									overwritePatch(&curr, newPatch)
 | 
				
			||||||
 | 
									b.patches[types.NodeID(newPatch.NodeID)] = curr
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							b.patchesChanged = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							b.n.sendAll(update)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// flush sends all the accumulated patches to all
 | 
				
			||||||
 | 
					// nodes in the notifier.
 | 
				
			||||||
 | 
					func (b *batcher) flush() {
 | 
				
			||||||
 | 
						b.mu.Lock()
 | 
				
			||||||
 | 
						defer b.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if b.nodesChanged || b.patchesChanged {
 | 
				
			||||||
 | 
							var patches []*tailcfg.PeerChange
 | 
				
			||||||
 | 
							// If a node is getting a full update from a change
 | 
				
			||||||
 | 
							// node update, then the patch can be dropped.
 | 
				
			||||||
 | 
							for nodeID, patch := range b.patches {
 | 
				
			||||||
 | 
								if b.changedNodeIDs.Contains(nodeID) {
 | 
				
			||||||
 | 
									delete(b.patches, nodeID)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									patches = append(patches, &patch)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							changedNodes := b.changedNodeIDs.Slice().AsSlice()
 | 
				
			||||||
 | 
							sort.Slice(changedNodes, func(i, j int) bool {
 | 
				
			||||||
 | 
								return changedNodes[i] < changedNodes[j]
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if b.changedNodeIDs.Slice().Len() > 0 {
 | 
				
			||||||
 | 
								update := types.StateUpdate{
 | 
				
			||||||
 | 
									Type:        types.StatePeerChanged,
 | 
				
			||||||
 | 
									ChangeNodes: changedNodes,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								b.n.sendAll(update)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(patches) > 0 {
 | 
				
			||||||
 | 
								patchUpdate := types.StateUpdate{
 | 
				
			||||||
 | 
									Type:          types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
									ChangePatches: patches,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								b.n.sendAll(patchUpdate)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							b.changedNodeIDs = set.Slice[types.NodeID]{}
 | 
				
			||||||
 | 
							b.nodesChanged = false
 | 
				
			||||||
 | 
							b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
 | 
				
			||||||
 | 
							b.patchesChanged = false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (b *batcher) doWork() {
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-b.cancelCh:
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							case <-b.tick.C:
 | 
				
			||||||
 | 
								b.flush()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// overwritePatch takes the current patch and a newer patch
 | 
				
			||||||
 | 
					// and override any field that has changed
 | 
				
			||||||
 | 
					func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
 | 
				
			||||||
 | 
						if newPatch.DERPRegion != 0 {
 | 
				
			||||||
 | 
							currPatch.DERPRegion = newPatch.DERPRegion
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.Cap != 0 {
 | 
				
			||||||
 | 
							currPatch.Cap = newPatch.Cap
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.CapMap != nil {
 | 
				
			||||||
 | 
							currPatch.CapMap = newPatch.CapMap
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.Endpoints != nil {
 | 
				
			||||||
 | 
							currPatch.Endpoints = newPatch.Endpoints
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.Key != nil {
 | 
				
			||||||
 | 
							currPatch.Key = newPatch.Key
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.KeySignature != nil {
 | 
				
			||||||
 | 
							currPatch.KeySignature = newPatch.KeySignature
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.DiscoKey != nil {
 | 
				
			||||||
 | 
							currPatch.DiscoKey = newPatch.DiscoKey
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.Online != nil {
 | 
				
			||||||
 | 
							currPatch.Online = newPatch.Online
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.LastSeen != nil {
 | 
				
			||||||
 | 
							currPatch.LastSeen = newPatch.LastSeen
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.KeyExpiry != nil {
 | 
				
			||||||
 | 
							currPatch.KeyExpiry = newPatch.KeyExpiry
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if newPatch.Capabilities != nil {
 | 
				
			||||||
 | 
							currPatch.Capabilities = newPatch.Capabilities
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										249
									
								
								hscontrol/notifier/notifier_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										249
									
								
								hscontrol/notifier/notifier_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,249 @@
 | 
				
			|||||||
 | 
					package notifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"net/netip"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/google/go-cmp/cmp"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/util"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBatcher(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name    string
 | 
				
			||||||
 | 
							updates []types.StateUpdate
 | 
				
			||||||
 | 
							want    []types.StateUpdate
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "full-passthrough",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StateFullUpdate,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StateFullUpdate,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "derp-passthrough",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StateDERPUpdated,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StateDERPUpdated,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "single-node-update",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChanged,
 | 
				
			||||||
 | 
										ChangeNodes: []types.NodeID{
 | 
				
			||||||
 | 
											2,
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChanged,
 | 
				
			||||||
 | 
										ChangeNodes: []types.NodeID{
 | 
				
			||||||
 | 
											2,
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "merge-node-update",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChanged,
 | 
				
			||||||
 | 
										ChangeNodes: []types.NodeID{
 | 
				
			||||||
 | 
											2, 4,
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChanged,
 | 
				
			||||||
 | 
										ChangeNodes: []types.NodeID{
 | 
				
			||||||
 | 
											2, 3,
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChanged,
 | 
				
			||||||
 | 
										ChangeNodes: []types.NodeID{
 | 
				
			||||||
 | 
											2, 3, 4,
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "single-patch-update",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     2,
 | 
				
			||||||
 | 
												DERPRegion: 5,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     2,
 | 
				
			||||||
 | 
												DERPRegion: 5,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "merge-patch-to-same-node-update",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     2,
 | 
				
			||||||
 | 
												DERPRegion: 5,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     2,
 | 
				
			||||||
 | 
												DERPRegion: 6,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     2,
 | 
				
			||||||
 | 
												DERPRegion: 6,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "merge-patch-to-multiple-node-update",
 | 
				
			||||||
 | 
								updates: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID: 3,
 | 
				
			||||||
 | 
												Endpoints: []netip.AddrPort{
 | 
				
			||||||
 | 
													netip.MustParseAddrPort("1.1.1.1:9090"),
 | 
				
			||||||
 | 
												},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID: 3,
 | 
				
			||||||
 | 
												Endpoints: []netip.AddrPort{
 | 
				
			||||||
 | 
													netip.MustParseAddrPort("1.1.1.1:9090"),
 | 
				
			||||||
 | 
													netip.MustParseAddrPort("2.2.2.2:8080"),
 | 
				
			||||||
 | 
												},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     4,
 | 
				
			||||||
 | 
												DERPRegion: 6,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID: 4,
 | 
				
			||||||
 | 
												Cap:    tailcfg.CapabilityVersion(54),
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								want: []types.StateUpdate{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Type: types.StatePeerChangedPatch,
 | 
				
			||||||
 | 
										ChangePatches: []*tailcfg.PeerChange{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID: 3,
 | 
				
			||||||
 | 
												Endpoints: []netip.AddrPort{
 | 
				
			||||||
 | 
													netip.MustParseAddrPort("1.1.1.1:9090"),
 | 
				
			||||||
 | 
													netip.MustParseAddrPort("2.2.2.2:8080"),
 | 
				
			||||||
 | 
												},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												NodeID:     4,
 | 
				
			||||||
 | 
												DERPRegion: 6,
 | 
				
			||||||
 | 
												Cap:        tailcfg.CapabilityVersion(54),
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								n := NewNotifier(&types.Config{
 | 
				
			||||||
 | 
									Tuning: types.Tuning{
 | 
				
			||||||
 | 
										// We will call flush manually for the tests,
 | 
				
			||||||
 | 
										// so do not run the worker.
 | 
				
			||||||
 | 
										BatchChangeDelay: time.Hour,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								ch := make(chan types.StateUpdate, 30)
 | 
				
			||||||
 | 
								defer close(ch)
 | 
				
			||||||
 | 
								n.AddNode(1, ch)
 | 
				
			||||||
 | 
								defer n.RemoveNode(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for _, u := range tt.updates {
 | 
				
			||||||
 | 
									n.NotifyAll(context.Background(), u)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								n.b.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								var got []types.StateUpdate
 | 
				
			||||||
 | 
								for len(ch) > 0 {
 | 
				
			||||||
 | 
									out := <-ch
 | 
				
			||||||
 | 
									got = append(got, out)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
 | 
				
			||||||
 | 
									t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -66,10 +66,16 @@ func (h *Headscale) newMapSession(
 | 
				
			|||||||
) *mapSession {
 | 
					) *mapSession {
 | 
				
			||||||
	warnf, infof, tracef, errf := logPollFunc(req, node)
 | 
						warnf, infof, tracef, errf := logPollFunc(req, node)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var updateChan chan types.StateUpdate
 | 
				
			||||||
 | 
						if req.Stream {
 | 
				
			||||||
		// Use a buffered channel in case a node is not fully ready
 | 
							// Use a buffered channel in case a node is not fully ready
 | 
				
			||||||
		// to receive a message to make sure we dont block the entire
 | 
							// to receive a message to make sure we dont block the entire
 | 
				
			||||||
		// notifier.
 | 
							// notifier.
 | 
				
			||||||
	updateChan := make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
 | 
							updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
 | 
				
			||||||
 | 
							updateChan <- types.StateUpdate{
 | 
				
			||||||
 | 
								Type: types.StateFullUpdate,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &mapSession{
 | 
						return &mapSession{
 | 
				
			||||||
		h:      h,
 | 
							h:      h,
 | 
				
			||||||
@ -218,33 +224,26 @@ func (m *mapSession) serve() {
 | 
				
			|||||||
	ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
 | 
						ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
 | 
				
			||||||
	defer cancel()
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): Make this available through a tuning envvar
 | 
					 | 
				
			||||||
	wait := time.Second
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Add a circuit breaker, if the loop is not interrupted
 | 
					 | 
				
			||||||
	// inbetween listening for the channels, some updates
 | 
					 | 
				
			||||||
	// might get stale and stucked in the "changed" map
 | 
					 | 
				
			||||||
	// defined below.
 | 
					 | 
				
			||||||
	blockBreaker := time.NewTicker(wait)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// true means changed, false means removed
 | 
					 | 
				
			||||||
	var changed map[types.NodeID]bool
 | 
					 | 
				
			||||||
	var patches []*tailcfg.PeerChange
 | 
					 | 
				
			||||||
	var derp bool
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Set full to true to immediatly send a full mapresponse
 | 
					 | 
				
			||||||
	full := true
 | 
					 | 
				
			||||||
	prev := time.Now()
 | 
					 | 
				
			||||||
	lastMessage := ""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Loop through updates and continuously send them to the
 | 
						// Loop through updates and continuously send them to the
 | 
				
			||||||
	// client.
 | 
						// client.
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		// If a full update has been requested or there are patches, then send it immediately
 | 
							// consume channels with update, keep alives or "batch" blocking signals
 | 
				
			||||||
		// otherwise wait for the "batching" of changes or patches
 | 
							select {
 | 
				
			||||||
		if full || patches != nil || (changed != nil && time.Since(prev) > wait) {
 | 
							case <-m.cancelCh:
 | 
				
			||||||
 | 
								m.tracef("poll cancelled received")
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							case <-ctx.Done():
 | 
				
			||||||
 | 
								m.tracef("poll context done")
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Consume all updates sent to node
 | 
				
			||||||
 | 
							case update := <-m.ch:
 | 
				
			||||||
 | 
								m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
 | 
				
			||||||
 | 
								mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var data []byte
 | 
								var data []byte
 | 
				
			||||||
			var err error
 | 
								var err error
 | 
				
			||||||
 | 
								var lastMessage string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Ensure the node object is updated, for example, there
 | 
								// Ensure the node object is updated, for example, there
 | 
				
			||||||
			// might have been a hostinfo update in a sidechannel
 | 
								// might have been a hostinfo update in a sidechannel
 | 
				
			||||||
@ -256,62 +255,43 @@ func (m *mapSession) serve() {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// If there are patches _and_ fully changed nodes, filter the
 | 
					 | 
				
			||||||
			// patches and remove all patches that are present for the full
 | 
					 | 
				
			||||||
			// changes updates. This allows us to send them as part of the
 | 
					 | 
				
			||||||
			// PeerChange update, but only for nodes that are not fully changed.
 | 
					 | 
				
			||||||
			// The fully changed nodes will be updated from the database and
 | 
					 | 
				
			||||||
			// have all the updates needed.
 | 
					 | 
				
			||||||
			// This means that the patches left are for nodes that has no
 | 
					 | 
				
			||||||
			// updates that requires a full update.
 | 
					 | 
				
			||||||
			// Patches are not suppose to be mixed in, but can be.
 | 
					 | 
				
			||||||
			//
 | 
					 | 
				
			||||||
			// From tailcfg docs:
 | 
					 | 
				
			||||||
			// These are applied after Peers* above, but in practice the
 | 
					 | 
				
			||||||
			// control server should only send these on their own, without
 | 
					 | 
				
			||||||
			//
 | 
					 | 
				
			||||||
			// Currently, there is no effort to merge patch updates, they
 | 
					 | 
				
			||||||
			// are all sent, and the client will apply them in order.
 | 
					 | 
				
			||||||
			// TODO(kradalby): Merge Patches for the same IDs to send less
 | 
					 | 
				
			||||||
			// data and give the client less work.
 | 
					 | 
				
			||||||
			if patches != nil && changed != nil {
 | 
					 | 
				
			||||||
				var filteredPatches []*tailcfg.PeerChange
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				for _, patch := range patches {
 | 
					 | 
				
			||||||
					if _, ok := changed[types.NodeID(patch.NodeID)]; !ok {
 | 
					 | 
				
			||||||
						filteredPatches = append(filteredPatches, patch)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				patches = filteredPatches
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			updateType := "full"
 | 
								updateType := "full"
 | 
				
			||||||
			// When deciding what update to send, the following is considered,
 | 
								switch update.Type {
 | 
				
			||||||
			// Full is a superset of all updates, when a full update is requested,
 | 
								case types.StateFullUpdate:
 | 
				
			||||||
			// send only that and move on, all other updates will be present in
 | 
					 | 
				
			||||||
			// a full map response.
 | 
					 | 
				
			||||||
			//
 | 
					 | 
				
			||||||
			// If a map of changed nodes exists, prefer sending that as it will
 | 
					 | 
				
			||||||
			// contain all the updates for the node, including patches, as it
 | 
					 | 
				
			||||||
			// is fetched freshly from the database when building the response.
 | 
					 | 
				
			||||||
			//
 | 
					 | 
				
			||||||
			// If there is full changes registered, but we have patches for individual
 | 
					 | 
				
			||||||
			// nodes, send them.
 | 
					 | 
				
			||||||
			//
 | 
					 | 
				
			||||||
			// Finally, if a DERP map is the only request, send that alone.
 | 
					 | 
				
			||||||
			if full {
 | 
					 | 
				
			||||||
				m.tracef("Sending Full MapResponse")
 | 
									m.tracef("Sending Full MapResponse")
 | 
				
			||||||
				data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
 | 
									data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
 | 
				
			||||||
			} else if changed != nil {
 | 
								case types.StatePeerChanged:
 | 
				
			||||||
 | 
									changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									for _, nodeID := range update.ChangeNodes {
 | 
				
			||||||
 | 
										changed[nodeID] = true
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									lastMessage = update.Message
 | 
				
			||||||
				m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
									m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
				
			||||||
				data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage)
 | 
									data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
 | 
				
			||||||
				updateType = "change"
 | 
									updateType = "change"
 | 
				
			||||||
			} else if patches != nil {
 | 
					
 | 
				
			||||||
 | 
								case types.StatePeerChangedPatch:
 | 
				
			||||||
				m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
 | 
									m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
 | 
				
			||||||
				data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy)
 | 
									data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
 | 
				
			||||||
				updateType = "patch"
 | 
									updateType = "patch"
 | 
				
			||||||
			} else if derp {
 | 
								case types.StatePeerRemoved:
 | 
				
			||||||
 | 
									changed := make(map[types.NodeID]bool, len(update.Removed))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									for _, nodeID := range update.Removed {
 | 
				
			||||||
 | 
										changed[nodeID] = false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
				
			||||||
 | 
									data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
 | 
				
			||||||
 | 
									updateType = "remove"
 | 
				
			||||||
 | 
								case types.StateSelfUpdate:
 | 
				
			||||||
 | 
									lastMessage = update.Message
 | 
				
			||||||
 | 
									m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
 | 
				
			||||||
 | 
									// create the map so an empty (self) update is sent
 | 
				
			||||||
 | 
									data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
 | 
				
			||||||
 | 
									updateType = "remove"
 | 
				
			||||||
 | 
								case types.StateDERPUpdated:
 | 
				
			||||||
				m.tracef("Sending DERPUpdate MapResponse")
 | 
									m.tracef("Sending DERPUpdate MapResponse")
 | 
				
			||||||
				data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
 | 
									data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
 | 
				
			||||||
				updateType = "derp"
 | 
									updateType = "derp"
 | 
				
			||||||
@ -348,68 +328,6 @@ func (m *mapSession) serve() {
 | 
				
			|||||||
				m.tracef("update sent")
 | 
									m.tracef("update sent")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// reset
 | 
					 | 
				
			||||||
			changed = nil
 | 
					 | 
				
			||||||
			patches = nil
 | 
					 | 
				
			||||||
			lastMessage = ""
 | 
					 | 
				
			||||||
			full = false
 | 
					 | 
				
			||||||
			derp = false
 | 
					 | 
				
			||||||
			prev = time.Now()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// consume channels with update, keep alives or "batch" blocking signals
 | 
					 | 
				
			||||||
		select {
 | 
					 | 
				
			||||||
		case <-m.cancelCh:
 | 
					 | 
				
			||||||
			m.tracef("poll cancelled received")
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		case <-ctx.Done():
 | 
					 | 
				
			||||||
			m.tracef("poll context done")
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// Avoid infinite block that would potentially leave
 | 
					 | 
				
			||||||
		// some updates in the changed map.
 | 
					 | 
				
			||||||
		case <-blockBreaker.C:
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Consume all updates sent to node
 | 
					 | 
				
			||||||
		case update := <-m.ch:
 | 
					 | 
				
			||||||
			m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
 | 
					 | 
				
			||||||
			mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			switch update.Type {
 | 
					 | 
				
			||||||
			case types.StateFullUpdate:
 | 
					 | 
				
			||||||
				full = true
 | 
					 | 
				
			||||||
			case types.StatePeerChanged:
 | 
					 | 
				
			||||||
				if changed == nil {
 | 
					 | 
				
			||||||
					changed = make(map[types.NodeID]bool)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				for _, nodeID := range update.ChangeNodes {
 | 
					 | 
				
			||||||
					changed[nodeID] = true
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				lastMessage = update.Message
 | 
					 | 
				
			||||||
			case types.StatePeerChangedPatch:
 | 
					 | 
				
			||||||
				patches = append(patches, update.ChangePatches...)
 | 
					 | 
				
			||||||
			case types.StatePeerRemoved:
 | 
					 | 
				
			||||||
				if changed == nil {
 | 
					 | 
				
			||||||
					changed = make(map[types.NodeID]bool)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				for _, nodeID := range update.Removed {
 | 
					 | 
				
			||||||
					changed[nodeID] = false
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			case types.StateSelfUpdate:
 | 
					 | 
				
			||||||
				// create the map so an empty (self) update is sent
 | 
					 | 
				
			||||||
				if changed == nil {
 | 
					 | 
				
			||||||
					changed = make(map[types.NodeID]bool)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				lastMessage = update.Message
 | 
					 | 
				
			||||||
			case types.StateDERPUpdated:
 | 
					 | 
				
			||||||
				derp = true
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		case <-m.keepAliveTicker.C:
 | 
							case <-m.keepAliveTicker.C:
 | 
				
			||||||
			data, err := m.mapper.KeepAliveResponse(m.req, m.node)
 | 
								data, err := m.mapper.KeepAliveResponse(m.req, m.node)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -10,6 +10,7 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
						"tailscale.com/util/ctxkey"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@ -183,10 +184,14 @@ func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						NotifyOriginKey   = ctxkey.New("notify.origin", "")
 | 
				
			||||||
 | 
						NotifyHostnameKey = ctxkey.New("notify.hostname", "")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
 | 
					func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
 | 
				
			||||||
	ctx2, _ := context.WithTimeout(
 | 
						ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
 | 
				
			||||||
		context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
 | 
						ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
 | 
				
			||||||
		3*time.Second,
 | 
						ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	return ctx2
 | 
						return ctx2
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -11,6 +11,7 @@ import (
 | 
				
			|||||||
	"encoding/pem"
 | 
						"encoding/pem"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"math/big"
 | 
						"math/big"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
@ -396,6 +397,14 @@ func (t *HeadscaleInContainer) Shutdown() error {
 | 
				
			|||||||
		)
 | 
							)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = t.SaveMetrics("/tmp/control/metrics.txt")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Printf(
 | 
				
			||||||
 | 
								"Failed to metrics from control: %s",
 | 
				
			||||||
 | 
								err,
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Send a interrupt signal to the "headscale" process inside the container
 | 
						// Send a interrupt signal to the "headscale" process inside the container
 | 
				
			||||||
	// allowing it to shut down gracefully and flush the profile to disk.
 | 
						// allowing it to shut down gracefully and flush the profile to disk.
 | 
				
			||||||
	// The container will live for a bit longer due to the sleep at the end.
 | 
						// The container will live for a bit longer due to the sleep at the end.
 | 
				
			||||||
@ -448,6 +457,25 @@ func (t *HeadscaleInContainer) SaveLog(path string) error {
 | 
				
			|||||||
	return dockertestutil.SaveLog(t.pool, t.container, path)
 | 
						return dockertestutil.SaveLog(t.pool, t.container, path)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
 | 
				
			||||||
 | 
						resp, err := http.Get(fmt.Sprintf("http://%s:9090/metrics", t.hostname))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("getting metrics: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
						out, err := os.Create(savePath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("creating file for metrics: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer out.Close()
 | 
				
			||||||
 | 
						_, err = io.Copy(out, resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("copy response to file: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
 | 
					func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
 | 
				
			||||||
	tarFile, err := t.FetchPath("/tmp/profile")
 | 
						tarFile, err := t.FetchPath("/tmp/profile")
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -252,7 +252,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	scenario, err := NewScenario(dockertestMaxWait())
 | 
						scenario, err := NewScenario(dockertestMaxWait())
 | 
				
			||||||
	assertNoErrf(t, "failed to create scenario: %s", err)
 | 
						assertNoErrf(t, "failed to create scenario: %s", err)
 | 
				
			||||||
	// defer scenario.Shutdown()
 | 
						defer scenario.Shutdown()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	spec := map[string]int{
 | 
						spec := map[string]int{
 | 
				
			||||||
		user: 3,
 | 
							user: 3,
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user