From ed3a9c8d6d3c0f45f46c24e6b42d404fb4456a09 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 17 Sep 2025 14:23:21 +0200 Subject: [PATCH] mapper: send change instead of full update (#2775) --- hscontrol/db/node.go | 37 -------------------------------- hscontrol/mapper/batcher.go | 37 ++++++++++++++++++++++---------- hscontrol/mapper/batcher_test.go | 12 ++++++++--- hscontrol/mapper/builder.go | 14 ++++-------- hscontrol/mapper/mapper.go | 21 +++++++++++++++++- hscontrol/policy/v2/policy.go | 18 +++++++++++++++- hscontrol/state/state.go | 10 ++++++--- hscontrol/types/change/change.go | 16 +++++++++++--- 8 files changed, 96 insertions(+), 69 deletions(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index f899ddd3..e54011c5 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -13,11 +13,9 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" ) @@ -494,41 +492,6 @@ func EnsureUniqueGivenName( return givenName, nil } -// ExpireExpiredNodes checks for nodes that have expired since the last check -// and returns a time to be used for the next check, a StateUpdate -// containing the expired nodes, and a boolean indicating if any nodes were found. -func ExpireExpiredNodes(tx *gorm.DB, - lastCheck time.Time, -) (time.Time, []change.ChangeSet, bool) { - // use the time of the start of the function to ensure we - // dont miss some nodes by returning it _after_ we have - // checked everything. - started := time.Now() - - expired := make([]*tailcfg.PeerChange, 0) - var updates []change.ChangeSet - - nodes, err := ListNodes(tx) - if err != nil { - return time.Unix(0, 0), nil, false - } - for _, node := range nodes { - if node.IsExpired() && node.Expiry.After(lastCheck) { - expired = append(expired, &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: node.Expiry, - }) - updates = append(updates, change.KeyExpiry(node.ID)) - } - } - - if len(expired) > 0 { - return started, updates, true - } - - return started, nil, false -} - // EphemeralGarbageCollector is a garbage collector that will delete nodes after // a certain amount of time. // It is used to delete ephemeral nodes that have disconnected and should be diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 91564a3a..b56bca08 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -88,16 +88,9 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, // TODO(kradalby): This can potentially be a peer update of the old and new subnet router. mapResp, err = mapper.fullMapResponse(nodeID, version) } else { - // CRITICAL FIX: Read actual online status from NodeStore when available, - // fall back to deriving from change type for unit tests or when NodeStore is empty - var onlineStatus bool - if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() { - // Use actual NodeStore status when available (production case) - onlineStatus = node.IsOnline().Get() - } else { - // Fall back to deriving from change type (unit test case or initial setup) - onlineStatus = c.Change == change.NodeCameOnline - } + // Trust the change type for online/offline status to avoid race conditions + // between NodeStore updates and change processing + onlineStatus := c.Change == change.NodeCameOnline mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ { @@ -108,11 +101,33 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, } case change.NodeNewOrUpdate: - mapResp, err = mapper.fullMapResponse(nodeID, version) + // If the node is the one being updated, we send a self update that preserves peer information + // to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes) + // without losing its view of peer status during rapid reconnection cycles + if c.IsSelfUpdate(nodeID) { + mapResp, err = mapper.selfMapResponse(nodeID, version) + } else { + mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID) + } case change.NodeRemove: mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID) + case change.NodeKeyExpiry: + // If the node is the one whose key is expiring, we send a "full" self update + // as nodes will ignore patch updates about themselves (?). + if c.IsSelfUpdate(nodeID) { + mapResp, err = mapper.selfMapResponse(nodeID, version) + // mapResp, err = mapper.fullMapResponse(nodeID, version) + } else { + mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ + { + NodeID: c.NodeID.NodeID(), + KeyExpiry: c.NodeExpiry, + }, + }) + } + default: // The following will always hit this: // change.Full, change.Policy diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 74277c6c..30e75f48 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -1028,7 +1028,9 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // Add multiple changes rapidly to test batching batcher.AddWork(change.DERPSet) - batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID)) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry)) batcher.AddWork(change.DERPSet) batcher.AddWork(change.NodeAdded(testNodes[1].n.ID)) batcher.AddWork(change.DERPSet) @@ -1278,7 +1280,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Add node-specific work occasionally if i%10 == 0 { - batcher.AddWork(change.KeyExpiry(testNode.n.ID)) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry)) } // Rapid removal creates race between worker and removal @@ -1493,7 +1497,9 @@ func TestBatcherConcurrentClients(t *testing.T) { if i%7 == 0 && len(allNodes) > 0 { // Node-specific changes using real nodes node := allNodes[i%len(allNodes)] - batcher.AddWork(change.KeyExpiry(node.n.ID)) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry)) } // Small delay to allow some batching diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 819d23a3..1177accb 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -28,6 +28,7 @@ type debugType string const ( fullResponseDebug debugType = "full" + selfResponseDebug debugType = "self" patchResponseDebug debugType = "patch" removeResponseDebug debugType = "remove" changeResponseDebug debugType = "change" @@ -68,24 +69,17 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers // WithSelfNode adds the requesting node to the response. func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { - nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID) + nv, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { b.addError(errors.New("node not found")) return b } - // Always use batcher's view of online status for self node - // The batcher respects grace periods for logout scenarios - node := nodeView.AsStruct() - // if b.mapper.batcher != nil { - // node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID)) - // } - _, matchers := b.mapper.state.Filter() tailnode, err := tailNode( - node.View(), b.capVer, b.mapper.state, + nv, b.capVer, b.mapper.state, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + return policy.ReduceRoutes(nv, b.mapper.state.GetNodePrimaryRoutes(id), matchers) }, b.mapper.cfg) if err != nil { diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5e9b9a13..372bb557 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -158,6 +158,26 @@ func (m *mapper) fullMapResponse( Build() } +func (m *mapper) selfMapResponse( + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, +) (*tailcfg.MapResponse, error) { + ma, err := m.NewMapResponseBuilder(nodeID). + WithDebugType(selfResponseDebug). + WithCapabilityVersion(capVer). + WithSelfNode(). + Build() + if err != nil { + return nil, err + } + + // Set the peers to nil, to ensure the node does not think + // its getting a new list. + ma.Peers = nil + + return ma, err +} + func (m *mapper) derpMapResponse( nodeID types.NodeID, ) (*tailcfg.MapResponse, error) { @@ -190,7 +210,6 @@ func (m *mapper) peerChangeResponse( return m.NewMapResponseBuilder(nodeID). WithDebugType(changeResponseDebug). WithCapabilityVersion(capVer). - WithSelfNode(). WithUserProfiles(peers). WithPeerChanges(peers). Build() diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 4215485a..ae3c100e 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -228,7 +228,23 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { defer pm.mu.Unlock() pm.users = users - return pm.updateLocked() + // Clear SSH policy map when users change to force SSH policy recomputation + // This ensures that if SSH policy compilation previously failed due to missing users, + // it will be retried with the new user list + clear(pm.sshPolicyMap) + + changed, err := pm.updateLocked() + if err != nil { + return false, err + } + + // If SSH policies exist, force a policy change when users are updated + // This ensures nodes get updated SSH policies even if other policy hashes didn't change + if pm.pol != nil && pm.pol.SSHs != nil && len(pm.pol.SSHs) > 0 { + return true, nil + } + + return changed, nil } // SetNodes updates the nodes in the policy manager and updates the filter rules. diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index b445f4e1..15597706 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -650,7 +650,7 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node } if !c.IsFull() { - c = change.KeyExpiry(nodeID) + c = change.KeyExpiry(nodeID, expiry) } return n, c, nil @@ -898,7 +898,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // Why check After(lastCheck): We only want to notify about nodes that // expired since the last check to avoid duplicate notifications if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { - updates = append(updates, change.KeyExpiry(node.ID())) + updates = append(updates, change.KeyExpiry(node.ID(), node.Expiry().Get())) } } @@ -1118,7 +1118,11 @@ func (s *State) HandleNodeFromAuthPath( // Get updated node from NodeStore updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) - return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil + if expiry != nil { + return updatedNode, change.KeyExpiry(existingNodeView.ID(), *expiry), nil + } + + return updatedNode, change.FullSet, nil } // New node registration diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 5c5ea8b8..36cf8a4f 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -3,6 +3,7 @@ package change import ( "errors" + "time" "github.com/juanfont/headscale/hscontrol/types" ) @@ -68,6 +69,9 @@ type ChangeSet struct { // IsSubnetRouter indicates whether the node is a subnet router. IsSubnetRouter bool + + // NodeExpiry is set if the change is NodeKeyExpiry. + NodeExpiry *time.Time } func (c *ChangeSet) Validate() error { @@ -126,6 +130,11 @@ func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) { return ret } +// IsSelfUpdate reports whether this ChangeSet represents an update to the given node itself. +func (c ChangeSet) IsSelfUpdate(nodeID types.NodeID) bool { + return c.NodeID == nodeID +} + func (c ChangeSet) AlsoSelf() bool { // If NodeID is 0, it means this ChangeSet is not related to a specific node, // so we consider it as a change that should be sent to all nodes. @@ -179,10 +188,11 @@ func NodeOffline(id types.NodeID) ChangeSet { } } -func KeyExpiry(id types.NodeID) ChangeSet { +func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet { return ChangeSet{ - Change: NodeKeyExpiry, - NodeID: id, + Change: NodeKeyExpiry, + NodeID: id, + NodeExpiry: &expiry, } }