From f534d429db90ab790212cf2095fb3cc37971e6b4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 13 Sep 2025 12:14:04 +0200 Subject: [PATCH] mapper: send expiry as patch Signed-off-by: Kristoffer Dalby --- hscontrol/db/node.go | 37 -------------------------------- hscontrol/mapper/batcher.go | 15 +++++++++++++ hscontrol/mapper/batcher_test.go | 12 ++++++++--- hscontrol/mapper/builder.go | 1 + hscontrol/mapper/mapper.go | 20 +++++++++++++++++ hscontrol/state/state.go | 10 ++++++--- hscontrol/types/change/change.go | 11 +++++++--- 7 files changed, 60 insertions(+), 46 deletions(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index f899ddd3..e54011c5 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -13,11 +13,9 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" ) @@ -494,41 +492,6 @@ func EnsureUniqueGivenName( return givenName, nil } -// ExpireExpiredNodes checks for nodes that have expired since the last check -// and returns a time to be used for the next check, a StateUpdate -// containing the expired nodes, and a boolean indicating if any nodes were found. -func ExpireExpiredNodes(tx *gorm.DB, - lastCheck time.Time, -) (time.Time, []change.ChangeSet, bool) { - // use the time of the start of the function to ensure we - // dont miss some nodes by returning it _after_ we have - // checked everything. - started := time.Now() - - expired := make([]*tailcfg.PeerChange, 0) - var updates []change.ChangeSet - - nodes, err := ListNodes(tx) - if err != nil { - return time.Unix(0, 0), nil, false - } - for _, node := range nodes { - if node.IsExpired() && node.Expiry.After(lastCheck) { - expired = append(expired, &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: node.Expiry, - }) - updates = append(updates, change.KeyExpiry(node.ID)) - } - } - - if len(expired) > 0 { - return started, updates, true - } - - return started, nil, false -} - // EphemeralGarbageCollector is a garbage collector that will delete nodes after // a certain amount of time. // It is used to delete ephemeral nodes that have disconnected and should be diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index a24b25ea..7f71676d 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -113,6 +113,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, case change.NodeRemove: mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID) + case change.NodeKeyExpiry: + // If the node is the one whose key is expiring, we send a "full" self update + // as nodes will ignore patch updates about themselves (?). + if nodeID == c.NodeID { + mapResp, err = mapper.selfMapResponse(nodeID, version) + // mapResp, err = mapper.fullMapResponse(nodeID, version) + } else { + mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ + { + NodeID: c.NodeID.NodeID(), + KeyExpiry: c.NodeExpiry, + }, + }) + } + default: // The following will always hit this: // change.Full, change.Policy diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 74277c6c..30e75f48 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -1028,7 +1028,9 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // Add multiple changes rapidly to test batching batcher.AddWork(change.DERPSet) - batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID)) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry)) batcher.AddWork(change.DERPSet) batcher.AddWork(change.NodeAdded(testNodes[1].n.ID)) batcher.AddWork(change.DERPSet) @@ -1278,7 +1280,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Add node-specific work occasionally if i%10 == 0 { - batcher.AddWork(change.KeyExpiry(testNode.n.ID)) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry)) } // Rapid removal creates race between worker and removal @@ -1493,7 +1497,9 @@ func TestBatcherConcurrentClients(t *testing.T) { if i%7 == 0 && len(allNodes) > 0 { // Node-specific changes using real nodes node := allNodes[i%len(allNodes)] - batcher.AddWork(change.KeyExpiry(node.n.ID)) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry)) } // Small delay to allow some batching diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 161efef2..1177accb 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -28,6 +28,7 @@ type debugType string const ( fullResponseDebug debugType = "full" + selfResponseDebug debugType = "self" patchResponseDebug debugType = "patch" removeResponseDebug debugType = "remove" changeResponseDebug debugType = "change" diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 880b4608..372bb557 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -158,6 +158,26 @@ func (m *mapper) fullMapResponse( Build() } +func (m *mapper) selfMapResponse( + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, +) (*tailcfg.MapResponse, error) { + ma, err := m.NewMapResponseBuilder(nodeID). + WithDebugType(selfResponseDebug). + WithCapabilityVersion(capVer). + WithSelfNode(). + Build() + if err != nil { + return nil, err + } + + // Set the peers to nil, to ensure the node does not think + // its getting a new list. + ma.Peers = nil + + return ma, err +} + func (m *mapper) derpMapResponse( nodeID types.NodeID, ) (*tailcfg.MapResponse, error) { diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index b445f4e1..15597706 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -650,7 +650,7 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node } if !c.IsFull() { - c = change.KeyExpiry(nodeID) + c = change.KeyExpiry(nodeID, expiry) } return n, c, nil @@ -898,7 +898,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // Why check After(lastCheck): We only want to notify about nodes that // expired since the last check to avoid duplicate notifications if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { - updates = append(updates, change.KeyExpiry(node.ID())) + updates = append(updates, change.KeyExpiry(node.ID(), node.Expiry().Get())) } } @@ -1118,7 +1118,11 @@ func (s *State) HandleNodeFromAuthPath( // Get updated node from NodeStore updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) - return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil + if expiry != nil { + return updatedNode, change.KeyExpiry(existingNodeView.ID(), *expiry), nil + } + + return updatedNode, change.FullSet, nil } // New node registration diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 5c5ea8b8..aac8acd6 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -3,6 +3,7 @@ package change import ( "errors" + "time" "github.com/juanfont/headscale/hscontrol/types" ) @@ -68,6 +69,9 @@ type ChangeSet struct { // IsSubnetRouter indicates whether the node is a subnet router. IsSubnetRouter bool + + // NodeExpiry is set if the change is NodeKeyExpiry. + NodeExpiry *time.Time } func (c *ChangeSet) Validate() error { @@ -179,10 +183,11 @@ func NodeOffline(id types.NodeID) ChangeSet { } } -func KeyExpiry(id types.NodeID) ChangeSet { +func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet { return ChangeSet{ - Change: NodeKeyExpiry, - NodeID: id, + Change: NodeKeyExpiry, + NodeID: id, + NodeExpiry: &expiry, } }