1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-06-05 01:20:21 +02:00

mapper-baatcher-experiment

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-26 14:48:40 +02:00
parent b8044c29dd
commit 72f473d0c9
No known key found for this signature in database
8 changed files with 708 additions and 654 deletions

View File

@ -95,7 +95,7 @@ type Headscale struct {
extraRecordMan *dns.ExtraRecordsMan
primaryRoutes *routes.PrimaryRoutes
mapper *mapper.Mapper
mapBatcher *mapper.Batcher
nodeNotifier *notifier.Notifier
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
@ -135,7 +135,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(cfg),
primaryRoutes: routes.New(),
}
@ -582,7 +581,15 @@ func (h *Headscale) Serve() error {
// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan, h.primaryRoutes)
mapp := mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.polMan, h.primaryRoutes)
h.mapBatcher = mapper.NewBatcher(mapp)
h.nodeNotifier = notifier.NewNotifier(h.cfg, h.mapBatcher)
// TODO(kradalby): I dont like this. Right now its done to access online status.
mapp.SetBatcher(h.mapBatcher)
h.mapBatcher.Start()
defer h.mapBatcher.Close()
if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server

279
hscontrol/mapper/batcher.go Normal file
View File

@ -0,0 +1,279 @@
package mapper
import (
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var (
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
)
func init() {
deadlock.Opts.Disable = !debugDeadlock
if debugDeadlock {
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
deadlock.Opts.PrintAllCurrentGoroutines = true
}
}
type ChangeWork struct {
NodeID *types.NodeID
Update types.StateUpdate
}
type nodeConn struct {
c chan<- []byte
compress string
version tailcfg.CapabilityVersion
}
type Batcher struct {
mu deadlock.RWMutex
mapper *Mapper
// connected is a map of NodeID to the time the closed a connection.
// This is used to track which nodes are currently connected.
// If value is nil, the node is connected
// If value is not nil, the node is disconnected
connected map[types.NodeID]*time.Time
// nodes is a map of NodeID to a channel that is used to send generated
// mapResp to a client.
nodes map[types.NodeID]nodeConn
// TODO: we will probably have more workers, but for now,
// this should serve for the experiment.
cancelCh chan struct{}
workCh chan *ChangeWork
}
func NewBatcher(mapper *Mapper) *Batcher {
return &Batcher{
mapper: mapper,
cancelCh: make(chan struct{}),
// TODO: No limit for now, this needs to be changed
workCh: make(chan *ChangeWork, (1<<16)-1),
nodes: make(map[types.NodeID]nodeConn),
connected: make(map[types.NodeID]*time.Time),
}
}
func (b *Batcher) Close() {
b.cancelCh <- struct{}{}
}
func (b *Batcher) Start() {
go b.doWork()
}
func (b *Batcher) AddNode(id types.NodeID, c chan<- []byte, compress string, version tailcfg.CapabilityVersion) {
b.mu.Lock()
defer b.mu.Unlock()
// If a channel exists, it means the node has opened a new
// connection. Close the old channel and replace it.
if curr, ok := b.nodes[id]; ok {
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
// if/when someone is waiting to send on this channel
go func(nc nodeConn) {
close(nc.c)
}(curr)
}
b.nodes[id] = nodeConn{
c: c,
compress: compress,
version: version,
}
b.connected[id] = nil // nil means connected
b.AddWork(&ChangeWork{
NodeID: &id,
Update: types.UpdateFull(),
})
}
func (b *Batcher) RemoveNode(id types.NodeID, c chan<- []byte) bool {
b.mu.Lock()
defer b.mu.Unlock()
if curr, ok := b.nodes[id]; ok {
if curr.c != c {
return false
}
}
delete(b.nodes, id)
b.connected[id] = ptr.To(time.Now())
return true
}
func (b *Batcher) AddWork(work *ChangeWork) {
log.Trace().Msgf("adding work: %v", work.Update)
b.workCh <- work
}
func (b *Batcher) IsConnected(id types.NodeID) bool {
b.mu.RLock()
defer b.mu.RUnlock()
// If the value is nil, it means the node is connected
if b.connected[id] == nil {
return true
}
// If the value is not nil, it means the node is disconnected
return false
}
func (b *Batcher) IsLikelyConnected(id types.NodeID) bool {
return b.isLikelyConnectedLocked(id)
}
func (b *Batcher) isLikelyConnectedLocked(id types.NodeID) bool {
// If the value is nil, it means the node is connected
if b.connected[id] == nil {
return true
}
// If the value is not nil, it means the node is disconnected
// but we check if it was disconnected recently (within 5 seconds)
if time.Since(*b.connected[id]) < 10*time.Second {
return true
}
return false
}
func (b *Batcher) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
b.mu.RLock()
defer b.mu.RUnlock()
ret := xsync.NewMapOf[types.NodeID, bool]()
for id, _ := range b.connected {
ret.Store(id, b.isLikelyConnectedLocked(id))
}
return ret
}
func (b *Batcher) doWork() {
for {
select {
case <-b.cancelCh:
return
case work := <-b.workCh:
b.processWork(work)
}
}
}
// processWork is the current bottleneck where all the updates get picked up
// one by one and processed. This will have to change, it needs to go as fast as
// possible and just pass it on to the nodes. Currently it wont block because the
// work channel is super large, but it might not be able to keep up.
// one alternative is to have a worker per node, but that would
// mean a lot of goroutines, hanging around.
// Another is just a worker pool that picks up work and processes it,
// and passes it on to the nodes. That might be complicated with order?
func (b *Batcher) processWork(work *ChangeWork) {
b.mu.RLock()
defer b.mu.RUnlock()
log.Trace().Msgf("processing work: %v", work)
if work.NodeID != nil {
id := *work.NodeID
node, ok := b.nodes[id]
if !ok {
log.Trace().Msgf("node %d not found in batcher, skipping work: %v", id, work.Update)
return
}
resp, err := b.resp(id, &node, work)
if err != nil {
log.Debug().Msgf("creating mapResp for %d: %s", id, err)
}
node.c <- resp
return
}
for id, node := range b.nodes {
resp, err := b.resp(id, &node, work)
if err != nil {
log.Debug().Msgf("creating mapResp for %d: %s", id, err)
}
node.c <- resp
}
}
// resp is the logic that used to reside in the poller, but is now moved
// to process before sending to the node. The idea is that we do not want to
// be blocked on the send channel to the individual node, but rather
// process all the work and then send the responses to the nodes.
// TODO(kradalby): This is a temporary solution, as we explore this
// approach, we will likely need to refactor this further.
func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte, error) {
var data []byte
var err error
// TODO(kradalby): This should not be necessary, mapper only
// use compress and version, and this can either be moved out
// or passed directly. The mapreq isnt needed.
req := tailcfg.MapRequest{
Compress: nc.compress,
Version: nc.version,
}
// TODO(kradalby): We dont want to use the db here. We should
// just have the node available, or at least quickly accessible
// from the new fancy mem state we want.
node, err := b.mapper.db.GetNodeByID(id)
if err != nil {
return nil, err
}
switch work.Update.Type {
case types.StateFullUpdate:
data, err = b.mapper.FullMapResponse(req, node)
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(work.Update.ChangeNodes))
for _, nodeID := range work.Update.ChangeNodes {
changed[nodeID] = true
}
data, err = b.mapper.PeerChangedResponse(req, node, changed, work.Update.ChangePatches)
case types.StatePeerChangedPatch:
data, err = b.mapper.PeerChangedPatchResponse(req, node, work.Update.ChangePatches)
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(work.Update.Removed))
for _, nodeID := range work.Update.Removed {
changed[nodeID] = false
}
data, err = b.mapper.PeerChangedResponse(req, node, changed, work.Update.ChangePatches)
case types.StateSelfUpdate:
data, err = b.mapper.PeerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches)
// case types.StateDERPUpdated:
// data, err = b.mapper.DERPMapResponse(req, node, b.mapper.DERPMap)
}
return data, err
}

