1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-16 17:50:44 +02:00

app: fix sigint hanging

When the node notifier was replaced with batcher, we removed
its closing, but forgot to add the batchers so it was never
stopping node connections and waiting forever.

Fixes #2751

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-09-10 15:34:16 +02:00 committed by Kristoffer Dalby
parent 01c1f6f82a
commit d41fb4d540
4 changed files with 57 additions and 23 deletions

View File

@ -100,7 +100,7 @@ type Headscale struct {
authProvider AuthProvider
mapBatcher mapper.Batcher
pollNetMapStreamWG sync.WaitGroup
clientStreamsOpen sync.WaitGroup
}
var (
@ -129,10 +129,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
app := Headscale{
cfg: cfg,
noisePrivateKey: noisePrivateKey,
pollNetMapStreamWG: sync.WaitGroup{},
state: s,
cfg: cfg,
noisePrivateKey: noisePrivateKey,
clientStreamsOpen: sync.WaitGroup{},
state: s,
}
// Initialize ephemeral garbage collector
@ -813,10 +813,11 @@ func (h *Headscale) Serve() error {
log.Error().Err(err).Msg("failed to shutdown http")
}
info("closing node notifier")
info("closing batcher")
h.mapBatcher.Close()
info("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
h.clientStreamsOpen.Wait()
info("shutting down grpc server (socket)")
grpcSocket.GracefulStop()
@ -842,11 +843,11 @@ func (h *Headscale) Serve() error {
info("closing socket listener")
socketListener.Close()
// Close db connections
info("closing database connection")
// Close state connections
info("closing state and database")
err = h.state.Close()
if err != nil {
log.Error().Err(err).Msg("failed to close db")
log.Error().Err(err).Msg("failed to close state")
}
log.Info().

View File

@ -153,7 +153,7 @@ func (b *LockFreeBatcher) Start() {
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
b.cancel = nil // Prevent multiple calls
b.cancel = nil
}
// Only close workCh once
@ -163,10 +163,15 @@ func (b *LockFreeBatcher) Close() {
default:
close(b.workCh)
}
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
conn.close()
return true
})
}
func (b *LockFreeBatcher) doWork() {
for i := range b.workers {
go b.worker(i + 1)
}
@ -184,17 +189,18 @@ func (b *LockFreeBatcher) doWork() {
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.ctx.Done():
log.Info().Msg("batcher context done, stopping to feed workers")
return
}
}
}
func (b *LockFreeBatcher) worker(workerID int) {
for {
select {
case w, ok := <-b.workCh:
if !ok {
log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID)
return
}
@ -213,7 +219,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Int("worker.id", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to generate map response for synchronous work")
@ -223,7 +229,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Int("worker.id", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
@ -248,13 +254,14 @@ func (b *LockFreeBatcher) worker(workerID int) {
if err != nil {
b.workErrors.Add(1)
log.Error().Err(err).
Int("workerID", workerID).
Int("worker.id", workerID).
Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to apply change")
}
}
case <-b.ctx.Done():
log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker")
return
}
}
@ -336,6 +343,7 @@ func (b *LockFreeBatcher) processBatchedChanges() {
}
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
// TODO(kradalby): reevaluate if we want to keep this.
func (b *LockFreeBatcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute
now := time.Now()
@ -477,6 +485,15 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
}
}
func (mc *multiChannelNodeConn) close() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
close(conn.c)
}
}
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
@ -530,6 +547,10 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
// send broadcasts data to all active connections for the node.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
mc.mutex.Lock()
defer mc.mutex.Unlock()
@ -597,6 +618,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.

View File

@ -1361,7 +1361,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
for {
select {
case data := <-channel:
case data, ok := <-channel:
if !ok {
// Channel was closed, exit gracefully
return
}
if valid, reason := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@ -1419,24 +1423,28 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking
go func() {
for {
select {
case data := <-ch:
case data, ok := <-ch:
if !ok {
// Channel was closed, exit gracefully
return
}
if valid, _ := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
1,
) // Use 1 as update size since we have MapResponse
}
case <-time.After(20 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
// Longer timeout to prevent premature exit during heavy load
return
}
}

View File

@ -186,8 +186,8 @@ func (m *mapSession) serveLongPoll() {
}()
// Set up the client stream
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
m.h.clientStreamsOpen.Add(1)
defer m.h.clientStreamsOpen.Done()
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel()