From 4d66d1f8d3da84e144e2663132b286d33e1d924d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 6 Aug 2025 14:56:05 +0200 Subject: [PATCH] more reuse Signed-off-by: Kristoffer Dalby --- hscontrol/db/preauth_keys.go | 4 +- hscontrol/state/state.go | 517 +++++++++++++++++++++-------------- integration/cli_test.go | 8 +- 3 files changed, 319 insertions(+), 210 deletions(-) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 2e60de2e..a36c1f13 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "slices" "strings" "time" @@ -47,8 +48,9 @@ func CreatePreAuthKey( return nil, err } - // Remove duplicates + // Remove duplicates and sort for consistency aclTags = set.SetOf(aclTags).Slice() + slices.Sort(aclTags) // TODO(kradalby): factor out and create a reusable tag validation, // check if there is one in Tailscale's lib. diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d7e037c9..b2cbfc94 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -994,7 +994,7 @@ func (s *State) HandleNodeFromAuthPath( return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) } - // Check if node already exists by node key + // Check if node already exists by node key (this is a refresh/re-registration) existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey) if exists && existingNodeView.Valid() { // Node exists - this is a refresh/re-registration @@ -1010,8 +1010,8 @@ func (s *State) HandleNodeFromAuthPath( if expiry != nil { node.Expiry = expiry } - // Mark as offline since node is reconnecting - node.IsOnline = ptr.To(false) + // Node is re-registering, so it's coming online + node.IsOnline = ptr.To(true) node.LastSeen = ptr.To(time.Now()) }) @@ -1041,96 +1041,30 @@ func (s *State) HandleNodeFromAuthPath( Str("expiresAt", fmt.Sprintf("%v", expiry)). Msg("Registering new node from auth callback") - // Check if node exists with same machine key - var existingMachineNode *types.Node - if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() { - existingMachineNode = nv.AsStruct() - } - - // Check for different user registration - if existingMachineNode != nil && existingMachineNode.UserID != uint(userID) { - return types.NodeView{}, change.EmptySet, hsdb.ErrDifferentRegisteredUser - } - // Prepare the node for registration nodeToRegister := regEntry.Node - nodeToRegister.UserID = uint(userID) - nodeToRegister.User = *user nodeToRegister.RegisterMethod = registrationMethod - if expiry != nil { - nodeToRegister.Expiry = expiry + + // Custom update function for existing nodes + updateFunc := func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.DiscoKey = nodeToRegister.DiscoKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod } - // Handle IP allocation - var ipv4, ipv6 *netip.Addr - if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { - // Reuse existing IPs and properties - nodeToRegister.ID = existingMachineNode.ID - nodeToRegister.GivenName = existingMachineNode.GivenName - nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes - ipv4 = existingMachineNode.IPv4 - ipv6 = existingMachineNode.IPv6 - } else { - // Allocate new IPs - ipv4, ipv6, err = s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) - } - } - - nodeToRegister.IPv4 = ipv4 - nodeToRegister.IPv6 = ipv6 - - // Ensure unique given name if not set - if nodeToRegister.GivenName == "" { - givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) - } - nodeToRegister.GivenName = givenName - } - - var savedNode *types.Node - if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { - // Update existing node - NodeStore first, then database - s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { - node.NodeKey = nodeToRegister.NodeKey - node.DiscoKey = nodeToRegister.DiscoKey - node.Hostname = nodeToRegister.Hostname - node.Hostinfo = nodeToRegister.Hostinfo - node.Endpoints = nodeToRegister.Endpoints - node.RegisterMethod = nodeToRegister.RegisterMethod - if expiry != nil { - node.Expiry = expiry - } - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) - }) - - // Save to database - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - return &nodeToRegister, nil - }) - if err != nil { - return types.NodeView{}, change.EmptySet, err - } - } else { - // New node - database first to get ID, then NodeStore - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - return &nodeToRegister, nil - }) - if err != nil { - return types.NodeView{}, change.EmptySet, err - } - - // Add to NodeStore after database creates the ID - s.nodeStore.PutNode(*savedNode) + savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{ + node: &nodeToRegister, + userID: userID, + user: user, + expiry: expiry, + updateExistingNode: updateFunc, + postSaveCallback: nil, // No post-save callback needed + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err } // Delete from registration cache @@ -1143,17 +1077,13 @@ func (s *State) HandleNodeFromAuthPath( } close(regEntry.Registered) - // Update policy manager - nodesChange, err := s.updatePolicyManagerNodes() + // Finalize registration + c, err := s.finalizeNodeRegistration(savedNode) if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err) + return savedNode.View(), c, err } - if !nodesChange.Empty() { - return savedNode.View(), nodesChange, nil - } - - return savedNode.View(), change.NodeAdded(savedNode.ID), nil + return savedNode.View(), c, nil } // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. @@ -1178,11 +1108,8 @@ func (s *State) HandleNodeFromPreAuthKey( if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { // Find the node to delete var nodeToDelete types.NodeView - for _, nv := range s.nodeStore.ListNodes().All() { - if nv.Valid() && nv.MachineKey() == machineKey { - nodeToDelete = nv - break - } + if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { + nodeToDelete = nv } if nodeToDelete.Valid() { c, err := s.DeleteNode(nodeToDelete) @@ -1194,6 +1121,93 @@ func (s *State) HandleNodeFromPreAuthKey( return types.NodeView{}, change.EmptySet, nil } + // Check if node already exists by node key (this is a refresh/re-registration) + existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regReq.NodeKey) + if exists && existingNodeView.Valid() { + // Node exists - this is a refresh/re-registration + log.Debug(). + Str("node", regReq.Hostinfo.Hostname). + Str("machine_key", machineKey.ShortString()). + Str("node_key", regReq.NodeKey.ShortString()). + Str("user", pak.User.Username()). + Msg("Refreshing existing node registration with pre-auth key") + + // Update NodeStore first with the new expiry and other fields + s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) { + if !regReq.Expiry.IsZero() { + expiry := regReq.Expiry + node.Expiry = &expiry + } + // Update machine key if it changed + node.MachineKey = machineKey + // Update hostinfo + node.Hostinfo = regReq.Hostinfo + // Node is re-registering, so it's coming online + node.IsOnline = ptr.To(true) + node.LastSeen = ptr.To(time.Now()) + // Update auth key association + node.AuthKey = pak + node.AuthKeyID = &pak.ID + // Update forced tags from the pre-auth key + node.ForcedTags = pak.Proto().GetAclTags() + }) + + // Save to database + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + // Update the node in database + node := existingNodeView.AsStruct() + if !regReq.Expiry.IsZero() { + err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), regReq.Expiry) + if err != nil { + return nil, err + } + } + // Update machine key if changed + if node.MachineKey != machineKey { + err := hsdb.NodeSetMachineKey(tx, node, machineKey) + if err != nil { + return nil, err + } + } + // Update tags + err := hsdb.SetTags(tx, existingNodeView.ID(), pak.Proto().GetAclTags()) + if err != nil { + return nil, err + } + // Update last seen + err = hsdb.SetLastSeen(tx, existingNodeView.ID(), time.Now()) + if err != nil { + return nil, err + } + // Mark the pre-auth key as used if not reusable + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, err + } + } + // Return the node to satisfy the Write signature + return hsdb.GetNodeByID(tx, existingNodeView.ID()) + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node: %w", err) + } + + // Get updated node from NodeStore + updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return updatedNode, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } + if !c.IsFull() { + c = change.KeyExpiry(existingNodeView.ID()) + } + + return updatedNode, c, nil + } + log.Debug(). Str("node", regReq.Hostinfo.Hostname). Str("machine_key", machineKey.ShortString()). @@ -1201,17 +1215,9 @@ func (s *State) HandleNodeFromPreAuthKey( Str("user", pak.User.Username()). Msg("Registering node with pre-auth key") - // Check if node already exists with same machine key - var existingNode *types.Node - if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { - existingNode = nv.AsStruct() - } - // Prepare the node for registration nodeToRegister := types.Node{ Hostname: regReq.Hostinfo.Hostname, - UserID: pak.User.ID, - User: pak.User, MachineKey: machineKey, NodeKey: regReq.NodeKey, Hostinfo: regReq.Hostinfo, @@ -1222,58 +1228,52 @@ func (s *State) HandleNodeFromPreAuthKey( AuthKeyID: &pak.ID, } + var expiry *time.Time if !regReq.Expiry.IsZero() { - nodeToRegister.Expiry = ®Req.Expiry + expiry = ®Req.Expiry + nodeToRegister.Expiry = expiry } - // Handle IP allocation and existing node properties - var ipv4, ipv6 *netip.Addr - if existingNode != nil && existingNode.UserID == pak.User.ID { - // Reuse existing node properties - nodeToRegister.ID = existingNode.ID - nodeToRegister.GivenName = existingNode.GivenName - nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes - ipv4 = existingNode.IPv4 - ipv6 = existingNode.IPv6 - } else { - // Allocate new IPs - ipv4, ipv6, err = s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + // Custom update function for existing nodes + updateFunc := func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + node.ForcedTags = nodeToRegister.ForcedTags + node.AuthKey = nodeToRegister.AuthKey + node.AuthKeyID = nodeToRegister.AuthKeyID + } + + // Post-save callback to use the pre-auth key + postSaveFunc := func(tx *gorm.DB, savedNode *types.Node) error { + if !pak.Reusable { + return hsdb.UsePreAuthKey(tx, pak) } + return nil } - nodeToRegister.IPv4 = ipv4 - nodeToRegister.IPv6 = ipv6 - - // Ensure unique given name if not set - if nodeToRegister.GivenName == "" { - givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) - } - nodeToRegister.GivenName = givenName + // Check if node already exists with same machine key for logging + var existingNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { + existingNode = nv.AsStruct() } - var savedNode *types.Node + savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{ + node: &nodeToRegister, + userID: types.UserID(pak.User.ID), + user: &pak.User, + expiry: expiry, + updateExistingNode: updateFunc, + postSaveCallback: postSaveFunc, + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("registering node: %w", err) + } + + // Log re-authorization if it was an existing node if existingNode != nil && existingNode.UserID == pak.User.ID { - // Update existing node - NodeStore first, then database - s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { - node.NodeKey = nodeToRegister.NodeKey - node.Hostname = nodeToRegister.Hostname - node.Hostinfo = nodeToRegister.Hostinfo - node.Endpoints = nodeToRegister.Endpoints - node.RegisterMethod = nodeToRegister.RegisterMethod - node.ForcedTags = nodeToRegister.ForcedTags - node.AuthKey = nodeToRegister.AuthKey - node.AuthKeyID = nodeToRegister.AuthKeyID - if nodeToRegister.Expiry != nil { - node.Expiry = nodeToRegister.Expiry - } - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) - }) - log.Trace(). Caller(). Str("node", nodeToRegister.Hostname). @@ -1281,65 +1281,12 @@ func (s *State) HandleNodeFromPreAuthKey( Str("node_key", regReq.NodeKey.ShortString()). Str("user", pak.User.Username()). Msg("Node re-authorized") - - // Save to database - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - - if !pak.Reusable { - err = hsdb.UsePreAuthKey(tx, pak) - if err != nil { - return nil, fmt.Errorf("using pre auth key: %w", err) - } - } - - return &nodeToRegister, nil - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) - } - } else { - // New node - database first to get ID, then NodeStore - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - - if !pak.Reusable { - err = hsdb.UsePreAuthKey(tx, pak) - if err != nil { - return nil, fmt.Errorf("using pre auth key: %w", err) - } - } - - return &nodeToRegister, nil - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) - } - - // Add to NodeStore after database creates the ID - s.nodeStore.PutNode(*savedNode) } - // Update policy managers - usersChange, err := s.updatePolicyManagerUsers() + // Finalize registration + c, err := s.finalizeNodeRegistration(savedNode) if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err) - } - - nodesChange, err := s.updatePolicyManagerNodes() - if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err) - } - - var c change.ChangeSet - if !usersChange.Empty() || !nodesChange.Empty() { - c = change.PolicyChange() - } else { - c = change.NodeAdded(savedNode.ID) + return savedNode.View(), c, err } return savedNode.View(), c, nil @@ -1622,3 +1569,157 @@ func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { peerChange.LastSeen == nil && peerChange.KeyExpiry == nil } + +// nodeRegistrationHelper contains common logic for registering or updating nodes. +// It handles IP allocation, given name generation, and the NodeStore vs Database update pattern. +type nodeRegistrationHelper struct { + node *types.Node + userID types.UserID + user *types.User + expiry *time.Time + updateExistingNode func(*types.Node) + postSaveCallback func(tx *gorm.DB, savedNode *types.Node) error +} + +// registerOrUpdateNode is a common helper for node registration that both HandleNodeFromAuthPath +// and HandleNodeFromPreAuthKey can use. It encapsulates the complex logic of handling +// existing vs new nodes, IP allocation, and the NodeStore/Database update pattern. +func (s *State) registerOrUpdateNode(helper nodeRegistrationHelper) (*types.Node, error) { + // Check if node exists with same machine key + var existingNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(helper.node.MachineKey); exists && nv.Valid() { + existingNode = nv.AsStruct() + } + + // Check for different user registration + if existingNode != nil && existingNode.UserID != uint(helper.userID) { + return nil, hsdb.ErrDifferentRegisteredUser + } + + // Handle IP allocation and existing node properties + var ipv4, ipv6 *netip.Addr + if existingNode != nil && existingNode.UserID == uint(helper.userID) { + // Reuse existing node properties + helper.node.ID = existingNode.ID + helper.node.GivenName = existingNode.GivenName + helper.node.ApprovedRoutes = existingNode.ApprovedRoutes + ipv4 = existingNode.IPv4 + ipv6 = existingNode.IPv6 + } else { + // Allocate new IPs + var err error + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return nil, fmt.Errorf("allocating IPs: %w", err) + } + } + + helper.node.IPv4 = ipv4 + helper.node.IPv6 = ipv6 + helper.node.UserID = uint(helper.userID) + helper.node.User = *helper.user + if helper.expiry != nil { + helper.node.Expiry = helper.expiry + } + + // Ensure unique given name if not set + if helper.node.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, helper.node.Hostname) + if err != nil { + return nil, fmt.Errorf("failed to ensure unique given name: %w", err) + } + helper.node.GivenName = givenName + } + + var savedNode *types.Node + var err error + if existingNode != nil && existingNode.UserID == uint(helper.userID) { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { + // Apply common updates + node.UserID = helper.node.UserID + node.User = helper.node.User + node.IPv4 = helper.node.IPv4 + node.IPv6 = helper.node.IPv6 + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + if helper.expiry != nil { + node.Expiry = helper.expiry + } + + // Apply custom updates from caller + if helper.updateExistingNode != nil { + helper.updateExistingNode(node) + } + }) + + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(helper.node).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + // Run post-save callback if provided + if helper.postSaveCallback != nil { + if err := helper.postSaveCallback(tx, helper.node); err != nil { + return nil, err + } + } + + return helper.node, nil + }) + if err != nil { + return nil, err + } + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(helper.node).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + // Run post-save callback if provided + if helper.postSaveCallback != nil { + if err := helper.postSaveCallback(tx, helper.node); err != nil { + return nil, err + } + } + + return helper.node, nil + }) + if err != nil { + return nil, err + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) + } + + return savedNode, nil +} + +// finalizeNodeRegistration handles the common final steps after node registration: +// updating policy managers and generating the appropriate change set. +func (s *State) finalizeNodeRegistration(savedNode *types.Node) (change.ChangeSet, error) { + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() + if err != nil { + log.Error().Err(err).Msg("failed to update policy manager users after node registration") + // Don't fail the registration, just log the error + } + + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + log.Error().Err(err).Msg("failed to update policy manager nodes after node registration") + // Don't fail the registration, just log the error + } + + var c change.ChangeSet + if !usersChange.Empty() || !nodesChange.Empty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(savedNode.ID) + } + + return c, nil +} diff --git a/integration/cli_test.go b/integration/cli_test.go index 42d191e0..064bb583 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -354,7 +354,11 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags()) + assert.Equal( + t, + []string{"tag:test1", "tag:test2"}, + listedPreAuthKeys[index].GetAclTags(), + ) } // Test key expiry @@ -386,6 +390,7 @@ func TestPreAuthKeyCommand(t *testing.T) { ) assertNoErr(t, err) + assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[3].GetExpiration().AsTime().After(time.Now())) @@ -445,6 +450,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) + assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now())) assert.True( t,