diff --git a/Makefile b/Makefile index 6b4c02ff..69935bc0 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ test: @go test -coverprofile=coverage.out ./... test_integration: - go test -tags integration -timeout 30m ./... + go test -tags integration -timeout 30m -count=1 ./... test_integration_cli: go test -tags integration -v integration_cli_test.go integration_common_test.go diff --git a/poll.go b/poll.go index 1d2db944..94c60dc1 100644 --- a/poll.go +++ b/poll.go @@ -1,8 +1,10 @@ package headscale import ( + "context" "encoding/json" "errors" + "fmt" "io" "net/http" "time" @@ -152,14 +154,33 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("id", ctx.Param("id")). Str("machine", machine.Name). Msg("Loading or creating update channel") - updateChan := make(chan struct{}) - pollDataChan := make(chan []byte) + // TODO: could probably remove all that duplication once generics land. + closeChanWithLog := func(channel interface{}, name string) { + log.Trace(). + Str("handler", "PollNetMap"). + Str("machine", machine.Name). + Str("channel", "Done"). + Msg(fmt.Sprintf("Closing %s channel", name)) + + switch c := channel.(type) { + case (chan struct{}): + close(c) + + case (chan []byte): + close(c) + } + } + + const chanSize = 8 + updateChan := make(chan struct{}, chanSize) + defer closeChanWithLog(updateChan, "updateChan") + + pollDataChan := make(chan []byte, chanSize) + defer closeChanWithLog(pollDataChan, "pollDataChan") keepAliveChan := make(chan []byte) - - cancelKeepAlive := make(chan struct{}) - defer close(cancelKeepAlive) + defer closeChanWithLog(keepAliveChan, "keepAliveChan") if req.OmitPeers && !req.Stream { log.Info(). @@ -172,7 +193,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // even tho the comments in the tailscale code dont explicitly say so. updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update"). Inc() - go func() { updateChan <- struct{}{} }() + updateChan <- struct{}{} return } else if req.OmitPeers && req.Stream { @@ -193,7 +214,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("handler", "PollNetMap"). Str("machine", machine.Name). Msg("Sending initial map") - go func() { pollDataChan <- data }() + pollDataChan <- data log.Info(). Str("handler", "PollNetMap"). @@ -201,7 +222,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Msg("Notifying peers") updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update"). Inc() - go func() { updateChan <- struct{}{} }() + updateChan <- struct{}{} h.PollNetMapStream( ctx, @@ -211,7 +232,6 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { pollDataChan, keepAliveChan, updateChan, - cancelKeepAlive, ) log.Trace(). Str("handler", "PollNetMap"). @@ -231,16 +251,20 @@ func (h *Headscale) PollNetMapStream( pollDataChan chan []byte, keepAliveChan chan []byte, updateChan chan struct{}, - cancelKeepAlive chan struct{}, ) { - go h.scheduledPollWorker( - cancelKeepAlive, - updateChan, - keepAliveChan, - machineKey, - mapRequest, - machine, - ) + { + ctx, cancel := context.WithCancel(ctx.Request.Context()) + defer cancel() + + go h.scheduledPollWorker( + ctx, + updateChan, + keepAliveChan, + machineKey, + mapRequest, + machine, + ) + } ctx.Stream(func(writer io.Writer) bool { log.Trace(). @@ -455,42 +479,13 @@ func (h *Headscale) PollNetMapStream( machine.LastSeen = &now h.db.Save(&machine) - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Name). - Str("channel", "Done"). - Msg("Cancelling keepAlive channel") - cancelKeepAlive <- struct{}{} - - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Name). - Str("channel", "Done"). - Msg("Closing update channel") - // h.closeUpdateChannel(m) - close(updateChan) - - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Name). - Str("channel", "Done"). - Msg("Closing pollData channel") - close(pollDataChan) - - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Name). - Str("channel", "Done"). - Msg("Closing keepAliveChan channel") - close(keepAliveChan) - return false } }) } func (h *Headscale) scheduledPollWorker( - cancelChan <-chan struct{}, + ctx context.Context, updateChan chan<- struct{}, keepAliveChan chan<- []byte, machineKey key.MachinePublic, @@ -502,7 +497,7 @@ func (h *Headscale) scheduledPollWorker( for { select { - case <-cancelChan: + case <-ctx.Done(): return case <-keepAliveTicker.C: