diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go new file mode 100644 index 00000000..9d6f1a09 --- /dev/null +++ b/hscontrol/state/maprequest.go @@ -0,0 +1,50 @@ +// Package state provides pure functions for processing MapRequest data. +// These functions are extracted from UpdateNodeFromMapRequest to improve +// testability and maintainability. + +package state + +import ( + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" +) + +// NetInfoFromMapRequest determines the correct NetInfo to use. +// Returns the NetInfo that should be used for this request. +func NetInfoFromMapRequest( + nodeID types.NodeID, + currentHostinfo *tailcfg.Hostinfo, + reqHostinfo *tailcfg.Hostinfo, +) *tailcfg.NetInfo { + // If request has NetInfo, use it + if reqHostinfo != nil && reqHostinfo.NetInfo != nil { + return reqHostinfo.NetInfo + } + + // Otherwise, use current NetInfo if available + if currentHostinfo != nil && currentHostinfo.NetInfo != nil { + log.Debug(). + Caller(). + Uint64("node.id", nodeID.Uint64()). + Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP). + Msg("using NetInfo from previous Hostinfo in MapRequest") + return currentHostinfo.NetInfo + } + + // No NetInfo available anywhere - log for debugging + var hostname string + if reqHostinfo != nil { + hostname = reqHostinfo.Hostname + } else if currentHostinfo != nil { + hostname = currentHostinfo.Hostname + } + + log.Debug(). + Caller(). + Uint64("node.id", nodeID.Uint64()). + Str("node.hostname", hostname). + Msg("node sent update but has no NetInfo in request or database") + + return nil +} diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go new file mode 100644 index 00000000..dfb2abd0 --- /dev/null +++ b/hscontrol/state/maprequest_test.go @@ -0,0 +1,134 @@ +package state + +import ( + "net/netip" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func TestNetInfoFromMapRequest(t *testing.T) { + nodeID := types.NodeID(1) + + tests := []struct { + name string + currentHostinfo *tailcfg.Hostinfo + reqHostinfo *tailcfg.Hostinfo + expectNetInfo *tailcfg.NetInfo + }{ + { + name: "no current NetInfo - return nil", + currentHostinfo: nil, + reqHostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + expectNetInfo: nil, + }, + { + name: "current has NetInfo, request has NetInfo - use request", + currentHostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{PreferredDERP: 1}, + }, + reqHostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + NetInfo: &tailcfg.NetInfo{PreferredDERP: 2}, + }, + expectNetInfo: &tailcfg.NetInfo{PreferredDERP: 2}, + }, + { + name: "current has NetInfo, request has no NetInfo - use current", + currentHostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{PreferredDERP: 3}, + }, + reqHostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + expectNetInfo: &tailcfg.NetInfo{PreferredDERP: 3}, + }, + { + name: "current has NetInfo, no request Hostinfo - use current", + currentHostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{PreferredDERP: 4}, + }, + reqHostinfo: nil, + expectNetInfo: &tailcfg.NetInfo{PreferredDERP: 4}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NetInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo) + + if tt.expectNetInfo == nil { + assert.Nil(t, result, "expected nil NetInfo") + } else { + require.NotNil(t, result, "expected non-nil NetInfo") + assert.Equal(t, tt.expectNetInfo.PreferredDERP, result.PreferredDERP, "DERP mismatch") + } + }) + } +} + +func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { + nodeID := types.NodeID(1) + + // This test reproduces the bug in registration flows where NetInfo was lost + // because we used the wrong hostinfo reference when calling NetInfoFromMapRequest + t.Run("registration_flow_bug_reproduction", func(t *testing.T) { + // Simulate existing node with NetInfo (before re-registration) + existingNodeHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + NetInfo: &tailcfg.NetInfo{PreferredDERP: 5}, + } + + // Simulate new registration request (no NetInfo) + newRegistrationHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + // NetInfo is nil - this is what comes from the registration request + } + + // Simulate what was happening in the bug: we passed the "current node being modified" + // hostinfo (which has no NetInfo) instead of the existing node's hostinfo + nodeBeingModifiedHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + // NetInfo is nil because this node is being modified/reset + } + + // BUG: Using the node being modified (no NetInfo) instead of existing node (has NetInfo) + buggyResult := NetInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo) + assert.Nil(t, buggyResult, "Bug: Should return nil when using wrong hostinfo reference") + + // CORRECT: Using the existing node's hostinfo (has NetInfo) + correctResult := NetInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo) + assert.NotNil(t, correctResult, "Fix: Should preserve NetInfo when using correct hostinfo reference") + assert.Equal(t, 5, correctResult.PreferredDERP, "Should preserve the DERP region from existing node") + }) +} + +// Simple helper function for tests +func createTestNodeSimple(id types.NodeID) *types.Node { + user := types.User{ + Name: "test-user", + } + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + node := &types.Node{ + ID: id, + Hostname: "test-node", + UserID: uint(id), + User: user, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + IPv4: &netip.Addr{}, + IPv6: &netip.Addr{}, + } + + return node +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d74814b0..b445f4e1 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -852,10 +852,25 @@ func (s *State) BackfillNodeIPs() ([]string, error) { } for _, node := range nodes { - // Preserve online status when refreshing from database + // Preserve online status and NetInfo when refreshing from database existingNode, exists := s.nodeStore.GetNode(node.ID) if exists && existingNode.Valid() { node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + + // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics + // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). + + // Preserve NetInfo from existing node to prevent loss during backfill + netInfo := NetInfoFromMapRequest(node.ID, existingNode.AsStruct().Hostinfo, node.Hostinfo) + if netInfo != nil { + if node.Hostinfo != nil { + hostinfoCopy := *node.Hostinfo + hostinfoCopy.NetInfo = netInfo + node.Hostinfo = &hostinfoCopy + } else { + node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} + } + } } // TODO(kradalby): This should just update the IP addresses, nothing else in the node store. // We should avoid PutNode here. @@ -1166,7 +1181,24 @@ func (s *State) HandleNodeFromAuthPath( node.NodeKey = nodeToRegister.NodeKey node.DiscoKey = nodeToRegister.DiscoKey node.Hostname = nodeToRegister.Hostname - node.Hostinfo = nodeToRegister.Hostinfo + + // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics + // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). + + // Preserve NetInfo from existing node when re-registering + netInfo := NetInfoFromMapRequest(existingMachineNode.ID, existingMachineNode.Hostinfo, nodeToRegister.Hostinfo) + if netInfo != nil { + if nodeToRegister.Hostinfo != nil { + hostinfoCopy := *nodeToRegister.Hostinfo + hostinfoCopy.NetInfo = netInfo + node.Hostinfo = &hostinfoCopy + } else { + node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} + } + } else { + node.Hostinfo = nodeToRegister.Hostinfo + } + node.Endpoints = nodeToRegister.Endpoints node.RegisterMethod = nodeToRegister.RegisterMethod if expiry != nil { @@ -1333,7 +1365,24 @@ func (s *State) HandleNodeFromPreAuthKey( s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { node.NodeKey = nodeToRegister.NodeKey node.Hostname = nodeToRegister.Hostname - node.Hostinfo = nodeToRegister.Hostinfo + + // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics + // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). + + // Preserve NetInfo from existing node when re-registering + netInfo := NetInfoFromMapRequest(existingNode.ID, existingNode.Hostinfo, nodeToRegister.Hostinfo) + if netInfo != nil { + if nodeToRegister.Hostinfo != nil { + hostinfoCopy := *nodeToRegister.Hostinfo + hostinfoCopy.NetInfo = netInfo + node.Hostinfo = &hostinfoCopy + } else { + node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} + } + } else { + node.Hostinfo = nodeToRegister.Hostinfo + } + node.Endpoints = nodeToRegister.Endpoints node.RegisterMethod = nodeToRegister.RegisterMethod node.ForcedTags = nodeToRegister.ForcedTags @@ -1527,6 +1576,12 @@ func (s *State) autoApproveNodes() ([]change.ChangeSet, error) { // - node.ApplyPeerChange // - logTracePeerChange in poll.go. func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) { + log.Trace(). + Caller(). + Uint64("node.id", id.Uint64()). + Interface("request", req). + Msg("Processing MapRequest for node") + var routeChange bool var hostinfoChanged bool var needsRouteApproval bool @@ -1536,6 +1591,27 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest peerChange := currentNode.PeerChangeFromMapRequest(req) hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) + // Get the correct NetInfo to use + netInfo := NetInfoFromMapRequest(id, currentNode.Hostinfo, req.Hostinfo) + + // Apply NetInfo to request Hostinfo + if req.Hostinfo != nil { + if netInfo != nil { + // Create a copy to avoid modifying the original + hostinfoCopy := *req.Hostinfo + hostinfoCopy.NetInfo = netInfo + req.Hostinfo = &hostinfoCopy + } + } else if netInfo != nil { + // Create minimal Hostinfo with NetInfo + req.Hostinfo = &tailcfg.Hostinfo{ + NetInfo: netInfo, + } + } + + // Re-check hostinfoChanged after potential NetInfo preservation + hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) + // If there is no changes and nothing to save, // return early. if peerChangeEmpty(peerChange) && !hostinfoChanged { @@ -1544,31 +1620,43 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest // Calculate route approval before NodeStore update to avoid calling View() inside callback var autoApprovedRoutes []netip.Prefix - hasNewRoutes := req.Hostinfo != nil && len(req.Hostinfo.RoutableIPs) > 0 + var hasNewRoutes bool + if hi := req.Hostinfo; hi != nil { + hasNewRoutes = len(hi.RoutableIPs) > 0 + } needsRouteApproval = hostinfoChanged && (routesChanged(currentNode.View(), req.Hostinfo) || (hasNewRoutes && len(currentNode.ApprovedRoutes) == 0)) if needsRouteApproval { - autoApprovedRoutes, routeChange = policy.ApproveRoutesWithPolicy( - s.polMan, - currentNode.View(), - // We need to preserve currently approved routes to ensure - // routes outside of the policy approver is persisted. - currentNode.ApprovedRoutes, - // However, the node has updated its routable IPs, so we - // need to approve them using that as a context. - req.Hostinfo.RoutableIPs, - ) + // Extract announced routes from request + var announcedRoutes []netip.Prefix + if req.Hostinfo != nil { + announcedRoutes = req.Hostinfo.RoutableIPs + } + + // Apply policy-based auto-approval if routes are announced + if len(announcedRoutes) > 0 { + autoApprovedRoutes, routeChange = policy.ApproveRoutesWithPolicy( + s.polMan, + currentNode.View(), + currentNode.ApprovedRoutes, + announcedRoutes, + ) + } } // Log when routes change but approval doesn't - if hostinfoChanged && req.Hostinfo != nil && routesChanged(currentNode.View(), req.Hostinfo) && !routeChange { - log.Debug(). - Caller(). - Uint64("node.id", id.Uint64()). - Strs("oldAnnouncedRoutes", util.PrefixesToString(currentNode.AnnouncedRoutes())). - Strs("newAnnouncedRoutes", util.PrefixesToString(req.Hostinfo.RoutableIPs)). - Strs("approvedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). - Bool("routeChange", routeChange). - Msg("announced routes changed but approved routes did not") + if hostinfoChanged && !routeChange { + if hi := req.Hostinfo; hi != nil { + if routesChanged(currentNode.View(), hi) { + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("oldAnnouncedRoutes", util.PrefixesToString(currentNode.AnnouncedRoutes())). + Strs("newAnnouncedRoutes", util.PrefixesToString(hi.RoutableIPs)). + Strs("approvedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Bool("routeChange", routeChange). + Msg("announced routes changed but approved routes did not") + } + } } currentNode.ApplyPeerChange(&peerChange) @@ -1581,27 +1669,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 // TODO(kradalby): evaluate if we need better comparing of hostinfo // before we take the changes. - // Preserve NetInfo only if the existing node actually has valid NetInfo - // This prevents copying nil NetInfo which would lose DERP relay assignments - if req.Hostinfo != nil && req.Hostinfo.NetInfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { - log.Debug(). - Caller(). - Uint64("node.id", id.Uint64()). - Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). - Msg("preserving NetInfo from previous Hostinfo in MapRequest") - req.Hostinfo.NetInfo = currentNode.Hostinfo.NetInfo - } else if req.Hostinfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { - // When MapRequest has no Hostinfo but we have existing NetInfo, create a minimal - // Hostinfo to preserve the NetInfo to maintain DERP connectivity - log.Debug(). - Caller(). - Uint64("node.id", id.Uint64()). - Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). - Msg("creating minimal Hostinfo to preserve NetInfo in MapRequest") - req.Hostinfo = &tailcfg.Hostinfo{ - NetInfo: currentNode.Hostinfo.NetInfo, - } - } + // NetInfo preservation has already been handled above before early return check currentNode.Hostinfo = req.Hostinfo currentNode.ApplyHostnameFromHostInfo(req.Hostinfo) @@ -1630,7 +1698,12 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest // 2. The announced routes changed (even if approved routes stayed the same) // This is because SubnetRoutes is the intersection of announced AND approved routes. needsRouteUpdate := false - routesChangedButNotApproved := hostinfoChanged && req.Hostinfo != nil && needsRouteApproval && !routeChange + var routesChangedButNotApproved bool + if hostinfoChanged && needsRouteApproval && !routeChange { + if hi := req.Hostinfo; hi != nil { + routesChangedButNotApproved = true + } + } if routeChange { needsRouteUpdate = true log.Debug(). diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 26c6becf..90034434 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -66,6 +66,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second) + // Validate that all nodes have NetInfo and DERP servers before logout + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 1*time.Minute) + // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) @@ -149,6 +152,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 120*time.Second) + // Validate that all nodes have NetInfo and DERP servers after reconnection + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 1*time.Minute) + err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) @@ -191,6 +197,60 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } } +// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database +// and a valid DERP server based on the NetInfo. This function follows the pattern of +// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. +func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + t.Logf("requireAllClientsNetInfoAndDERP: Starting validation at %s - %s", startTime.Format(TimestampFormat), message) + + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate node counts first + expectedCount := len(expectedNodes) + assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch") + + // Check each expected node + for _, nodeID := range expectedNodes { + node, exists := nodeStore[nodeID] + assert.True(c, exists, "Node %d not found in nodestore", nodeID) + if !exists { + continue + } + + // Validate that the node has Hostinfo + assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo", nodeID, node.Hostname) + if node.Hostinfo == nil { + continue + } + + // Validate that the node has NetInfo + assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { + continue + } + + // Validate that the node has a valid DERP server (PreferredDERP should be > 0) + preferredDERP := node.Hostinfo.NetInfo.PreferredDERP + assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0), got %d", nodeID, node.Hostname, preferredDERP) + + t.Logf("Node %d (%s) has valid NetInfo with DERP server %d", nodeID, node.Hostname, preferredDERP) + } + }, timeout, 2*time.Second, message) + + endTime := time.Now() + duration := endTime.Sub(startTime) + t.Logf("requireAllClientsNetInfoAndDERP: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message) +} + func assertLastSeenSet(t *testing.T, node *v1.Node) { assert.NotNil(t, node) assert.NotNil(t, node.GetLastSeen())