1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-30 00:09:42 +01:00

ensure expire routines are cleaned up (#1924)

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-05-02 17:57:53 +02:00 committed by GitHub
parent a9c568c801
commit 622aa82da2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -70,7 +70,7 @@ var (
const ( const (
AuthPrefix = "Bearer " AuthPrefix = "Bearer "
updateInterval = 5000 updateInterval = 5 * time.Second
privateKeyFileMode = 0o600 privateKeyFileMode = 0o600
headscaleDirPerm = 0o700 headscaleDirPerm = 0o700
@ -219,10 +219,15 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been // deleteExpireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. // seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) { func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(every)
for range ticker.C { for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
var removed []types.NodeID var removed []types.NodeID
var changed []types.NodeID var changed []types.NodeID
if err := h.db.Write(func(tx *gorm.DB) error { if err := h.db.Write(func(tx *gorm.DB) error {
@ -250,19 +255,24 @@ func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
}) })
} }
} }
}
} }
// expireExpiredMachines expires nodes that have an explicit expiry set // expireExpiredNodes expires nodes that have an explicit expiry set
// after that expiry time has passed. // after that expiry time has passed.
func (h *Headscale) expireExpiredMachines(intervalMs int64) { func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) {
interval := time.Duration(intervalMs) * time.Millisecond ticker := time.NewTicker(every)
ticker := time.NewTicker(interval)
lastCheck := time.Unix(0, 0) lastCheck := time.Unix(0, 0)
var update types.StateUpdate var update types.StateUpdate
var changed bool var changed bool
for range ticker.C { for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
if err := h.db.Write(func(tx *gorm.DB) error { if err := h.db.Write(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
@ -279,6 +289,7 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
h.nodeNotifier.NotifyAll(ctx, update) h.nodeNotifier.NotifyAll(ctx, update)
} }
} }
}
} }
// scheduledDERPMapUpdateWorker refreshes the DERPMap stored on the global object // scheduledDERPMapUpdateWorker refreshes the DERPMap stored on the global object
@ -538,10 +549,13 @@ func (h *Headscale) Serve() error {
return errEmptyInitialDERPMap return errEmptyInitialDERPMap
} }
// TODO(kradalby): These should have cancel channels and be cleaned expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background())
// up on shutdown. defer expireEphemeralCancel()
go h.deleteExpireEphemeralNodes(updateInterval) go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval)
go h.expireExpiredMachines(updateInterval)
expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background())
defer expireNodeCancel()
go h.expireExpiredNodes(expireNodeCtx, updateInterval)
if zl.GlobalLevel() == zl.TraceLevel { if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true zerolog.RespLog = true
@ -805,6 +819,9 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
expireNodeCancel()
expireEphemeralCancel()
trace("closing map sessions") trace("closing map sessions")
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
for _, mapSess := range h.mapSessions { for _, mapSess := range h.mapSessions {