From 88d7ac04bf7a78f378cd3015c1b1bb083ba54cb3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 20 Aug 2021 16:52:34 +0100 Subject: [PATCH] Account for racecondition in deleting/closing update channel This commit tries to address the possible raceondition that can happen if a client closes its connection after we have fetched it from the syncmap before sending the message. To try to avoid introducing new dead lock conditions, all messages sent to updateChannel has been moved into a function, which handles the locking (instead of calling it all over the place) The same lock is used around the delete/close function. --- app.go | 3 ++- machine.go | 42 ++++++++++++++++++++++++++++++++++++++++-- poll.go | 29 ++++------------------------- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/app.go b/app.go index fe1b954b..e5f44103 100644 --- a/app.go +++ b/app.go @@ -58,7 +58,8 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules *[]tailcfg.FilterRule - clientsUpdateChannels sync.Map + clientsUpdateChannels sync.Map + clientsUpdateChannelMutex sync.Mutex lastStateChange sync.Map } diff --git a/machine.go b/machine.go index 5352f741..57c48ba8 100644 --- a/machine.go +++ b/machine.go @@ -266,7 +266,7 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) { Str("peer", p.Name). Str("address", p.Addresses[0].String()). Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) - err := h.requestUpdate(p) + err := h.sendRequestOnUpdateChannel(p) if err != nil { log.Info(). Str("func", "notifyChangesToPeers"). @@ -283,7 +283,45 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) { } } -func (h *Headscale) requestUpdate(m *tailcfg.Node) error { +func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { + var updateChan chan struct{} + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + updateChan = unwrapped + } else { + log.Error(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Failed to convert update channel to struct{}") + } + } else { + log.Debug(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Update channel not found, creating") + + updateChan = make(chan struct{}) + h.clientsUpdateChannels.Store(m.ID, updateChan) + } + return updateChan +} + +func (h *Headscale) closeUpdateChannel(m *Machine) { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + close(unwrapped) + } + } + h.clientsUpdateChannels.Delete(m.ID) +} + +func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) if ok { log.Info(). diff --git a/poll.go b/poll.go index e85c7a9f..d086fc44 100644 --- a/poll.go +++ b/poll.go @@ -134,27 +134,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Str("id", c.Param("id")). Str("machine", m.Name). Msg("Loading or creating update channel") - var updateChan chan struct{} - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if wrapped, ok := storedChan.(chan struct{}); ok { - updateChan = wrapped - } else { - log.Error(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Failed to convert update channel to struct{}") - } - } else { - log.Debug(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Update channel not found, creating") - - updateChan = make(chan struct{}) - h.clientsUpdateChannels.Store(m.ID, updateChan) - } + updateChan := h.getOrOpenUpdateChannel(&m) pollDataChan := make(chan []byte) // defer close(pollData) @@ -215,7 +195,7 @@ func (h *Headscale) PollNetMapStream( mKey wgkey.Key, pollDataChan chan []byte, keepAliveChan chan []byte, - updateChan chan struct{}, + updateChan <-chan struct{}, cancelKeepAlive chan struct{}, ) { go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m) @@ -364,8 +344,7 @@ func (h *Headscale) PollNetMapStream( cancelKeepAlive <- struct{}{} - h.clientsUpdateChannels.Delete(m.ID) - // close(updateChan) + h.closeUpdateChannel(&m) close(pollDataChan) @@ -411,7 +390,7 @@ func (h *Headscale) scheduledPollWorker( // Send an update request regardless of outdated or not, if data is sent // to the node is determined in the updateChan consumer block n, _ := m.toNode() - err := h.requestUpdate(n) + err := h.sendRequestOnUpdateChannel(n) if err != nil { log.Error(). Str("func", "keepAlive").