1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

graceful shutdown fix

This commit is contained in:
Grigoriy Mikhalkin 2022-06-30 23:35:22 +02:00
parent c6eb7be7fb
commit 889eff265f
2 changed files with 29 additions and 15 deletions

23
app.go
View File

@ -95,6 +95,7 @@ type Headscale struct {
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
shutdownChan chan struct{} shutdownChan chan struct{}
wg sync.WaitGroup
} }
// Look up the TLS constant relative to user-supplied TLS client // Look up the TLS constant relative to user-supplied TLS client
@ -153,6 +154,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
privateKey: privKey, privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache, registrationCache: registrationCache,
wg: sync.WaitGroup{},
} }
err = app.initDB() err = app.initDB()
@ -567,6 +569,8 @@ func (h *Headscale) Serve() error {
// https://github.com/soheilhy/cmux/issues/68 // https://github.com/soheilhy/cmux/issues/68
// https://github.com/soheilhy/cmux/issues/91 // https://github.com/soheilhy/cmux/issues/91
var grpcServer *grpc.Server
var grpcListener net.Listener
if tlsConfig != nil || h.cfg.GRPCAllowInsecure { if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr)
@ -587,12 +591,12 @@ func (h *Headscale) Serve() error {
log.Warn().Msg("gRPC is running without security") log.Warn().Msg("gRPC is running without security")
} }
grpcServer := grpc.NewServer(grpcOptions...) grpcServer = grpc.NewServer(grpcOptions...)
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
reflection.Register(grpcServer) reflection.Register(grpcServer)
grpcListener, err := net.Listen("tcp", h.cfg.GRPCAddr) grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("failed to bind to TCP address: %w", err)
} }
@ -668,7 +672,7 @@ func (h *Headscale) Serve() error {
syscall.SIGTERM, syscall.SIGTERM,
syscall.SIGQUIT, syscall.SIGQUIT,
syscall.SIGHUP) syscall.SIGHUP)
go func(c chan os.Signal) { sig_func := func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL: // Wait for a SIGINT or SIGKILL:
for { for {
sig := <-c sig := <-c
@ -678,7 +682,7 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config") Msg("Received SIGHUP, reloading ACL and Config")
// TODO(kradalby): Reload config on SIGHUP // TODO(kradalby): Reload config on SIGHUP
if h.cfg.ACL.PolicyPath != "" { if h.cfg.ACL.PolicyPath != "" {
aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
@ -698,7 +702,8 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
h.shutdownChan <- struct{}{} close(h.shutdownChan)
h.wg.Wait()
// Gracefully shut down servers // Gracefully shut down servers
ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout)
@ -710,6 +715,11 @@ func (h *Headscale) Serve() error {
} }
grpcSocket.GracefulStop() grpcSocket.GracefulStop()
if grpcServer != nil {
grpcServer.GracefulStop()
grpcListener.Close()
}
// Close network listeners // Close network listeners
promHTTPListener.Close() promHTTPListener.Close()
httpListener.Close() httpListener.Close()
@ -736,7 +746,8 @@ func (h *Headscale) Serve() error {
os.Exit(0) os.Exit(0)
} }
} }
}(sigc) }
errorGroup.Go(func() error { sig_func(sigc); return nil })
return errorGroup.Wait() return errorGroup.Wait()
} }

21
poll.go
View File

@ -290,6 +290,9 @@ func (h *Headscale) PollNetMapStream(
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
) { ) {
h.wg.Add(1)
defer h.wg.Done()
ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@ -353,9 +356,9 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Data from pollData channel written successfully") Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine) err = h.UpdateMachineFromDatabase(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -431,9 +434,9 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Keep alive sent successfully") Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine) err = h.UpdateMachineFromDatabase(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -588,9 +591,9 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("The client has closed the connection") Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions // TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err := h.UpdateMachineFromDatabase(machine) err := h.UpdateMachineFromDatabase(machine)
if err != nil { if err != nil {
log.Error(). log.Error().