From 1abc68ccf42794756a61e7c2e850051e79b5a1e0 Mon Sep 17 00:00:00 2001
From: Kristoffer Dalby <kradalby@kradalby.no>
Date: Thu, 5 Aug 2021 22:14:37 +0100
Subject: [PATCH] Removes locks causing deadlock

This commit removes most of the locks in the PollingMap handler as there
was combinations that caused deadlocks. Instead of doing a plain map and
doing the locking ourselves, we use sync.Map which handles it for us.
---
 api.go        | 14 ++++----------
 app.go        |  3 +--
 namespaces.go |  6 ++----
 routes.go     |  4 ++--
 4 files changed, 9 insertions(+), 18 deletions(-)

diff --git a/api.go b/api.go
index 9fd8c7bf..8589d07c 100644
--- a/api.go
+++ b/api.go
@@ -299,9 +299,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 		Str("id", c.Param("id")).
 		Str("machine", m.Name).
 		Msg("Locking poll mutex")
-	h.pollMu.Lock()
-	h.clientsPolling[m.ID] = update
-	h.pollMu.Unlock()
+	h.clientsPolling.Store(m.ID, update)
 	log.Trace().
 		Str("handler", "PollNetMap").
 		Str("id", c.Param("id")).
@@ -373,9 +371,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 		Str("machine", m.Name).
 		Msg("Notifying peers")
 	peers, _ := h.getPeers(m)
-	h.pollMu.Lock()
 	for _, p := range *peers {
-		pUp, ok := h.clientsPolling[uint64(p.ID)]
+		pUp, ok := h.clientsPolling.Load(uint64(p.ID))
 		if ok {
 			log.Info().
 				Str("handler", "PollNetMap").
@@ -383,7 +380,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 				Str("peer", m.Name).
 				Str("address", p.Addresses[0].String()).
 				Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
-			pUp <- []byte{}
+			pUp.(chan []byte) <- []byte{}
 		} else {
 			log.Info().
 				Str("handler", "PollNetMap").
@@ -392,7 +389,6 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 				Msgf("Peer %s does not appear to be polling", p.Name)
 		}
 	}
-	h.pollMu.Unlock()
 
 	go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
 
@@ -448,11 +444,9 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 			now := time.Now().UTC()
 			m.LastSeen = &now
 			h.db.Save(&m)
-			h.pollMu.Lock()
 			cancelKeepAlive <- []byte{}
-			delete(h.clientsPolling, m.ID)
+			h.clientsPolling.Delete(m.ID)
 			close(update)
-			h.pollMu.Unlock()
 			return false
 
 		}
diff --git a/app.go b/app.go
index 45df01c0..668e23b9 100644
--- a/app.go
+++ b/app.go
@@ -59,7 +59,7 @@ type Headscale struct {
 	aclRules  *[]tailcfg.FilterRule
 
 	pollMu         sync.Mutex
-	clientsPolling map[uint64]chan []byte // this is by all means a hackity hack
+	clientsPolling sync.Map
 }
 
 // NewHeadscale returns the Headscale app
@@ -99,7 +99,6 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
 		return nil, err
 	}
 
-	h.clientsPolling = make(map[uint64]chan []byte)
 	return &h, nil
 }
 
diff --git a/namespaces.go b/namespaces.go
index 9bbb6b32..9b8d1904 100644
--- a/namespaces.go
+++ b/namespaces.go
@@ -170,9 +170,8 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
 		}
 		for _, m := range *machines {
 			peers, _ := h.getPeers(m)
-			h.pollMu.Lock()
 			for _, p := range *peers {
-				pUp, ok := h.clientsPolling[uint64(p.ID)]
+				pUp, ok := h.clientsPolling.Load(uint64(p.ID))
 				if ok {
 					log.Info().
 						Str("func", "checkForNamespacesPendingUpdates").
@@ -180,7 +179,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
 						Str("peer", m.Name).
 						Str("address", p.Addresses[0].String()).
 						Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
-					pUp <- []byte{}
+					pUp.(chan []byte) <- []byte{}
 				} else {
 					log.Info().
 						Str("func", "checkForNamespacesPendingUpdates").
@@ -189,7 +188,6 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
 						Msgf("Peer %s does not appear to be polling", p.Name)
 				}
 			}
-			h.pollMu.Unlock()
 		}
 	}
 	newV, err := h.getValue("namespaces_pending_updates")
diff --git a/routes.go b/routes.go
index a02bed30..e188b91c 100644
--- a/routes.go
+++ b/routes.go
@@ -52,8 +52,8 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
 			peers, _ := h.getPeers(*m)
 			h.pollMu.Lock()
 			for _, p := range *peers {
-				if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok {
-					pUp <- []byte{}
+				if pUp, ok := h.clientsPolling.Load(uint64(p.ID)); ok {
+					pUp.(chan []byte) <- []byte{}
 				}
 			}
 			h.pollMu.Unlock()