1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

app: fix sigint hanging

When the node notifier was replaced with batcher, we removed
its closing, but forgot to add the batchers so it was never
stopping node connections and waiting forever.

Fixes #2751

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-09-10 15:34:16 +02:00
parent 01c1f6f82a
commit 0aa2c52718
No known key found for this signature in database
4 changed files with 57 additions and 23 deletions

View File

@ -100,7 +100,7 @@ type Headscale struct {
authProvider AuthProvider authProvider AuthProvider
mapBatcher mapper.Batcher mapBatcher mapper.Batcher
pollNetMapStreamWG sync.WaitGroup clientStreamsOpen sync.WaitGroup
} }
var ( var (
@ -131,7 +131,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
noisePrivateKey: noisePrivateKey, noisePrivateKey: noisePrivateKey,
pollNetMapStreamWG: sync.WaitGroup{}, clientStreamsOpen: sync.WaitGroup{},
state: s, state: s,
} }
@ -813,10 +813,11 @@ func (h *Headscale) Serve() error {
log.Error().Err(err).Msg("failed to shutdown http") log.Error().Err(err).Msg("failed to shutdown http")
} }
info("closing node notifier") info("closing batcher")
h.mapBatcher.Close()
info("waiting for netmap stream to close") info("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait() h.clientStreamsOpen.Wait()
info("shutting down grpc server (socket)") info("shutting down grpc server (socket)")
grpcSocket.GracefulStop() grpcSocket.GracefulStop()
@ -842,11 +843,11 @@ func (h *Headscale) Serve() error {
info("closing socket listener") info("closing socket listener")
socketListener.Close() socketListener.Close()
// Close db connections // Close state connections
info("closing database connection") info("closing state and database")
err = h.state.Close() err = h.state.Close()
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to close db") log.Error().Err(err).Msg("failed to close state")
} }
log.Info(). log.Info().

View File

@ -153,7 +153,7 @@ func (b *LockFreeBatcher) Start() {
func (b *LockFreeBatcher) Close() { func (b *LockFreeBatcher) Close() {
if b.cancel != nil { if b.cancel != nil {
b.cancel() b.cancel()
b.cancel = nil // Prevent multiple calls b.cancel = nil
} }
// Only close workCh once // Only close workCh once
@ -163,10 +163,15 @@ func (b *LockFreeBatcher) Close() {
default: default:
close(b.workCh) close(b.workCh)
} }
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
conn.close()
return true
})
} }
func (b *LockFreeBatcher) doWork() { func (b *LockFreeBatcher) doWork() {
for i := range b.workers { for i := range b.workers {
go b.worker(i + 1) go b.worker(i + 1)
} }
@ -184,17 +189,18 @@ func (b *LockFreeBatcher) doWork() {
// Clean up nodes that have been offline for too long // Clean up nodes that have been offline for too long
b.cleanupOfflineNodes() b.cleanupOfflineNodes()
case <-b.ctx.Done(): case <-b.ctx.Done():
log.Info().Msg("batcher context done, stopping to feed workers")
return return
} }
} }
} }
func (b *LockFreeBatcher) worker(workerID int) { func (b *LockFreeBatcher) worker(workerID int) {
for { for {
select { select {
case w, ok := <-b.workCh: case w, ok := <-b.workCh:
if !ok { if !ok {
log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID)
return return
} }
@ -213,7 +219,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
if result.err != nil { if result.err != nil {
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(result.err). log.Error().Err(result.err).
Int("workerID", workerID). Int("worker.id", workerID).
Uint64("node.id", w.nodeID.Uint64()). Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()). Str("change", w.c.Change.String()).
Msg("failed to generate map response for synchronous work") Msg("failed to generate map response for synchronous work")
@ -223,7 +229,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(result.err). log.Error().Err(result.err).
Int("workerID", workerID). Int("worker.id", workerID).
Uint64("node.id", w.nodeID.Uint64()). Uint64("node.id", w.nodeID.Uint64()).
Msg("node not found for synchronous work") Msg("node not found for synchronous work")
} }
@ -248,13 +254,14 @@ func (b *LockFreeBatcher) worker(workerID int) {
if err != nil { if err != nil {
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(err). log.Error().Err(err).
Int("workerID", workerID). Int("worker.id", workerID).
Uint64("node.id", w.c.NodeID.Uint64()). Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()). Str("change", w.c.Change.String()).
Msg("failed to apply change") Msg("failed to apply change")
} }
} }
case <-b.ctx.Done(): case <-b.ctx.Done():
log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker")
return return
} }
} }
@ -336,6 +343,7 @@ func (b *LockFreeBatcher) processBatchedChanges() {
} }
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
// TODO(kradalby): reevaluate if we want to keep this.
func (b *LockFreeBatcher) cleanupOfflineNodes() { func (b *LockFreeBatcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute cleanupThreshold := 15 * time.Minute
now := time.Now() now := time.Now()
@ -477,6 +485,15 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
} }
} }
func (mc *multiChannelNodeConn) close() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
close(conn.c)
}
}
// addConnection adds a new connection. // addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now() mutexWaitStart := time.Now()
@ -530,6 +547,10 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
// send broadcasts data to all active connections for the node. // send broadcasts data to all active connections for the node.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
mc.mutex.Lock() mc.mutex.Lock()
defer mc.mutex.Unlock() defer mc.mutex.Unlock()
@ -597,6 +618,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// send sends data to a single connection entry with timeout-based stale connection detection. // send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Use a short timeout to detect stale connections where the client isn't reading the channel. // Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated // This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open. // but still have channels that appear open.

View File

@ -1361,7 +1361,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
for { for {
select { select {
case data := <-channel: case data, ok := <-channel:
if !ok {
// Channel was closed, exit gracefully
return
}
if valid, reason := validateUpdateContent(data); valid { if valid, reason := validateUpdateContent(data); valid {
tracker.recordUpdate( tracker.recordUpdate(
nodeID, nodeID,
@ -1419,24 +1423,28 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock() churningChannelsMutex.Lock()
churningChannels[nodeID] = ch churningChannels[nodeID] = ch
churningChannelsMutex.Unlock() churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking // Consume updates to prevent blocking
go func() { go func() {
for { for {
select { select {
case data := <-ch: case data, ok := <-ch:
if !ok {
// Channel was closed, exit gracefully
return
}
if valid, _ := validateUpdateContent(data); valid { if valid, _ := validateUpdateContent(data); valid {
tracker.recordUpdate( tracker.recordUpdate(
nodeID, nodeID,
1, 1,
) // Use 1 as update size since we have MapResponse ) // Use 1 as update size since we have MapResponse
} }
case <-time.After(20 * time.Millisecond): case <-time.After(500 * time.Millisecond):
// Longer timeout to prevent premature exit during heavy load
return return
} }
} }

View File

@ -186,8 +186,8 @@ func (m *mapSession) serveLongPoll() {
}() }()
// Set up the client stream // Set up the client stream
m.h.pollNetMapStreamWG.Add(1) m.h.clientStreamsOpen.Add(1)
defer m.h.pollNetMapStreamWG.Done() defer m.h.clientStreamsOpen.Done()
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel() defer cancel()