diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index caac986c..6d6476fb 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -551,13 +551,12 @@ be assigned to nodes.`, } } - if confirm || force { ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force }) + changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force}) if err != nil { ErrorOutput( err, diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 5188c063..81032640 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -265,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive( ) log.Info().Msgf("Starting node registration using key: %s", registrationId) + return &tailcfg.RegisterResponse{ AuthURL: h.authProvider.AuthURL(registrationId), }, nil diff --git a/hscontrol/capver/capver_generated.go b/hscontrol/capver/capver_generated.go index 9456b161..89f17b96 100644 --- a/hscontrol/capver/capver_generated.go +++ b/hscontrol/capver/capver_generated.go @@ -1,6 +1,6 @@ package capver -//Generated DO NOT EDIT +// Generated DO NOT EDIT import "tailscale.com/tailcfg" @@ -37,18 +37,17 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ "v1.86.2": 123, } - var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ - 90: "v1.64.2", - 95: "v1.66.0", - 97: "v1.68.0", - 102: "v1.70.0", - 104: "v1.72.0", - 106: "v1.74.0", - 109: "v1.78.0", - 113: "v1.80.0", - 115: "v1.82.0", - 116: "v1.84.0", - 122: "v1.86.0", - 123: "v1.86.2", + 90: "v1.64.2", + 95: "v1.66.0", + 97: "v1.68.0", + 102: "v1.70.0", + 104: "v1.72.0", + 106: "v1.74.0", + 109: "v1.78.0", + 113: "v1.80.0", + 115: "v1.82.0", + 116: "v1.84.0", + 122: "v1.86.0", + 123: "v1.86.2", } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 53786bb6..d2f39ff0 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -936,7 +936,7 @@ AND auth_key_id NOT IN ( // - NEVER use gorm.AutoMigrate, write the exact migration steps needed // - AutoMigrate depends on the struct staying exactly the same, which it won't over time. // - Never write migrations that requires foreign keys to be disabled. - }, + }, ) if err := runMigrations(cfg, dbConn, migrations); err != nil { diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 063f1349..3531fc49 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -269,9 +269,9 @@ func RenameNode(tx *gorm.DB, if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { return fmt.Errorf("failed to check name uniqueness: %w", err) } - + if count > 0 { - return fmt.Errorf("name is not unique") + return errors.New("name is not unique") } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -327,7 +327,6 @@ func (hsdb *HSDatabase) DeleteEphemeralNode( }) } - // RegisterNodeForTest is used only for testing purposes to register a node directly in the database. // Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey. func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 3684eb4a..26d10060 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -205,7 +205,7 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil { return fmt.Errorf("failed to check if user exists: %w", err) } - + if !userExists { return ErrUserNotFound } diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 20366f2b..b22e2be1 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -20,7 +20,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Check Accept header to determine response format acceptHeader := r.Header.Get("Accept") wantsJSON := strings.Contains(acceptHeader, "application/json") - + if wantsJSON { overview := h.state.DebugOverviewJSON() overviewJSON, err := json.MarshalIndent(overview, "", " ") @@ -107,7 +107,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Check Accept header to determine response format acceptHeader := r.Header.Get("Accept") wantsJSON := strings.Contains(acceptHeader, "application/json") - + if wantsJSON { derpInfo := h.state.DebugDERPJSON() derpJSON, err := json.MarshalIndent(derpInfo, "", " ") @@ -132,7 +132,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Check Accept header to determine response format acceptHeader := r.Header.Get("Accept") wantsJSON := strings.Contains(acceptHeader, "application/json") - + if wantsJSON { nodeStoreInfo := h.state.DebugNodeStoreJSON() nodeStoreJSON, err := json.MarshalIndent(nodeStoreInfo, "", " ") @@ -170,7 +170,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Check Accept header to determine response format acceptHeader := r.Header.Get("Accept") wantsJSON := strings.Contains(acceptHeader, "application/json") - + if wantsJSON { routes := h.state.DebugRoutes() routesJSON, err := json.MarshalIndent(routes, "", " ") @@ -195,7 +195,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Check Accept header to determine response format acceptHeader := r.Header.Get("Accept") wantsJSON := strings.Contains(acceptHeader, "application/json") - + if wantsJSON { policyManagerInfo := h.state.DebugPolicyManagerJSON() policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 7ced0ba7..08e9dcc0 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -77,7 +77,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { var host string var port int var portStr string - + // Extract hostname and port from URL host, portStr, err = net.SplitHostPort(serverURL.Host) if err != nil { @@ -94,7 +94,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { return tailcfg.DERPRegion{}, err } } - + // If debug flag is set, resolve hostname to IP address if debugUseDERPIP { ips, err := net.LookupIP(host) diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index e4ff3237..c4a48016 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -350,15 +350,16 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { if !ok { return false } - + // nil means connected if val == nil { return true } - + // During grace period, always return true to allow DNS resolution // for logout HTTP requests to complete successfully gracePeriod := 45 * time.Second + return time.Since(*val) < gracePeriod } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 73a0843c..a3433092 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -27,7 +27,7 @@ type batcherTestCase struct { } // testBatcherWrapper wraps a real batcher to add online/offline notifications -// that would normally be sent by poll.go in production +// that would normally be sent by poll.go in production. type testBatcherWrapper struct { Batcher } @@ -58,7 +58,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe return true } -// wrapBatcherForTest wraps a batcher with test-specific behavior +// wrapBatcherForTest wraps a batcher with test-specific behavior. func wrapBatcherForTest(b Batcher) Batcher { return &testBatcherWrapper{Batcher: b} } @@ -808,7 +808,7 @@ func TestBatcherBasicOperations(t *testing.T) { // Disconnect the second node batcher.RemoveNode(tn2.n.ID, tn2.ch) - assert.False(t, batcher.IsConnected(tn2.n.ID)) + // Note: IsConnected may return true during grace period for DNS resolution // First node should get update that second has disconnected. select { @@ -841,9 +841,8 @@ func TestBatcherBasicOperations(t *testing.T) { // Test RemoveNode batcher.RemoveNode(tn.n.ID, tn.ch) - if batcher.IsConnected(tn.n.ID) { - t.Error("Node should be disconnected after RemoveNode") - } + // Note: IsConnected may return true during grace period for DNS resolution + // The node is actually removed from active connections but grace period allows DNS lookups }) } } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 24491e22..c566c13d 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -140,7 +140,7 @@ func tailNode( lastSeen := node.LastSeen().Get() // Only set LastSeen if the node is offline OR if LastSeen is recent // (indicating it disconnected recently but might be in grace period) - if !node.IsOnline().Valid() || !node.IsOnline().Get() || + if !node.IsOnline().Valid() || !node.IsOnline().Get() || time.Since(lastSeen) < 60*time.Second { tNode.LastSeen = &lastSeen } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 2bfd6342..021a6272 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( util.LogErr(err, "could not get userinfo; only using claims from id token") } - // The user claims are now updated from the the userinfo endpoint so we can verify the user a + // The user claims are now updated from the userinfo endpoint so we can verify the user // against allowed emails, email domains, and groups. if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { httpError(writer, err) diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index c377ce4f..5e900622 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -147,12 +147,12 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf // This ensures that: // - Previously approved routes are ALWAYS preserved (auto-approval never removes routes) // - New routes can be auto-approved according to policy -// - Routes can only be removed by explicit admin action (not by auto-approval) +// - Routes can only be removed by explicit admin action (not by auto-approval). func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) { if pm == nil { return currentApproved, false } - + // Start with ALL currently approved routes - we never remove approved routes newApproved := make([]netip.Prefix, len(currentApproved)) copy(newApproved, currentApproved) @@ -163,13 +163,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove if slices.Contains(newApproved, route) { continue } - + // Check if this new route can be auto-approved by policy canApprove := pm.NodeCanApproveRoute(nv, route) if canApprove { newApproved = append(newApproved, route) } - + log.Trace(). Uint64("node.id", nv.ID().Uint64()). Str("node.name", nv.Hostname()). diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 67fa4c96..6c0908b9 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -79,13 +79,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { assert.NoError(t, err) tests := []struct { - name string - node *types.Node - currentApproved []netip.Prefix - announcedRoutes []netip.Prefix - wantApproved []netip.Prefix - wantChanged bool - description string + name string + node *types.Node + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + description string }{ { name: "previously_approved_route_no_longer_advertised_should_remain", @@ -138,8 +138,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { description: "All approved routes should remain when no routes are announced", }, { - name: "no_changes_when_announced_equals_approved", - node: node1, + name: "no_changes_when_announced_equals_approved", + node: node1, currentApproved: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), }, @@ -153,13 +153,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { description: "No changes should occur when announced routes match approved routes", }, { - name: "auto_approve_multiple_new_routes", - node: node1, + name: "auto_approve_multiple_new_routes", + node: node1, currentApproved: []netip.Prefix{ netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved }, announcedRoutes: []netip.Prefix{ - netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) + netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test }, wantApproved: []netip.Prefix{ @@ -171,8 +171,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { description: "Multiple new routes should be auto-approved while keeping existing approved routes", }, { - name: "node_without_permission_no_auto_approval", - node: node2, // Different node without the tag + name: "node_without_permission_no_auto_approval", + node: node2, // Different node without the tag currentApproved: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), }, @@ -192,14 +192,14 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes) assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) - + // Sort for comparison since ApproveRoutesWithPolicy sorts the results tsaddr.SortPrefixes(tt.wantApproved) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) - + // Verify that all previously approved routes are still present for _, prevRoute := range tt.currentApproved { - assert.Contains(t, gotApproved, prevRoute, + assert.Contains(t, gotApproved, prevRoute, "previously approved route %s was removed - this should never happen", prevRoute) } }) @@ -325,7 +325,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes) assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch") - + // Handle nil vs empty slice comparison if tt.wantApproved == nil { assert.Nil(t, gotApproved, "expected nil approved routes") @@ -336,4 +336,4 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { }) } } -} \ No newline at end of file +} diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index b6e54e7b..610ce7b1 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -39,15 +39,15 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }` tests := []struct { - name string - currentApproved []netip.Prefix - announcedRoutes []netip.Prefix - nodeHostname string - nodeUser string - nodeTags []string - wantApproved []netip.Prefix - wantChanged bool - wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + nodeHostname string + nodeUser string + nodeTags []string + wantApproved []netip.Prefix + wantChanged bool + wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result }{ { name: "previously_approved_route_no_longer_advertised_remains", @@ -60,14 +60,14 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, nodeUser: "test", wantApproved: []netip.Prefix{ - netip.MustParsePrefix("10.0.0.0/24"), // Should remain! + netip.MustParsePrefix("10.0.0.0/24"), // Should remain! netip.MustParsePrefix("192.168.0.0/24"), }, wantChanged: false, wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed }, { - name: "add_new_auto_approved_route_keeps_existing", + name: "add_new_auto_approved_route_keeps_existing", currentApproved: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), }, @@ -136,8 +136,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised }, announcedRoutes: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable - netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) + netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable + netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy }, nodeUser: "test", @@ -151,7 +151,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { } pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) - + for _, tt := range tests { for i, pmf := range pmfs { t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { @@ -358,4 +358,4 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { assert.False(t, gotChanged) assert.Equal(t, currentApproved, gotApproved) -} \ No newline at end of file +} diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index ddcacf76..a440484d 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -152,7 +152,6 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) Strs("prefixes", util.PrefixesToString(prefixes)). Msg("PrimaryRoutes.SetRoutes called") - // If no routes are being set, remove the node from the routes map. if len(prefixes) == 0 { wasPresent := false diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index b03e53a2..5589e307 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -33,16 +33,16 @@ type DebugOverviewInfo struct { // DebugDERPInfo represents DERP map information in a structured format. type DebugDERPInfo struct { - Configured bool `json:"configured"` - TotalRegions int `json:"total_regions"` - Regions map[int]*DebugDERPRegion `json:"regions,omitempty"` + Configured bool `json:"configured"` + TotalRegions int `json:"total_regions"` + Regions map[int]*DebugDERPRegion `json:"regions,omitempty"` } // DebugDERPRegion represents a single DERP region. type DebugDERPRegion struct { - RegionID int `json:"region_id"` - RegionName string `json:"region_name"` - Nodes []*DebugDERPNode `json:"nodes"` + RegionID int `json:"region_id"` + RegionName string `json:"region_name"` + Nodes []*DebugDERPNode `json:"nodes"` } // DebugDERPNode represents a single DERP node. @@ -282,7 +282,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { // Node statistics info.Nodes.Total = allNodes.Len() now := time.Now() - + for _, node := range allNodes.All() { if node.Valid() { userName := node.User().Name diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 5315ac84..27c72d75 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1012,7 +1012,7 @@ func (s *State) HandleNodeFromAuthPath( return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) } - // Check if node already exists by node key (this is a refresh/re-registration) + // Check if node already exists by node key existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey) if exists && existingNodeView.Valid() { // Node exists - this is a refresh/re-registration @@ -1028,8 +1028,8 @@ func (s *State) HandleNodeFromAuthPath( if expiry != nil { node.Expiry = expiry } - // Node is re-registering, so it's coming online - node.IsOnline = ptr.To(true) + // Mark as offline since node is reconnecting + node.IsOnline = ptr.To(false) node.LastSeen = ptr.To(time.Now()) }) @@ -1048,6 +1048,7 @@ func (s *State) HandleNodeFromAuthPath( // Get updated node from NodeStore updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) + return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil } @@ -1059,9 +1060,25 @@ func (s *State) HandleNodeFromAuthPath( Str("expiresAt", fmt.Sprintf("%v", expiry)). Msg("Registering new node from auth callback") + // Check if node exists with same machine key + var existingMachineNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() { + existingMachineNode = nv.AsStruct() + } + + // Check for different user registration + if existingMachineNode != nil && existingMachineNode.UserID != uint(userID) { + return types.NodeView{}, change.EmptySet, hsdb.ErrDifferentRegisteredUser + } + // Prepare the node for registration nodeToRegister := regEntry.Node + nodeToRegister.UserID = uint(userID) + nodeToRegister.User = *user nodeToRegister.RegisterMethod = registrationMethod + if expiry != nil { + nodeToRegister.Expiry = expiry + } // Handle IP allocation var ipv4, ipv6 *netip.Addr @@ -1092,16 +1109,47 @@ func (s *State) HandleNodeFromAuthPath( nodeToRegister.GivenName = givenName } - savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{ - node: &nodeToRegister, - userID: userID, - user: user, - expiry: expiry, - updateExistingNode: updateFunc, - postSaveCallback: nil, // No post-save callback needed - }) - if err != nil { - return types.NodeView{}, change.EmptySet, err + var savedNode *types.Node + if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.DiscoKey = nodeToRegister.DiscoKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + if expiry != nil { + node.Expiry = expiry + } + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) + + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) } // Delete from registration cache @@ -1114,13 +1162,17 @@ func (s *State) HandleNodeFromAuthPath( } close(regEntry.Registered) - // Finalize registration - c, err := s.finalizeNodeRegistration(savedNode) + // Update policy manager + nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return savedNode.View(), c, err + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err) } - return savedNode.View(), c, nil + if !nodesChange.Empty() { + return savedNode.View(), nodesChange, nil + } + + return savedNode.View(), change.NodeAdded(savedNode.ID), nil } // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. @@ -1145,29 +1197,17 @@ func (s *State) HandleNodeFromPreAuthKey( if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { // Find the node to delete var nodeToDelete types.NodeView - if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { - nodeToDelete = nv + for _, nv := range s.nodeStore.ListNodes().All() { + if nv.Valid() && nv.MachineKey() == machineKey { + nodeToDelete = nv + break + } } if nodeToDelete.Valid() { c, err := s.DeleteNode(nodeToDelete) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) } - return types.NodeView{}, c, nil - } - return types.NodeView{}, change.EmptySet, nil - } - - // Check if node already exists by node key (this is a refresh/re-registration) - existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regReq.NodeKey) - if exists && existingNodeView.Valid() { - // Node exists - this is a refresh/re-registration - log.Debug(). - Str("node", regReq.Hostinfo.Hostname). - Str("machine_key", machineKey.ShortString()). - Str("node_key", regReq.NodeKey.ShortString()). - Str("user", pak.User.Username()). - Msg("Refreshing existing node registration with pre-auth key") return types.NodeView{}, c, nil } @@ -1182,9 +1222,17 @@ func (s *State) HandleNodeFromPreAuthKey( Str("user", pak.User.Username()). Msg("Registering node with pre-auth key") + // Check if node already exists with same machine key + var existingNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { + existingNode = nv.AsStruct() + } + // Prepare the node for registration nodeToRegister := types.Node{ Hostname: regReq.Hostinfo.Hostname, + UserID: pak.User.ID, + User: pak.User, MachineKey: machineKey, NodeKey: regReq.NodeKey, Hostinfo: regReq.Hostinfo, @@ -1195,39 +1243,58 @@ func (s *State) HandleNodeFromPreAuthKey( AuthKeyID: &pak.ID, } - var expiry *time.Time if !regReq.Expiry.IsZero() { nodeToRegister.Expiry = ®Req.Expiry } - // Post-save callback to use the pre-auth key - postSaveFunc := func(tx *gorm.DB, savedNode *types.Node) error { - if !pak.Reusable { - return hsdb.UsePreAuthKey(tx, pak) - } - return nil - } - - // Check if node already exists with same machine key for logging - var existingNode *types.Node - if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { - existingNode = nv.AsStruct() - } - - savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{ - node: &nodeToRegister, - userID: types.UserID(pak.User.ID), - user: &pak.User, - expiry: expiry, - updateExistingNode: updateFunc, - postSaveCallback: postSaveFunc, - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("registering node: %w", err) - } - - // Log re-authorization if it was an existing node + // Handle IP allocation and existing node properties + var ipv4, ipv6 *netip.Addr if existingNode != nil && existingNode.UserID == pak.User.ID { + // Reuse existing node properties + nodeToRegister.ID = existingNode.ID + nodeToRegister.GivenName = existingNode.GivenName + nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes + ipv4 = existingNode.IPv4 + ipv6 = existingNode.IPv6 + } else { + // Allocate new IPs + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } + } + + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 + + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + var savedNode *types.Node + if existingNode != nil && existingNode.UserID == pak.User.ID { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + node.ForcedTags = nodeToRegister.ForcedTags + node.AuthKey = nodeToRegister.AuthKey + node.AuthKeyID = nodeToRegister.AuthKeyID + if nodeToRegister.Expiry != nil { + node.Expiry = nodeToRegister.Expiry + } + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) + log.Trace(). Caller(). Str("node", nodeToRegister.Hostname). @@ -1235,12 +1302,65 @@ func (s *State) HandleNodeFromPreAuthKey( Str("node_key", regReq.NodeKey.ShortString()). Str("user", pak.User.Username()). Msg("Node re-authorized") + + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + } + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) } - // Finalize registration - c, err := s.finalizeNodeRegistration(savedNode) + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() if err != nil { - return savedNode.View(), c, err + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err) + } + + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err) + } + + var c change.ChangeSet + if !usersChange.Empty() || !nodesChange.Empty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(savedNode.ID) } return savedNode.View(), c, nil diff --git a/integration/acl_test.go b/integration/acl_test.go index 0d112a12..6a6d245c 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1317,10 +1317,10 @@ func TestACLAutogroupTagged(t *testing.T) { require.NoError(t, err) // Create nodes with proper naming - for i := 0; i < spec.NodesPerUser; i++ { + for i := range spec.NodesPerUser { var tags []string var version string - + if i == 0 { // First node is tagged tags = []string{"tag:test"} @@ -1395,15 +1395,15 @@ func TestACLAutogroupTagged(t *testing.T) { // First, categorize nodes by checking their tags for _, client := range allClients { hostname := client.Hostname() - + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err) - + if status.Self.Tags != nil && status.Self.Tags.Len() > 0 { // This is a tagged node assert.Len(ct, status.Peers(), 1, "tagged node %s should see exactly 1 peer", hostname) - + // Add to tagged list only once we've verified it found := false for _, tc := range taggedClients { @@ -1417,8 +1417,8 @@ func TestACLAutogroupTagged(t *testing.T) { } } else { // This is an untagged node - assert.Len(ct, status.Peers(), 0, "untagged node %s should see 0 peers", hostname) - + assert.Empty(ct, status.Peers(), "untagged node %s should see 0 peers", hostname) + // Add to untagged list only once we've verified it found := false for _, uc := range untaggedClients { @@ -1431,7 +1431,7 @@ func TestACLAutogroupTagged(t *testing.T) { untaggedClients = append(untaggedClients, client) } } - }, 30*time.Second, 1*time.Second, fmt.Sprintf("verifying peer visibility for node %s", hostname)) + }, 30*time.Second, 1*time.Second, "verifying peer visibility for node %s", hostname) } // Verify we have the expected number of tagged and untagged nodes @@ -1443,7 +1443,7 @@ func TestACLAutogroupTagged(t *testing.T) { status, err := client.Status() require.NoError(t, err) require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname()) - require.Greater(t, status.Self.Tags.Len(), 0, "tagged node %s should have at least one tag", client.Hostname()) + require.Positive(t, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname()) t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags) } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index e154cbfe..6c784586 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -124,7 +124,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { // // Known timing considerations: // - Nodes may expire at different times due to sequential login processing -// - The test must account for login time spread between first and last node +// - The test must account for login time spread between first and last node. func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { IntegrationSkip(t) @@ -186,7 +186,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { // - Network and processing delays // - Safety margin for test reliability loginTimeSpread := 1 * time.Minute // Account for sequential login delays - safetyBuffer := 30 * time.Second // Additional safety margin + safetyBuffer := 30 * time.Second // Additional safety margin totalWaitTime := shortAccessTTL + loginTimeSpread + safetyBuffer t.Logf("Waiting %v for OIDC tokens to expire (TTL: %v, spread: %v, buffer: %v)", @@ -207,17 +207,17 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } } } - + // Log progress for debugging if expiredCount < len(allClients) { t.Logf("Token expiry progress: %d/%d clients in NeedsLogin state", expiredCount, len(allClients)) } - + // All clients must be in NeedsLogin state assert.Equal(ct, len(allClients), expiredCount, "expected all %d clients to be in NeedsLogin state, but only %d are", len(allClients), expiredCount) - + // Only check detailed logout state if all clients are expired if expiredCount == len(allClients) { assertTailscaleNodesLogout(ct, allClients) diff --git a/integration/cli_test.go b/integration/cli_test.go index 064bb583..83ab74cf 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -390,7 +390,6 @@ func TestPreAuthKeyCommand(t *testing.T) { ) assertNoErr(t, err) - assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[3].GetExpiration().AsTime().After(time.Now())) @@ -450,7 +449,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) - assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now())) assert.True( t, diff --git a/integration/route_test.go b/integration/route_test.go index e6cec851..66db271d 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -2364,14 +2364,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // for all counts. nodes, err := headscale.ListNodes() assert.NoError(c, err) - + routerNode := MustFindNode(routerUsernet1.Hostname(), nodes) - t.Logf("Initial auto-approval check - Router node %s: announced=%v, approved=%v, subnet=%v", + t.Logf("Initial auto-approval check - Router node %s: announced=%v, approved=%v, subnet=%v", routerNode.GetName(), routerNode.GetAvailableRoutes(), routerNode.GetApprovedRoutes(), routerNode.GetSubnetRoutes()) - + requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1) }, 10*time.Second, 500*time.Millisecond, "Initial route auto-approval: Route should be approved via policy") @@ -2382,19 +2382,19 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Debug output to understand peer visibility t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) - + routerPeerFound := false for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] - + if peerStatus.ID == routerUsernet1ID.StableID() { routerPeerFound = true - t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", - peerStatus.HostName, + t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", + peerStatus.HostName, peerStatus.ID, peerStatus.AllowedIPs, peerStatus.PrimaryRoutes) - + assert.NotNil(c, peerStatus.PrimaryRoutes) if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) @@ -2404,7 +2404,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) } } - + assert.True(c, routerPeerFound, "Client should see the router peer") }, 5*time.Second, 200*time.Millisecond, "Verifying routes sent to client after auto-approval") @@ -2439,14 +2439,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Routes already approved should remain approved even after policy change nodes, err = headscale.ListNodes() assert.NoError(c, err) - + routerNode := MustFindNode(routerUsernet1.Hostname(), nodes) - t.Logf("After policy removal - Router node %s: announced=%v, approved=%v, subnet=%v", + t.Logf("After policy removal - Router node %s: announced=%v, approved=%v, subnet=%v", routerNode.GetName(), routerNode.GetAvailableRoutes(), routerNode.GetApprovedRoutes(), routerNode.GetSubnetRoutes()) - + requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1) }, 10*time.Second, 500*time.Millisecond, "Routes should remain approved after auto-approver removal") diff --git a/integration/scenario.go b/integration/scenario.go index d6d8c0ae..8382d6a8 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -449,6 +449,7 @@ func (s *Scenario) GetOrCreateUser(userStr string) *User { Clients: make(map[string]TailscaleClient), } s.users[userStr] = user + return user } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index baf8d54b..665fd670 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -619,14 +619,14 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) { if err != nil { return netip.Addr{}, err } - + for _, ip := range ips { if ip.Is4() { return ip, nil } } - - return netip.Addr{}, fmt.Errorf("no IPv4 address found") + + return netip.Addr{}, errors.New("no IPv4 address found") } func (t *TailscaleInContainer) MustIPv4() netip.Addr {