diff --git a/hscontrol/app.go b/hscontrol/app.go index 6f669d4a..885066a0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -100,7 +100,7 @@ type Headscale struct { authProvider AuthProvider mapBatcher mapper.Batcher - pollNetMapStreamWG sync.WaitGroup + clientStreamsOpen sync.WaitGroup } var ( @@ -129,10 +129,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } app := Headscale{ - cfg: cfg, - noisePrivateKey: noisePrivateKey, - pollNetMapStreamWG: sync.WaitGroup{}, - state: s, + cfg: cfg, + noisePrivateKey: noisePrivateKey, + clientStreamsOpen: sync.WaitGroup{}, + state: s, } // Initialize ephemeral garbage collector @@ -813,10 +813,11 @@ func (h *Headscale) Serve() error { log.Error().Err(err).Msg("failed to shutdown http") } - info("closing node notifier") + info("closing batcher") + h.mapBatcher.Close() info("waiting for netmap stream to close") - h.pollNetMapStreamWG.Wait() + h.clientStreamsOpen.Wait() info("shutting down grpc server (socket)") grpcSocket.GracefulStop() @@ -842,11 +843,11 @@ func (h *Headscale) Serve() error { info("closing socket listener") socketListener.Close() - // Close db connections - info("closing database connection") + // Close state connections + info("closing state and database") err = h.state.Close() if err != nil { - log.Error().Err(err).Msg("failed to close db") + log.Error().Err(err).Msg("failed to close state") } log.Info(). diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index aaa58f2f..b403fd14 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -153,7 +153,7 @@ func (b *LockFreeBatcher) Start() { func (b *LockFreeBatcher) Close() { if b.cancel != nil { b.cancel() - b.cancel = nil // Prevent multiple calls + b.cancel = nil } // Only close workCh once @@ -163,10 +163,15 @@ func (b *LockFreeBatcher) Close() { default: close(b.workCh) } + + // Close the underlying channels supplying the data to the clients. + b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { + conn.close() + return true + }) } func (b *LockFreeBatcher) doWork() { - for i := range b.workers { go b.worker(i + 1) } @@ -184,17 +189,18 @@ func (b *LockFreeBatcher) doWork() { // Clean up nodes that have been offline for too long b.cleanupOfflineNodes() case <-b.ctx.Done(): + log.Info().Msg("batcher context done, stopping to feed workers") return } } } func (b *LockFreeBatcher) worker(workerID int) { - for { select { case w, ok := <-b.workCh: if !ok { + log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID) return } @@ -213,7 +219,7 @@ func (b *LockFreeBatcher) worker(workerID int) { if result.err != nil { b.workErrors.Add(1) log.Error().Err(result.err). - Int("workerID", workerID). + Int("worker.id", workerID). Uint64("node.id", w.nodeID.Uint64()). Str("change", w.c.Change.String()). Msg("failed to generate map response for synchronous work") @@ -223,7 +229,7 @@ func (b *LockFreeBatcher) worker(workerID int) { b.workErrors.Add(1) log.Error().Err(result.err). - Int("workerID", workerID). + Int("worker.id", workerID). Uint64("node.id", w.nodeID.Uint64()). Msg("node not found for synchronous work") } @@ -248,13 +254,14 @@ func (b *LockFreeBatcher) worker(workerID int) { if err != nil { b.workErrors.Add(1) log.Error().Err(err). - Int("workerID", workerID). + Int("worker.id", workerID). Uint64("node.id", w.c.NodeID.Uint64()). Str("change", w.c.Change.String()). Msg("failed to apply change") } } case <-b.ctx.Done(): + log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker") return } } @@ -336,6 +343,7 @@ func (b *LockFreeBatcher) processBatchedChanges() { } // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. +// TODO(kradalby): reevaluate if we want to keep this. func (b *LockFreeBatcher) cleanupOfflineNodes() { cleanupThreshold := 15 * time.Minute now := time.Now() @@ -477,6 +485,15 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC } } +func (mc *multiChannelNodeConn) close() { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for _, conn := range mc.connections { + close(conn.c) + } +} + // addConnection adds a new connection. func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { mutexWaitStart := time.Now() @@ -530,6 +547,10 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int { // send broadcasts data to all active connections for the node. func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + mc.mutex.Lock() defer mc.mutex.Unlock() @@ -597,6 +618,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // send sends data to a single connection entry with timeout-based stale connection detection. func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + // Use a short timeout to detect stale connections where the client isn't reading the channel. // This is critical for detecting Docker containers that are forcefully terminated // but still have channels that appear open. diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index efc96f98..74277c6c 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -1361,7 +1361,11 @@ func TestBatcherConcurrentClients(t *testing.T) { go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { for { select { - case data := <-channel: + case data, ok := <-channel: + if !ok { + // Channel was closed, exit gracefully + return + } if valid, reason := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1419,24 +1423,28 @@ func TestBatcherConcurrentClients(t *testing.T) { ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) churningChannelsMutex.Lock() - churningChannels[nodeID] = ch - churningChannelsMutex.Unlock() + batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { for { select { - case data := <-ch: + case data, ok := <-ch: + if !ok { + // Channel was closed, exit gracefully + return + } if valid, _ := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, 1, ) // Use 1 as update size since we have MapResponse } - case <-time.After(20 * time.Millisecond): + case <-time.After(500 * time.Millisecond): + // Longer timeout to prevent premature exit during heavy load return } } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index cfe89b1a..c0d6e6b3 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -186,8 +186,8 @@ func (m *mapSession) serveLongPoll() { }() // Set up the client stream - m.h.pollNetMapStreamWG.Add(1) - defer m.h.pollNetMapStreamWG.Done() + m.h.clientStreamsOpen.Add(1) + defer m.h.clientStreamsOpen.Done() ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) defer cancel()