diff --git a/app.go b/app.go index fe1b954b..e5f44103 100644 --- a/app.go +++ b/app.go @@ -58,7 +58,8 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules *[]tailcfg.FilterRule - clientsUpdateChannels sync.Map + clientsUpdateChannels sync.Map + clientsUpdateChannelMutex sync.Mutex lastStateChange sync.Map } diff --git a/machine.go b/machine.go index 5352f741..57c48ba8 100644 --- a/machine.go +++ b/machine.go @@ -266,7 +266,7 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) { Str("peer", p.Name). Str("address", p.Addresses[0].String()). Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) - err := h.requestUpdate(p) + err := h.sendRequestOnUpdateChannel(p) if err != nil { log.Info(). Str("func", "notifyChangesToPeers"). @@ -283,7 +283,45 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) { } } -func (h *Headscale) requestUpdate(m *tailcfg.Node) error { +func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { + var updateChan chan struct{} + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + updateChan = unwrapped + } else { + log.Error(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Failed to convert update channel to struct{}") + } + } else { + log.Debug(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Update channel not found, creating") + + updateChan = make(chan struct{}) + h.clientsUpdateChannels.Store(m.ID, updateChan) + } + return updateChan +} + +func (h *Headscale) closeUpdateChannel(m *Machine) { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + close(unwrapped) + } + } + h.clientsUpdateChannels.Delete(m.ID) +} + +func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) if ok { log.Info(). diff --git a/poll.go b/poll.go index e85c7a9f..d086fc44 100644 --- a/poll.go +++ b/poll.go @@ -134,27 +134,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Str("id", c.Param("id")). Str("machine", m.Name). Msg("Loading or creating update channel") - var updateChan chan struct{} - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if wrapped, ok := storedChan.(chan struct{}); ok { - updateChan = wrapped - } else { - log.Error(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Failed to convert update channel to struct{}") - } - } else { - log.Debug(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Update channel not found, creating") - - updateChan = make(chan struct{}) - h.clientsUpdateChannels.Store(m.ID, updateChan) - } + updateChan := h.getOrOpenUpdateChannel(&m) pollDataChan := make(chan []byte) // defer close(pollData) @@ -215,7 +195,7 @@ func (h *Headscale) PollNetMapStream( mKey wgkey.Key, pollDataChan chan []byte, keepAliveChan chan []byte, - updateChan chan struct{}, + updateChan <-chan struct{}, cancelKeepAlive chan struct{}, ) { go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m) @@ -364,8 +344,7 @@ func (h *Headscale) PollNetMapStream( cancelKeepAlive <- struct{}{} - h.clientsUpdateChannels.Delete(m.ID) - // close(updateChan) + h.closeUpdateChannel(&m) close(pollDataChan) @@ -411,7 +390,7 @@ func (h *Headscale) scheduledPollWorker( // Send an update request regardless of outdated or not, if data is sent // to the node is determined in the updateChan consumer block n, _ := m.toNode() - err := h.requestUpdate(n) + err := h.sendRequestOnUpdateChannel(n) if err != nil { log.Error(). Str("func", "keepAlive").