From 0199b468efa20ce4a797fea8211c55e3a3b00a46 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 22 Sep 2025 17:45:05 +0200 Subject: [PATCH] state: reorganise auth key path Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 13 ++- hscontrol/state/state.go | 174 +++++++++++++++++++-------------------- 2 files changed, 98 insertions(+), 89 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 81032640..2a7258d5 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -227,12 +227,21 @@ func (h *Headscale) handleRegisterWithAuthKey( user := node.User() - return &tailcfg.RegisterResponse{ + resp := &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), User: *user.TailscaleUser(), Login: *user.TailscaleLogin(), - }, nil + } + + log.Trace(). + Interface("reg.resp", resp). + Interface("reg.req", regReq). + Str("node.name", node.Hostname()). + Uint64("node.id", node.ID().Uint64()). + Msg("RegisterResponse") + + return resp, nil } func (h *Headscale) handleRegisterInteractive( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 43f54c0e..1e159e14 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1138,113 +1138,65 @@ func (s *State) HandleNodeFromPreAuthKey( Str("user.name", pak.User.Username()). Msg("Registering node with pre-auth key") + var finalNode types.NodeView + // Check if node already exists with same machine key - var existingNode *types.Node - if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { - existingNode = nv.AsStruct() - } + existingNode, exists := s.nodeStore.GetNodeByMachineKey(machineKey) - // Prepare the node for registration - nodeToRegister := types.Node{ - Hostname: regReq.Hostinfo.Hostname, - UserID: pak.User.ID, - User: pak.User, - MachineKey: machineKey, - NodeKey: regReq.NodeKey, - Hostinfo: regReq.Hostinfo, - LastSeen: ptr.To(time.Now()), - RegisterMethod: util.RegisterMethodAuthKey, - ForcedTags: pak.Proto().GetAclTags(), - AuthKey: pak, - AuthKeyID: &pak.ID, - } + // If this node exist and belongs to the same user as the pre-auth key, update the node in place. + if exists && existingNode.Valid() && existingNode.User().ID == pak.User.ID { + log.Trace(). + Caller(). + Str("node.name", existingNode.Hostname()). + Uint64("node.id", existingNode.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", existingNode.NodeKey().ShortString()). + Str("user.name", pak.User.Username()). + Msg("Node re-registering with existing machine key and user, updating in place") - if !regReq.Expiry.IsZero() { - nodeToRegister.Expiry = ®Req.Expiry - } - - // Handle IP allocation and existing node properties - var ipv4, ipv6 *netip.Addr - if existingNode != nil && existingNode.UserID == pak.User.ID { - // Reuse existing node properties - nodeToRegister.ID = existingNode.ID - nodeToRegister.GivenName = existingNode.GivenName - nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes - ipv4 = existingNode.IPv4 - ipv6 = existingNode.IPv6 - } else { - // Allocate new IPs - ipv4, ipv6, err = s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) - } - } - - nodeToRegister.IPv4 = ipv4 - nodeToRegister.IPv6 = ipv6 - - // Ensure unique given name if not set - if nodeToRegister.GivenName == "" { - givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) - } - nodeToRegister.GivenName = givenName - } - - var savedNode *types.Node - if existingNode != nil && existingNode.UserID == pak.User.ID { // Update existing node - NodeStore first, then database - updatedNodeView, ok := s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { - node.NodeKey = nodeToRegister.NodeKey - node.Hostname = nodeToRegister.Hostname + updatedNodeView, ok := s.nodeStore.UpdateNode(existingNode.ID(), func(node *types.Node) { + node.NodeKey = regReq.NodeKey + node.Hostname = regReq.Hostinfo.Hostname // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). // Preserve NetInfo from existing node when re-registering - netInfo := NetInfoFromMapRequest(existingNode.ID, existingNode.Hostinfo, nodeToRegister.Hostinfo) + netInfo := NetInfoFromMapRequest(existingNode.ID(), node.Hostinfo, regReq.Hostinfo) if netInfo != nil { - if nodeToRegister.Hostinfo != nil { - hostinfoCopy := *nodeToRegister.Hostinfo + if node.Hostinfo != nil { + hostinfoCopy := *node.Hostinfo hostinfoCopy.NetInfo = netInfo node.Hostinfo = &hostinfoCopy } else { node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} } } else { - node.Hostinfo = nodeToRegister.Hostinfo + node.Hostinfo = node.Hostinfo } - node.Endpoints = nodeToRegister.Endpoints - node.RegisterMethod = nodeToRegister.RegisterMethod - node.ForcedTags = nodeToRegister.ForcedTags - node.AuthKey = nodeToRegister.AuthKey - node.AuthKeyID = nodeToRegister.AuthKeyID - if nodeToRegister.Expiry != nil { - node.Expiry = nodeToRegister.Expiry - } + node.RegisterMethod = util.RegisterMethodAuthKey + + // TODO(kradalby): This might need a rework as part of #2417 + node.ForcedTags = pak.Proto().GetAclTags() + node.AuthKey = pak + node.AuthKeyID = &pak.ID node.IsOnline = ptr.To(false) node.LastSeen = ptr.To(time.Now()) + + // Update expiry, if it is zero, it means that the node will + // not have an expiry anymore. If it is non-zero, we set that. + node.Expiry = ®Req.Expiry }) if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNode.ID) } - log.Trace(). - Caller(). - Str("node.name", nodeToRegister.Hostname). - Uint64("node.id", existingNode.ID.Uint64()). - Str("machine.key", machineKey.ShortString()). - Str("node.key", regReq.NodeKey.ShortString()). - Str("user.name", pak.User.Username()). - Msg("Node re-authorized") - // Use the node from UpdateNode to save to database - nodePtr := updatedNodeView.AsStruct() - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(nodePtr).Error; err != nil { + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } @@ -1255,14 +1207,62 @@ func (s *State) HandleNodeFromPreAuthKey( } } - return nodePtr, nil + return nil, nil }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) } + + log.Trace(). + Caller(). + Str("node.name", updatedNodeView.Hostname()). + Uint64("node.id", updatedNodeView.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", updatedNodeView.NodeKey().ShortString()). + Str("user.name", pak.User.Username()). + Msg("Node re-authorized") + + finalNode = updatedNodeView } else { + // This is a new node, or an existing node that belongs to a different user. + // Prepare the node for registration + nodeToRegister := types.Node{ + Hostname: regReq.Hostinfo.Hostname, + UserID: pak.User.ID, + User: pak.User, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + RegisterMethod: util.RegisterMethodAuthKey, + ForcedTags: pak.Proto().GetAclTags(), + AuthKey: pak, + AuthKeyID: &pak.ID, + } + + if !regReq.Expiry.IsZero() { + nodeToRegister.Expiry = ®Req.Expiry + } + + ipv4, ipv6, err := s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } + + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 + + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + // New node - database first to get ID, then NodeStore - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { if err := tx.Save(&nodeToRegister).Error; err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } @@ -1281,28 +1281,28 @@ func (s *State) HandleNodeFromPreAuthKey( } // Add to NodeStore after database creates the ID - _ = s.nodeStore.PutNode(*savedNode) + finalNode = s.nodeStore.PutNode(*savedNode) } // Update policy managers usersChange, err := s.updatePolicyManagerUsers() if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err) + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager users: %w", err) } nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err) + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) } var c change.ChangeSet if !usersChange.Empty() || !nodesChange.Empty() { c = change.PolicyChange() } else { - c = change.NodeAdded(savedNode.ID) + c = change.NodeAdded(finalNode.ID()) } - return savedNode.View(), c, nil + return finalNode, c, nil } // updatePolicyManagerUsers updates the policy manager with current users.