diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index b56bca08..e808df7e 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -8,12 +8,22 @@ import ( "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" "tailscale.com/types/ptr" ) +var ( + mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "headscale", + Name: "mapresponse_generated_total", + Help: "total count of mapresponses generated by response type and change type", + }, []string{"response_type", "change_type"}) +) + type batcherFunc func(cfg *types.Config, state *state.State) Batcher // Batcher defines the common interface for all batcher implementations. @@ -75,21 +85,32 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, } var ( - mapResp *tailcfg.MapResponse - err error + mapResp *tailcfg.MapResponse + err error + responseType string ) + // Record metric when function exits + defer func() { + if err == nil && mapResp != nil && responseType != "" { + mapResponseGenerated.WithLabelValues(responseType, c.Change.String()).Inc() + } + }() + switch c.Change { case change.DERP: + responseType = "derp" mapResp, err = mapper.derpMapResponse(nodeID) case change.NodeCameOnline, change.NodeWentOffline: if c.IsSubnetRouter { // TODO(kradalby): This can potentially be a peer update of the old and new subnet router. + responseType = "full" mapResp, err = mapper.fullMapResponse(nodeID, version) } else { // Trust the change type for online/offline status to avoid race conditions // between NodeStore updates and change processing + responseType = "patch" onlineStatus := c.Change == change.NodeCameOnline mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ @@ -105,21 +126,26 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, // to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes) // without losing its view of peer status during rapid reconnection cycles if c.IsSelfUpdate(nodeID) { + responseType = "self" mapResp, err = mapper.selfMapResponse(nodeID, version) } else { + responseType = "change" mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID) } case change.NodeRemove: + responseType = "remove" mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID) case change.NodeKeyExpiry: // If the node is the one whose key is expiring, we send a "full" self update // as nodes will ignore patch updates about themselves (?). if c.IsSelfUpdate(nodeID) { + responseType = "self" mapResp, err = mapper.selfMapResponse(nodeID, version) // mapResp, err = mapper.fullMapResponse(nodeID, version) } else { + responseType = "patch" mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ { NodeID: c.NodeID.NodeID(), @@ -128,9 +154,34 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, }) } + case change.NodeEndpoint, change.NodeDERP: + // Endpoint or DERP changes can be sent as lightweight patches. + // Query the NodeStore for the current peer state to construct the PeerChange. + // Even if only endpoint or only DERP changed, we include both in the patch + // since they're often updated together and it's minimal overhead. + responseType = "patch" + peer, found := mapper.state.GetNodeByID(c.NodeID) + if !found { + return nil, fmt.Errorf("node not found in NodeStore: %d", c.NodeID) + } + + peerChange := &tailcfg.PeerChange{ + NodeID: c.NodeID.NodeID(), + Endpoints: peer.Endpoints().AsSlice(), + DERPRegion: 0, // Will be set below if available + } + + // Extract DERP region from Hostinfo if available + if hi := peer.AsStruct().Hostinfo; hi != nil && hi.NetInfo != nil { + peerChange.DERPRegion = hi.NetInfo.PreferredDERP + } + + mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{peerChange}) + default: // The following will always hit this: // change.Full, change.Policy + responseType = "full" mapResp, err = mapper.fullMapResponse(nodeID, version) } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 30e75f48..a327a8f9 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -50,7 +50,11 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo // Send the online notification that poll.go would normally send // This ensures other nodes get notified about this node coming online - t.AddWork(change.NodeOnline(id)) + node, ok := t.state.GetNodeByID(id) + if !ok { + return fmt.Errorf("node not found after adding to batcher: %d", id) + } + t.AddWork(change.NodeOnline(node)) return nil } @@ -65,7 +69,10 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe // Send the offline notification that poll.go would normally send // Do this BEFORE removing from batcher so the change can be processed - t.AddWork(change.NodeOffline(id)) + node, ok := t.state.GetNodeByID(id) + if ok { + t.AddWork(change.NodeOffline(node)) + } // Finally remove from the real batcher removed := t.Batcher.RemoveNode(id, c) diff --git a/hscontrol/state/endpoint_test.go b/hscontrol/state/endpoint_test.go new file mode 100644 index 00000000..119933e6 --- /dev/null +++ b/hscontrol/state/endpoint_test.go @@ -0,0 +1,108 @@ +package state + +import ( + "net/netip" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +// TestEndpointStorageInNodeStore verifies that endpoints sent in MapRequest via ApplyPeerChange +// are correctly stored in the NodeStore and can be retrieved for sending to peers. +// This test reproduces the issue reported in https://github.com/juanfont/headscale/issues/2846 +func TestEndpointStorageInNodeStore(t *testing.T) { + // Create two test nodes + node1 := createTestNode(1, 1, "test-user", "node1") + node2 := createTestNode(2, 1, "test-user", "node2") + + // Create NodeStore with allow-all peers function + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Add both nodes to NodeStore + store.PutNode(node1) + store.PutNode(node2) + + // Create a MapRequest with endpoints for node1 + endpoints := []netip.AddrPort{ + netip.MustParseAddrPort("192.168.1.1:41641"), + netip.MustParseAddrPort("10.0.0.1:41641"), + } + + mapReq := tailcfg.MapRequest{ + NodeKey: node1.NodeKey, + DiscoKey: node1.DiscoKey, + Endpoints: endpoints, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "node1", + }, + } + + // Simulate what UpdateNodeFromMapRequest does: create PeerChange and apply it + peerChange := node1.PeerChangeFromMapRequest(mapReq) + + // Verify PeerChange has endpoints + require.NotNil(t, peerChange.Endpoints, "PeerChange should contain endpoints") + assert.Equal(t, len(endpoints), len(peerChange.Endpoints), + "PeerChange should have same number of endpoints as MapRequest") + + // Apply the PeerChange via NodeStore.UpdateNode + updatedNode, ok := store.UpdateNode(node1.ID, func(n *types.Node) { + n.ApplyPeerChange(&peerChange) + }) + require.True(t, ok, "UpdateNode should succeed") + require.True(t, updatedNode.Valid(), "Updated node should be valid") + + // Verify endpoints are in the updated node view + storedEndpoints := updatedNode.Endpoints().AsSlice() + assert.Equal(t, len(endpoints), len(storedEndpoints), + "NodeStore should have same number of endpoints as sent") + + if len(storedEndpoints) == len(endpoints) { + for i, ep := range endpoints { + assert.Equal(t, ep, storedEndpoints[i], + "Endpoint %d should match", i) + } + } + + // Verify we can retrieve the node again and endpoints are still there + retrievedNode, found := store.GetNode(node1.ID) + require.True(t, found, "node1 should exist in NodeStore") + + retrievedEndpoints := retrievedNode.Endpoints().AsSlice() + assert.Equal(t, len(endpoints), len(retrievedEndpoints), + "Retrieved node should have same number of endpoints") + + // Verify that when we get node1 as a peer of node2, it has endpoints + // This is the critical part that was failing in the bug report + peers := store.ListPeers(node2.ID) + require.Greater(t, peers.Len(), 0, "node2 should have at least one peer") + + // Find node1 in the peer list + var node1Peer types.NodeView + foundPeer := false + for _, peer := range peers.All() { + if peer.ID() == node1.ID { + node1Peer = peer + foundPeer = true + break + } + } + require.True(t, foundPeer, "node1 should be in node2's peer list") + + // Check that node1's endpoints are available in the peer view + peerEndpoints := node1Peer.Endpoints().AsSlice() + assert.Equal(t, len(endpoints), len(peerEndpoints), + "Peer view should have same number of endpoints as sent") + + if len(peerEndpoints) == len(endpoints) { + for i, ep := range endpoints { + assert.Equal(t, ep, peerEndpoints[i], + "Peer endpoint %d should match", i) + } + } +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index c340adc2..e88d2641 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -451,7 +451,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { if !ok { return nil } - c := []change.ChangeSet{change.NodeOnline(id)} + c := []change.ChangeSet{change.NodeOnline(node)} log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") @@ -497,7 +497,7 @@ func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { // announced are served to any nodes. routeChange := s.primaryRoutes.SetRoutes(id) - cs := []change.ChangeSet{change.NodeOffline(id), c} + cs := []change.ChangeSet{change.NodeOffline(node), c} // If we have a policy change or route change, return that as it's more comprehensive // Otherwise, return the NodeOffline change to ensure nodes are notified @@ -1583,10 +1583,16 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest var routeChange bool var hostinfoChanged bool var needsRouteApproval bool + var endpointChanged bool + var derpChanged bool // We need to ensure we update the node as it is in the NodeStore at // the time of the request. updatedNode, ok := s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { peerChange := currentNode.PeerChangeFromMapRequest(req) + + // Track what specifically changed + endpointChanged = peerChange.Endpoints != nil + derpChanged = peerChange.DERPRegion != 0 hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) // Get the correct NetInfo to use @@ -1736,6 +1742,23 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest return nodeRouteChange, nil } + // Determine the most specific change type based on what actually changed. + // This allows us to send lightweight patch updates instead of full map responses. + // Hostinfo changes require NodeAdded (full update) as they may affect many fields. + if hostinfoChanged { + return change.NodeAdded(id), nil + } + + // Return specific change types for endpoint and/or DERP updates. + // The batcher will query NodeStore for current state and include both in PeerChange if both changed. + // Prioritize endpoint changes as they're more common and important for connectivity. + if endpointChanged { + return change.EndpointUpdate(id), nil + } + if derpChanged { + return change.DERPUpdate(id), nil + } + return change.NodeAdded(id), nil } @@ -1768,6 +1791,16 @@ func routesChanged(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { return !slices.Equal(oldRoutes, newRoutes) } +func endpointsChanged(oldNode types.NodeView, newEndpoints []netip.AddrPort) bool { + var oldEndpoints []netip.AddrPort + if oldNode.Valid() { + oldEndpoints = oldNode.Endpoints().AsSlice() + } + + // Use the same comparison logic as PeerChangeFromMapRequest + return types.EndpointsChanged(oldEndpoints, newEndpoints) +} + func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { return peerChange.Key == nil && peerChange.DiscoKey == nil && diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 36cf8a4f..ef1d4d01 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -29,11 +29,13 @@ const ( ExtraRecords Change = 13 // Node changes. - NodeCameOnline Change = 21 - NodeWentOffline Change = 22 - NodeRemove Change = 23 - NodeKeyExpiry Change = 24 - NodeNewOrUpdate Change = 25 + NodeCameOnline Change = 21 + NodeWentOffline Change = 22 + NodeRemove Change = 23 + NodeKeyExpiry Change = 24 + NodeNewOrUpdate Change = 25 + NodeEndpoint Change = 26 + NodeDERP Change = 27 // User changes. UserNewOrUpdate Change = 51 @@ -174,17 +176,19 @@ func NodeRemoved(id types.NodeID) ChangeSet { } } -func NodeOnline(id types.NodeID) ChangeSet { +func NodeOnline(node types.NodeView) ChangeSet { return ChangeSet{ - Change: NodeCameOnline, - NodeID: id, + Change: NodeCameOnline, + NodeID: node.ID(), + IsSubnetRouter: node.IsSubnetRouter(), } } -func NodeOffline(id types.NodeID) ChangeSet { +func NodeOffline(node types.NodeView) ChangeSet { return ChangeSet{ - Change: NodeWentOffline, - NodeID: id, + Change: NodeWentOffline, + NodeID: node.ID(), + IsSubnetRouter: node.IsSubnetRouter(), } } @@ -196,6 +200,20 @@ func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet { } } +func EndpointUpdate(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeEndpoint, + NodeID: id, + } +} + +func DERPUpdate(id types.NodeID) ChangeSet { + return ChangeSet{ + Change: NodeDERP, + NodeID: id, + } +} + func UserAdded(id types.UserID) ChangeSet { return ChangeSet{ Change: UserNewOrUpdate, diff --git a/hscontrol/types/change/change_string.go b/hscontrol/types/change/change_string.go index dbf9d17e..fd6059d5 100644 --- a/hscontrol/types/change/change_string.go +++ b/hscontrol/types/change/change_string.go @@ -18,6 +18,8 @@ func _() { _ = x[NodeRemove-23] _ = x[NodeKeyExpiry-24] _ = x[NodeNewOrUpdate-25] + _ = x[NodeEndpoint-26] + _ = x[NodeDERP-27] _ = x[UserNewOrUpdate-51] _ = x[UserRemove-52] } @@ -26,13 +28,13 @@ const ( _Change_name_0 = "ChangeUnknown" _Change_name_1 = "Full" _Change_name_2 = "PolicyDERPExtraRecords" - _Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate" + _Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdateNodeEndpointNodeDERP" _Change_name_4 = "UserNewOrUpdateUserRemove" ) var ( _Change_index_2 = [...]uint8{0, 6, 10, 22} - _Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67} + _Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67, 79, 87} _Change_index_4 = [...]uint8{0, 15, 25} ) @@ -45,7 +47,7 @@ func (i Change) String() string { case 11 <= i && i <= 13: i -= 11 return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]] - case 21 <= i && i <= 25: + case 21 <= i && i <= 27: i -= 21 return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]] case 51 <= i && i <= 52: diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index c6429669..05eb8a35 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -535,8 +535,10 @@ func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC } } - // TODO(kradalby): Find a good way to compare updates - ret.Endpoints = req.Endpoints + // Compare endpoints using order-independent comparison + if EndpointsChanged(node.Endpoints, req.Endpoints) { + ret.Endpoints = req.Endpoints + } now := time.Now() ret.LastSeen = &now @@ -544,6 +546,32 @@ func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC return ret } +// EndpointsChanged compares two endpoint slices and returns true if they differ. +// The comparison is order-independent - endpoints are sorted before comparison. +func EndpointsChanged(oldEndpoints, newEndpoints []netip.AddrPort) bool { + if len(oldEndpoints) != len(newEndpoints) { + return true + } + + if len(oldEndpoints) == 0 { + return false + } + + // Make copies to avoid modifying the original slices + oldCopy := slices.Clone(oldEndpoints) + newCopy := slices.Clone(newEndpoints) + + // Sort both slices to enable order-independent comparison + slices.SortFunc(oldCopy, func(a, b netip.AddrPort) int { + return a.Compare(b) + }) + slices.SortFunc(newCopy, func(a, b netip.AddrPort) int { + return a.Compare(b) + }) + + return !slices.Equal(oldCopy, newCopy) +} + func (node *Node) RegisterMethodToV1Enum() v1.RegisterMethod { switch node.RegisterMethod { case "authkey":