View File

@ -17,7 +17,6 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
@ -56,9 +55,9 @@ type Mapper struct {
db *db.HSDatabase
cfg *types.Config
derpMap *tailcfg.DERPMap
notif *notifier.Notifier
polMan policy.PolicyManager
primary *routes.PrimaryRoutes
batcher *Batcher
uid string
created time.Time
@ -74,7 +73,6 @@ func NewMapper(
db *db.HSDatabase,
cfg *types.Config,
derpMap *tailcfg.DERPMap,
notif *notifier.Notifier,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
) *Mapper {
@ -84,7 +82,6 @@ func NewMapper(
db: db,
cfg: cfg,
derpMap: derpMap,
notif: notif,
polMan: polMan,
primary: primary,
@ -94,6 +91,10 @@ func NewMapper(
}
}
func (m *Mapper) SetBatcher(batcher *Batcher) {
m.batcher = batcher
}
func (m *Mapper) String() string {
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
}
@ -502,8 +503,10 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, peer := range peers {
online := m.notif.IsLikelyConnected(peer.ID)
online := m.batcher.IsLikelyConnected(peer.ID)
peer.IsOnline = &online
}
@ -518,8 +521,10 @@ func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, node := range nodes {
online := m.notif.IsLikelyConnected(node.ID)
online := m.batcher.IsLikelyConnected(node.ID)
node.IsOnline = &online
}

View File

@ -420,7 +420,6 @@ func Test_fullMapResponse(t *testing.T) {
nil,
tt.cfg,
tt.derpMap,
nil,
polMan,
primary,
)

View File

@ -2,12 +2,12 @@ package notifier
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
@ -31,20 +31,18 @@ func init() {
}
type Notifier struct {
l deadlock.Mutex
nodes map[types.NodeID]chan<- types.StateUpdate
connected *xsync.MapOf[types.NodeID, bool]
b *batcher
cfg *types.Config
closed bool
l deadlock.Mutex
b *batcher
cfg *types.Config
closed bool
mbatcher *mapper.Batcher
}
func NewNotifier(cfg *types.Config) *Notifier {
func NewNotifier(cfg *types.Config, mbatch *mapper.Batcher) *Notifier {
n := &Notifier{
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: xsync.NewMapOf[types.NodeID, bool](),
cfg: cfg,
closed: false,
mbatcher: mbatch,
cfg: cfg,
closed: false,
}
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
n.b = b
@ -62,14 +60,6 @@ func (n *Notifier) Close() {
n.closed = true
n.b.close()
// Close channels safely using the helper method
for nodeID, c := range n.nodes {
n.safeCloseChannel(nodeID, c)
}
// Clear node map after closing channels
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed
@ -87,104 +77,24 @@ func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpd
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
log.Trace().
Uint64("node.id", nID.Uint64()).
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
if n.closed {
return
}
// If a channel exists, it means the node has opened a new
// connection. Close the old channel and replace it.
if curr, ok := n.nodes[nodeID]; ok {
n.tracef(nodeID, "channel present, closing and replacing")
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
// if/when someone is waiting to send on this channel
go func(ch chan<- types.StateUpdate) {
n.safeCloseChannel(nodeID, ch)
}(curr)
}
n.nodes[nodeID] = c
n.connected.Store(nodeID, true)
n.tracef(nodeID, "added new channel")
notifierNodeUpdateChans.Inc()
}
// RemoveNode removes a node and a given channel from the notifier.
// It checks that the channel is the same as currently being updated
// and ignores the removal if it is not.
// RemoveNode reports if the node/chan was removed.
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
if n.closed {
return true
}
if len(n.nodes) == 0 {
return true
}
// If the channel exist, but it does not belong
// to the caller, ignore.
if curr, ok := n.nodes[nodeID]; ok {
if curr != c {
n.tracef(nodeID, "channel has been replaced, not removing")
return false
}
}
delete(n.nodes, nodeID)
n.connected.Store(nodeID, false)
n.tracef(nodeID, "removed channel")
notifierNodeUpdateChans.Dec()
return true
Uint64("node.id", nID.Uint64()).Msgf(msg, args...)
}
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
return n.mbatcher.IsConnected(nodeID)
}
// IsLikelyConnected reports if a node is connected to headscale and has a
// poll session open, but doesn't lock, so might be wrong.
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
return n.mbatcher.IsLikelyConnected(nodeID)
}
// LikelyConnectedMap returns a thread safe map of connected nodes
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
return n.mbatcher.LikelyConnectedMap()
}
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
@ -209,87 +119,16 @@ func (n *Notifier) NotifyByNodeID(
update types.StateUpdate,
nodeID types.NodeID,
) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
if n.closed {
return
}
if c, ok := n.nodes[nodeID]; ok {
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", types.NotifyOriginKey.Value(ctx)).
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
Msgf("update not sent, context cancelled")
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
return
case c <- update:
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
}
}
n.mbatcher.AddWork(&mapper.ChangeWork{
NodeID: &nodeID,
Update: update,
})
}
func (n *Notifier) sendAll(update types.StateUpdate) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
if n.closed {
return
}
for id, c := range n.nodes {
// Whenever an update is sent to all nodes, there is a chance that the node
// has disconnected and the goroutine that was supposed to consume the update
// has shut down the channel and is waiting for the lock held here in RemoveNode.
// This means that there is potential for a deadlock which would stop all updates
// going out to clients. This timeout prevents that from happening by moving on to the
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
// call will succeed and the update will go to the correct nodes on the next call.
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
defer cancel()
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", id.Uint64()).
Msgf("update not sent, context cancelled")
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
}
return
case c <- update:
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
}
}
}
n.mbatcher.AddWork(&mapper.ChangeWork{
Update: update,
})
}
func (n *Notifier) String() string {
@ -299,29 +138,12 @@ func (n *Notifier) String() string {
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
var b strings.Builder
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
var keys []types.NodeID
n.connected.Range(func(key types.NodeID, value bool) bool {
keys = append(keys, key)
return true
})
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
for _, key := range keys {
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
}
b.WriteString("\n")
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
for _, key := range keys {
val, _ := n.connected.Load(key)
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
}
return b.String()
}

View File

@ -1,343 +1,343 @@
package notifier
import (
"context"
"fmt"
"math/rand"
"net/netip"
"sort"
"sync"
"testing"
"time"
// import (
// "context"
// "fmt"
// "math/rand"
// "net/netip"
// "sort"
// "sync"
// "testing"
// "time"
"slices"
// "slices"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
)
// "github.com/google/go-cmp/cmp"
// "github.com/juanfont/headscale/hscontrol/types"
// "github.com/juanfont/headscale/hscontrol/util"
// "tailscale.com/tailcfg"
// )
func TestBatcher(t *testing.T) {
tests := []struct {
name string
updates []types.StateUpdate
want []types.StateUpdate
}{
{
name: "full-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
want: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
},
{
name: "derp-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
want: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
},
{
name: "single-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
},
{
name: "merge-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 4,
},
},
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3, 4,
},
},
},
},
{
name: "single-patch-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
},
{
name: "merge-patch-to-same-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
},
{
name: "merge-patch-to-multiple-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
DERPRegion: 6,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
{
NodeID: 4,
DERPRegion: 6,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
},
}
// func TestBatcher(t *testing.T) {
// tests := []struct {
// name string
// updates []types.StateUpdate
// want []types.StateUpdate
// }{
// {
// name: "full-passthrough",
// updates: []types.StateUpdate{
// {
// Type: types.StateFullUpdate,
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StateFullUpdate,
// },
// },
// },
// {
// name: "derp-passthrough",
// updates: []types.StateUpdate{
// {
// Type: types.StateDERPUpdated,
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StateDERPUpdated,
// },
// },
// },
// {
// name: "single-node-update",
// updates: []types.StateUpdate{
// {
// Type: types.StatePeerChanged,
// ChangeNodes: []types.NodeID{
// 2,
// },
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StatePeerChanged,
// ChangeNodes: []types.NodeID{
// 2,
// },
// },
// },
// },
// {
// name: "merge-node-update",
// updates: []types.StateUpdate{
// {
// Type: types.StatePeerChanged,
// ChangeNodes: []types.NodeID{
// 2, 4,
// },
// },
// {
// Type: types.StatePeerChanged,
// ChangeNodes: []types.NodeID{
// 2, 3,
// },
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StatePeerChanged,
// ChangeNodes: []types.NodeID{
// 2, 3, 4,
// },
// },
// },
// },
// {
// name: "single-patch-update",
// updates: []types.StateUpdate{
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 2,
// DERPRegion: 5,
// },
// },
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 2,
// DERPRegion: 5,
// },
// },
// },
// },
// },
// {
// name: "merge-patch-to-same-node-update",
// updates: []types.StateUpdate{
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 2,
// DERPRegion: 5,
// },
// },
// },
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 2,
// DERPRegion: 6,
// },
// },
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 2,
// DERPRegion: 6,
// },
// },
// },
// },
// },
// {
// name: "merge-patch-to-multiple-node-update",
// updates: []types.StateUpdate{
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 3,
// Endpoints: []netip.AddrPort{
// netip.MustParseAddrPort("1.1.1.1:9090"),
// },
// },
// },
// },
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 3,
// Endpoints: []netip.AddrPort{
// netip.MustParseAddrPort("1.1.1.1:9090"),
// netip.MustParseAddrPort("2.2.2.2:8080"),
// },
// },
// },
// },
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 4,
// DERPRegion: 6,
// },
// },
// },
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 4,
// Cap: tailcfg.CapabilityVersion(54),
// },
// },
// },
// },
// want: []types.StateUpdate{
// {
// Type: types.StatePeerChangedPatch,
// ChangePatches: []*tailcfg.PeerChange{
// {
// NodeID: 3,
// Endpoints: []netip.AddrPort{
// netip.MustParseAddrPort("1.1.1.1:9090"),
// netip.MustParseAddrPort("2.2.2.2:8080"),
// },
// },
// {
// NodeID: 4,
// DERPRegion: 6,
// Cap: tailcfg.CapabilityVersion(54),
// },
// },
// },
// },
// },
// }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := NewNotifier(&types.Config{
Tuning: types.Tuning{
// We will call flush manually for the tests,
// so do not run the worker.
BatchChangeDelay: time.Hour,
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// n := NewNotifier(&types.Config{
// Tuning: types.Tuning{
// // We will call flush manually for the tests,
// // so do not run the worker.
// BatchChangeDelay: time.Hour,
// Since we do not load the config, we won't get the
// default, so set it manually so we dont time out
// and have flakes.
NotifierSendTimeout: time.Second,
},
})
// // Since we do not load the config, we won't get the
// // default, so set it manually so we dont time out
// // and have flakes.
// NotifierSendTimeout: time.Second,
// },
// })
ch := make(chan types.StateUpdate, 30)
defer close(ch)
n.AddNode(1, ch)
defer n.RemoveNode(1, ch)
// ch := make(chan types.StateUpdate, 30)
// defer close(ch)
// n.AddNode(1, ch)
// defer n.RemoveNode(1, ch)
for _, u := range tt.updates {
n.NotifyAll(context.Background(), u)
}
// for _, u := range tt.updates {
// n.NotifyAll(context.Background(), u)
// }
n.b.flush()
// n.b.flush()
var got []types.StateUpdate
for len(ch) > 0 {
out := <-ch
got = append(got, out)
}
// var got []types.StateUpdate
// for len(ch) > 0 {
// out := <-ch
// got = append(got, out)
// }
// Make the inner order stable for comparison.
for _, u := range got {
slices.Sort(u.ChangeNodes)
sort.Slice(u.ChangePatches, func(i, j int) bool {
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
})
}
// // Make the inner order stable for comparison.
// for _, u := range got {
// slices.Sort(u.ChangeNodes)
// sort.Slice(u.ChangePatches, func(i, j int) bool {
// return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
// })
// }
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
}
})
}
}
// if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
// t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
// }
// })
// }
// }
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
Tuning: types.Tuning{
NotifierSendTimeout: 1 * time.Second,
BatchChangeDelay: 1 * time.Second,
NodeMapSessionBufferedChanSize: 30,
},
}
// // TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// // Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// // close a channel that was already closed, which can happen when a node changes
// // network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
// func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// // mock config for the notifier
// cfg := &types.Config{
// Tuning: types.Tuning{
// NotifierSendTimeout: 1 * time.Second,
// BatchChangeDelay: 1 * time.Second,
// NodeMapSessionBufferedChanSize: 30,
// },
// }
notifier := NewNotifier(cfg)
defer notifier.Close()
// notifier := NewNotifier(cfg)
// defer notifier.Close()
nodeID := types.NodeID(1)
updateChan := make(chan types.StateUpdate, 10)
// nodeID := types.NodeID(1)
// updateChan := make(chan types.StateUpdate, 10)
var wg sync.WaitGroup
// var wg sync.WaitGroup
// Number of goroutines to spawn for concurrent access
concurrentAccessors := 100
iterations := 100
// // Number of goroutines to spawn for concurrent access
// concurrentAccessors := 100
// iterations := 100
// Add node to notifier
notifier.AddNode(nodeID, updateChan)
// // Add node to notifier
// notifier.AddNode(nodeID, updateChan)
// Track errors
errChan := make(chan string, concurrentAccessors*iterations)
// // Track errors
// errChan := make(chan string, concurrentAccessors*iterations)
// Start goroutines to cause a race
wg.Add(concurrentAccessors)
for i := range concurrentAccessors {
go func(routineID int) {
defer wg.Done()
// // Start goroutines to cause a race
// wg.Add(concurrentAccessors)
// for i := range concurrentAccessors {
// go func(routineID int) {
// defer wg.Done()
for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
if routineID%3 == 0 {
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
} else if routineID%3 == 1 {
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
} else {
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}
// for range iterations {
// // Simulate race by having some goroutines check IsLikelyConnected
// // while others add/remove the node
// if routineID%3 == 0 {
// // This goroutine checks connection status
// isConnected := notifier.IsLikelyConnected(nodeID)
// if isConnected != true && isConnected != false {
// errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
// }
// } else if routineID%3 == 1 {
// // This goroutine removes the node
// notifier.RemoveNode(nodeID, updateChan)
// } else {
// // This goroutine adds the node back
// notifier.AddNode(nodeID, updateChan)
// }
// Small random delay to increase chance of races
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
}
}(i)
}
// // Small random delay to increase chance of races
// time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
// }
// }(i)
// }
wg.Wait()
close(errChan)
// wg.Wait()
// close(errChan)
// Collate errors
var errors []string
for err := range errChan {
errors = append(errors, err)
}
// // Collate errors
// var errors []string
// for err := range errChan {
// errors = append(errors, err)
// }
if len(errors) > 0 {
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
}
}
// if len(errors) > 0 {
// t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
// }
// }

