From 0aa2c52718534345caec722cee76a9edea9eb512 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 10 Sep 2025 15:34:16 +0200 Subject: [PATCH] app: fix sigint hanging When the node notifier was replaced with batcher, we removed its closing, but forgot to add the batchers so it was never stopping node connections and waiting forever. Fixes #2751 Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 21 ++++++++-------- hscontrol/mapper/batcher_lockfree.go | 37 +++++++++++++++++++++++----- hscontrol/mapper/batcher_test.go | 18 ++++++++++---- hscontrol/poll.go | 4 +-- 4 files changed, 57 insertions(+), 23 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 6f669d4a..885066a0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -100,7 +100,7 @@ type Headscale struct { authProvider AuthProvider mapBatcher mapper.Batcher - pollNetMapStreamWG sync.WaitGroup + clientStreamsOpen sync.WaitGroup } var ( @@ -129,10 +129,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } app := Headscale{ - cfg: cfg, - noisePrivateKey: noisePrivateKey, - pollNetMapStreamWG: sync.WaitGroup{}, - state: s, + cfg: cfg, + noisePrivateKey: noisePrivateKey, + clientStreamsOpen: sync.WaitGroup{}, + state: s, } // Initialize ephemeral garbage collector @@ -813,10 +813,11 @@ func (h *Headscale) Serve() error { log.Error().Err(err).Msg("failed to shutdown http") } - info("closing node notifier") + info("closing batcher") + h.mapBatcher.Close() info("waiting for netmap stream to close") - h.pollNetMapStreamWG.Wait() + h.clientStreamsOpen.Wait() info("shutting down grpc server (socket)") grpcSocket.GracefulStop() @@ -842,11 +843,11 @@ func (h *Headscale) Serve() error { info("closing socket listener") socketListener.Close() - // Close db connections - info("closing database connection") + // Close state connections + info("closing state and database") err = h.state.Close() if err != nil { - log.Error().Err(err).Msg("failed to close db") + log.Error().Err(err).Msg("failed to close state") } log.Info(). diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index aaa58f2f..b403fd14 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -153,7 +153,7 @@ func (b *LockFreeBatcher) Start() { func (b *LockFreeBatcher) Close() { if b.cancel != nil { b.cancel() - b.cancel = nil // Prevent multiple calls + b.cancel = nil } // Only close workCh once @@ -163,10 +163,15 @@ func (b *LockFreeBatcher) Close() { default: close(b.workCh) } + + // Close the underlying channels supplying the data to the clients. + b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { + conn.close() + return true + }) } func (b *LockFreeBatcher) doWork() { - for i := range b.workers { go b.worker(i + 1) } @@ -184,17 +189,18 @@ func (b *LockFreeBatcher) doWork() { // Clean up nodes that have been offline for too long b.cleanupOfflineNodes() case <-b.ctx.Done(): + log.Info().Msg("batcher context done, stopping to feed workers") return } } } func (b *LockFreeBatcher) worker(workerID int) { - for { select { case w, ok := <-b.workCh: if !ok { + log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID) return } @@ -213,7 +219,7 @@ func (b *LockFreeBatcher) worker(workerID int) { if result.err != nil { b.workErrors.Add(1) log.Error().Err(result.err). - Int("workerID", workerID). + Int("worker.id", workerID). Uint64("node.id", w.nodeID.Uint64()). Str("change", w.c.Change.String()). Msg("failed to generate map response for synchronous work") @@ -223,7 +229,7 @@ func (b *LockFreeBatcher) worker(workerID int) { b.workErrors.Add(1) log.Error().Err(result.err). - Int("workerID", workerID). + Int("worker.id", workerID). Uint64("node.id", w.nodeID.Uint64()). Msg("node not found for synchronous work") } @@ -248,13 +254,14 @@ func (b *LockFreeBatcher) worker(workerID int) { if err != nil { b.workErrors.Add(1) log.Error().Err(err). - Int("workerID", workerID). + Int("worker.id", workerID). Uint64("node.id", w.c.NodeID.Uint64()). Str("change", w.c.Change.String()). Msg("failed to apply change") } } case <-b.ctx.Done(): + log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker") return } } @@ -336,6 +343,7 @@ func (b *LockFreeBatcher) processBatchedChanges() { } // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. +// TODO(kradalby): reevaluate if we want to keep this. func (b *LockFreeBatcher) cleanupOfflineNodes() { cleanupThreshold := 15 * time.Minute now := time.Now() @@ -477,6 +485,15 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC } } +func (mc *multiChannelNodeConn) close() { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for _, conn := range mc.connections { + close(conn.c) + } +} + // addConnection adds a new connection. func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { mutexWaitStart := time.Now() @@ -530,6 +547,10 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int { // send broadcasts data to all active connections for the node. func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + mc.mutex.Lock() defer mc.mutex.Unlock() @@ -597,6 +618,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // send sends data to a single connection entry with timeout-based stale connection detection. func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + // Use a short timeout to detect stale connections where the client isn't reading the channel. // This is critical for detecting Docker containers that are forcefully terminated // but still have channels that appear open. diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index efc96f98..74277c6c 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -1361,7 +1361,11 @@ func TestBatcherConcurrentClients(t *testing.T) { go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { for { select { - case data := <-channel: + case data, ok := <-channel: + if !ok { + // Channel was closed, exit gracefully + return + } if valid, reason := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1419,24 +1423,28 @@ func TestBatcherConcurrentClients(t *testing.T) { ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) churningChannelsMutex.Lock() - churningChannels[nodeID] = ch - churningChannelsMutex.Unlock() + batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { for { select { - case data := <-ch: + case data, ok := <-ch: + if !ok { + // Channel was closed, exit gracefully + return + } if valid, _ := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, 1, ) // Use 1 as update size since we have MapResponse } - case <-time.After(20 * time.Millisecond): + case <-time.After(500 * time.Millisecond): + // Longer timeout to prevent premature exit during heavy load return } } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index cfe89b1a..c0d6e6b3 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -186,8 +186,8 @@ func (m *mapSession) serveLongPoll() { }() // Set up the client stream - m.h.pollNetMapStreamWG.Add(1) - defer m.h.pollNetMapStreamWG.Done() + m.h.clientStreamsOpen.Add(1) + defer m.h.clientStreamsOpen.Done() ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) defer cancel()