View File

@ -2,11 +2,9 @@ package hscontrol
import (
"context"
"fmt"
"math/rand/v2"
"net/http"
"net/netip"
"slices"
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
@ -36,7 +34,7 @@ type mapSession struct {
cancelChMu deadlock.Mutex
ch chan types.StateUpdate
ch chan []byte
cancelCh chan struct{}
cancelChOpen bool
@ -60,14 +58,16 @@ func (h *Headscale) newMapSession(
) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node)
var updateChan chan types.StateUpdate
if req.Stream {
// Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.UpdateFull()
}
// TODO(kradalby): This needs to happen in the batcher now, give a full
// var updateChan chan []byte
// map on start.
// if req.Stream {
// // Use a buffered channel in case a node is not fully ready
// // to receive a message to make sure we dont block the entire
// // notifier.
// updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
// updateChan <- types.UpdateFull()
// }
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
@ -78,9 +78,8 @@ func (h *Headscale) newMapSession(
w: w,
node: node,
capVer: req.Version,
mapper: h.mapper,
ch: updateChan,
ch: make(chan []byte, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
cancelCh: make(chan struct{}),
cancelChOpen: true,
@ -200,7 +199,7 @@ func (m *mapSession) serveLongPoll() {
// in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) {
// Failover the node's routes if any.
m.h.updateNodeOnlineStatus(false, m.node)
@ -239,7 +238,7 @@ func (m *mapSession) serveLongPoll() {
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(m.node.ID, m.ch)
m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.req.Compress, m.req.Version)
go m.h.updateNodeOnlineStatus(true, m.node)
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -261,133 +260,56 @@ func (m *mapSession) serveLongPoll() {
// Consume updates sent to node
case update, ok := <-m.ch:
m.tracef("received update from channel, ok: %t, len: %d", ok, len(update))
if !ok {
m.tracef("update channel closed, streaming session is likely being replaced")
return
}
// If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, m.node.ID) {
m.tracef("node removed, closing stream")
return
}
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
var data []byte
var err error
var lastMessage string
// Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
m.node, err = m.h.db.GetNodeByID(m.node.ID)
startWrite := time.Now()
_, err := m.w.Write(update)
if err != nil {
m.errf(err, "Could not get machine from db")
m.errf(err, "could not write the map response, for mapSession: %p", m)
return
}
updateType := "full"
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
for _, nodeID := range update.ChangeNodes {
changed[nodeID] = true
}
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change"
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch"
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
for _, nodeID := range update.Removed {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
updateType = "derp"
}
if err != nil {
m.errf(err, "Could not get the create map update")
return
}
// Only send update if there is change
if data != nil {
startWrite := time.Now()
_, err = m.w.Write(data)
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
return
}
err = rc.Flush()
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return
}
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID.String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent")
m.resetKeepAlive()
}
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil {
m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
_, err = m.w.Write(data)
if err != nil {
m.errf(err, "Cannot write keep alive message")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
err = rc.Flush()
if err != nil {
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return
}
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
m.tracef("update sent")
m.resetKeepAlive()
// TODO(kradalby): This needs to be rehinked now that we do not have a mapper,
// maybe keepalive can be a static function? I do not think it needs state?
// case <-m.keepAliveTicker.C:
// data, err := m.mapper.KeepAliveResponse(m.req, m.node)
// if err != nil {
// m.errf(err, "Error generating the keep alive msg")
// mapResponseSent.WithLabelValues("error", "keepalive").Inc()
// return
// }
// _, err = m.w.Write(data)
// if err != nil {
// m.errf(err, "Cannot write keep alive message")
// mapResponseSent.WithLabelValues("error", "keepalive").Inc()
// return
// }
// err = rc.Flush()
// if err != nil {
// m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
// mapResponseSent.WithLabelValues("error", "keepalive").Inc()
// return
// }
// if debugHighCardinalityMetrics {
// mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
// }
// mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
}
}
}

View File

@ -196,3 +196,23 @@ type RegisterNode struct {
Node Node
Registered chan *Node
}
type Change struct {
NodeChange NodeChange
UserChange UserChange
}
type NodeChangeWhat string
const (
NodeChangeCameOnline NodeChangeWhat = "node-online"
)
type NodeChange struct {
ID NodeID
What NodeChangeWhat
}
type UserChange struct {
ID UserID
}