mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-24 13:46:53 +02:00
more reuse
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
cef1728cc3
commit
4d66d1f8d3
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -47,8 +48,9 @@ func CreatePreAuthKey(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove duplicates
|
// Remove duplicates and sort for consistency
|
||||||
aclTags = set.SetOf(aclTags).Slice()
|
aclTags = set.SetOf(aclTags).Slice()
|
||||||
|
slices.Sort(aclTags)
|
||||||
|
|
||||||
// TODO(kradalby): factor out and create a reusable tag validation,
|
// TODO(kradalby): factor out and create a reusable tag validation,
|
||||||
// check if there is one in Tailscale's lib.
|
// check if there is one in Tailscale's lib.
|
||||||
|
@ -994,7 +994,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if node already exists by node key
|
// Check if node already exists by node key (this is a refresh/re-registration)
|
||||||
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey)
|
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey)
|
||||||
if exists && existingNodeView.Valid() {
|
if exists && existingNodeView.Valid() {
|
||||||
// Node exists - this is a refresh/re-registration
|
// Node exists - this is a refresh/re-registration
|
||||||
@ -1010,8 +1010,8 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
if expiry != nil {
|
if expiry != nil {
|
||||||
node.Expiry = expiry
|
node.Expiry = expiry
|
||||||
}
|
}
|
||||||
// Mark as offline since node is reconnecting
|
// Node is re-registering, so it's coming online
|
||||||
node.IsOnline = ptr.To(false)
|
node.IsOnline = ptr.To(true)
|
||||||
node.LastSeen = ptr.To(time.Now())
|
node.LastSeen = ptr.To(time.Now())
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -1041,96 +1041,30 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
Str("expiresAt", fmt.Sprintf("%v", expiry)).
|
Str("expiresAt", fmt.Sprintf("%v", expiry)).
|
||||||
Msg("Registering new node from auth callback")
|
Msg("Registering new node from auth callback")
|
||||||
|
|
||||||
// Check if node exists with same machine key
|
|
||||||
var existingMachineNode *types.Node
|
|
||||||
if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() {
|
|
||||||
existingMachineNode = nv.AsStruct()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for different user registration
|
|
||||||
if existingMachineNode != nil && existingMachineNode.UserID != uint(userID) {
|
|
||||||
return types.NodeView{}, change.EmptySet, hsdb.ErrDifferentRegisteredUser
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the node for registration
|
// Prepare the node for registration
|
||||||
nodeToRegister := regEntry.Node
|
nodeToRegister := regEntry.Node
|
||||||
nodeToRegister.UserID = uint(userID)
|
|
||||||
nodeToRegister.User = *user
|
|
||||||
nodeToRegister.RegisterMethod = registrationMethod
|
nodeToRegister.RegisterMethod = registrationMethod
|
||||||
if expiry != nil {
|
|
||||||
nodeToRegister.Expiry = expiry
|
// Custom update function for existing nodes
|
||||||
|
updateFunc := func(node *types.Node) {
|
||||||
|
node.NodeKey = nodeToRegister.NodeKey
|
||||||
|
node.DiscoKey = nodeToRegister.DiscoKey
|
||||||
|
node.Hostname = nodeToRegister.Hostname
|
||||||
|
node.Hostinfo = nodeToRegister.Hostinfo
|
||||||
|
node.Endpoints = nodeToRegister.Endpoints
|
||||||
|
node.RegisterMethod = nodeToRegister.RegisterMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle IP allocation
|
savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{
|
||||||
var ipv4, ipv6 *netip.Addr
|
node: &nodeToRegister,
|
||||||
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
userID: userID,
|
||||||
// Reuse existing IPs and properties
|
user: user,
|
||||||
nodeToRegister.ID = existingMachineNode.ID
|
expiry: expiry,
|
||||||
nodeToRegister.GivenName = existingMachineNode.GivenName
|
updateExistingNode: updateFunc,
|
||||||
nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes
|
postSaveCallback: nil, // No post-save callback needed
|
||||||
ipv4 = existingMachineNode.IPv4
|
})
|
||||||
ipv6 = existingMachineNode.IPv6
|
if err != nil {
|
||||||
} else {
|
return types.NodeView{}, change.EmptySet, err
|
||||||
// Allocate new IPs
|
|
||||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeToRegister.IPv4 = ipv4
|
|
||||||
nodeToRegister.IPv6 = ipv6
|
|
||||||
|
|
||||||
// Ensure unique given name if not set
|
|
||||||
if nodeToRegister.GivenName == "" {
|
|
||||||
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
|
|
||||||
}
|
|
||||||
nodeToRegister.GivenName = givenName
|
|
||||||
}
|
|
||||||
|
|
||||||
var savedNode *types.Node
|
|
||||||
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
|
||||||
// Update existing node - NodeStore first, then database
|
|
||||||
s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) {
|
|
||||||
node.NodeKey = nodeToRegister.NodeKey
|
|
||||||
node.DiscoKey = nodeToRegister.DiscoKey
|
|
||||||
node.Hostname = nodeToRegister.Hostname
|
|
||||||
node.Hostinfo = nodeToRegister.Hostinfo
|
|
||||||
node.Endpoints = nodeToRegister.Endpoints
|
|
||||||
node.RegisterMethod = nodeToRegister.RegisterMethod
|
|
||||||
if expiry != nil {
|
|
||||||
node.Expiry = expiry
|
|
||||||
}
|
|
||||||
node.IsOnline = ptr.To(false)
|
|
||||||
node.LastSeen = ptr.To(time.Now())
|
|
||||||
})
|
|
||||||
|
|
||||||
// Save to database
|
|
||||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
|
||||||
}
|
|
||||||
return &nodeToRegister, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// New node - database first to get ID, then NodeStore
|
|
||||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
|
||||||
}
|
|
||||||
return &nodeToRegister, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to NodeStore after database creates the ID
|
|
||||||
s.nodeStore.PutNode(*savedNode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete from registration cache
|
// Delete from registration cache
|
||||||
@ -1143,17 +1077,13 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
}
|
}
|
||||||
close(regEntry.Registered)
|
close(regEntry.Registered)
|
||||||
|
|
||||||
// Update policy manager
|
// Finalize registration
|
||||||
nodesChange, err := s.updatePolicyManagerNodes()
|
c, err := s.finalizeNodeRegistration(savedNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err)
|
return savedNode.View(), c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !nodesChange.Empty() {
|
return savedNode.View(), c, nil
|
||||||
return savedNode.View(), nodesChange, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return savedNode.View(), change.NodeAdded(savedNode.ID), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
|
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
|
||||||
@ -1178,11 +1108,8 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||||
// Find the node to delete
|
// Find the node to delete
|
||||||
var nodeToDelete types.NodeView
|
var nodeToDelete types.NodeView
|
||||||
for _, nv := range s.nodeStore.ListNodes().All() {
|
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
|
||||||
if nv.Valid() && nv.MachineKey() == machineKey {
|
nodeToDelete = nv
|
||||||
nodeToDelete = nv
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if nodeToDelete.Valid() {
|
if nodeToDelete.Valid() {
|
||||||
c, err := s.DeleteNode(nodeToDelete)
|
c, err := s.DeleteNode(nodeToDelete)
|
||||||
@ -1194,6 +1121,93 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
return types.NodeView{}, change.EmptySet, nil
|
return types.NodeView{}, change.EmptySet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if node already exists by node key (this is a refresh/re-registration)
|
||||||
|
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regReq.NodeKey)
|
||||||
|
if exists && existingNodeView.Valid() {
|
||||||
|
// Node exists - this is a refresh/re-registration
|
||||||
|
log.Debug().
|
||||||
|
Str("node", regReq.Hostinfo.Hostname).
|
||||||
|
Str("machine_key", machineKey.ShortString()).
|
||||||
|
Str("node_key", regReq.NodeKey.ShortString()).
|
||||||
|
Str("user", pak.User.Username()).
|
||||||
|
Msg("Refreshing existing node registration with pre-auth key")
|
||||||
|
|
||||||
|
// Update NodeStore first with the new expiry and other fields
|
||||||
|
s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) {
|
||||||
|
if !regReq.Expiry.IsZero() {
|
||||||
|
expiry := regReq.Expiry
|
||||||
|
node.Expiry = &expiry
|
||||||
|
}
|
||||||
|
// Update machine key if it changed
|
||||||
|
node.MachineKey = machineKey
|
||||||
|
// Update hostinfo
|
||||||
|
node.Hostinfo = regReq.Hostinfo
|
||||||
|
// Node is re-registering, so it's coming online
|
||||||
|
node.IsOnline = ptr.To(true)
|
||||||
|
node.LastSeen = ptr.To(time.Now())
|
||||||
|
// Update auth key association
|
||||||
|
node.AuthKey = pak
|
||||||
|
node.AuthKeyID = &pak.ID
|
||||||
|
// Update forced tags from the pre-auth key
|
||||||
|
node.ForcedTags = pak.Proto().GetAclTags()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Save to database
|
||||||
|
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
// Update the node in database
|
||||||
|
node := existingNodeView.AsStruct()
|
||||||
|
if !regReq.Expiry.IsZero() {
|
||||||
|
err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), regReq.Expiry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update machine key if changed
|
||||||
|
if node.MachineKey != machineKey {
|
||||||
|
err := hsdb.NodeSetMachineKey(tx, node, machineKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update tags
|
||||||
|
err := hsdb.SetTags(tx, existingNodeView.ID(), pak.Proto().GetAclTags())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Update last seen
|
||||||
|
err = hsdb.SetLastSeen(tx, existingNodeView.ID(), time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Mark the pre-auth key as used if not reusable
|
||||||
|
if !pak.Reusable {
|
||||||
|
err = hsdb.UsePreAuthKey(tx, pak)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Return the node to satisfy the Write signature
|
||||||
|
return hsdb.GetNodeByID(tx, existingNodeView.ID())
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get updated node from NodeStore
|
||||||
|
updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID())
|
||||||
|
|
||||||
|
// Check if policy manager needs updating
|
||||||
|
c, err := s.updatePolicyManagerNodes()
|
||||||
|
if err != nil {
|
||||||
|
return updatedNode, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||||
|
}
|
||||||
|
if !c.IsFull() {
|
||||||
|
c = change.KeyExpiry(existingNodeView.ID())
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedNode, c, nil
|
||||||
|
}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("node", regReq.Hostinfo.Hostname).
|
Str("node", regReq.Hostinfo.Hostname).
|
||||||
Str("machine_key", machineKey.ShortString()).
|
Str("machine_key", machineKey.ShortString()).
|
||||||
@ -1201,17 +1215,9 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
Str("user", pak.User.Username()).
|
Str("user", pak.User.Username()).
|
||||||
Msg("Registering node with pre-auth key")
|
Msg("Registering node with pre-auth key")
|
||||||
|
|
||||||
// Check if node already exists with same machine key
|
|
||||||
var existingNode *types.Node
|
|
||||||
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
|
|
||||||
existingNode = nv.AsStruct()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the node for registration
|
// Prepare the node for registration
|
||||||
nodeToRegister := types.Node{
|
nodeToRegister := types.Node{
|
||||||
Hostname: regReq.Hostinfo.Hostname,
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
UserID: pak.User.ID,
|
|
||||||
User: pak.User,
|
|
||||||
MachineKey: machineKey,
|
MachineKey: machineKey,
|
||||||
NodeKey: regReq.NodeKey,
|
NodeKey: regReq.NodeKey,
|
||||||
Hostinfo: regReq.Hostinfo,
|
Hostinfo: regReq.Hostinfo,
|
||||||
@ -1222,58 +1228,52 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
AuthKeyID: &pak.ID,
|
AuthKeyID: &pak.ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var expiry *time.Time
|
||||||
if !regReq.Expiry.IsZero() {
|
if !regReq.Expiry.IsZero() {
|
||||||
nodeToRegister.Expiry = ®Req.Expiry
|
expiry = ®Req.Expiry
|
||||||
|
nodeToRegister.Expiry = expiry
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle IP allocation and existing node properties
|
// Custom update function for existing nodes
|
||||||
var ipv4, ipv6 *netip.Addr
|
updateFunc := func(node *types.Node) {
|
||||||
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
node.NodeKey = nodeToRegister.NodeKey
|
||||||
// Reuse existing node properties
|
node.Hostname = nodeToRegister.Hostname
|
||||||
nodeToRegister.ID = existingNode.ID
|
node.Hostinfo = nodeToRegister.Hostinfo
|
||||||
nodeToRegister.GivenName = existingNode.GivenName
|
node.Endpoints = nodeToRegister.Endpoints
|
||||||
nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes
|
node.RegisterMethod = nodeToRegister.RegisterMethod
|
||||||
ipv4 = existingNode.IPv4
|
node.ForcedTags = nodeToRegister.ForcedTags
|
||||||
ipv6 = existingNode.IPv6
|
node.AuthKey = nodeToRegister.AuthKey
|
||||||
} else {
|
node.AuthKeyID = nodeToRegister.AuthKeyID
|
||||||
// Allocate new IPs
|
}
|
||||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
|
||||||
if err != nil {
|
// Post-save callback to use the pre-auth key
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
postSaveFunc := func(tx *gorm.DB, savedNode *types.Node) error {
|
||||||
|
if !pak.Reusable {
|
||||||
|
return hsdb.UsePreAuthKey(tx, pak)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeToRegister.IPv4 = ipv4
|
// Check if node already exists with same machine key for logging
|
||||||
nodeToRegister.IPv6 = ipv6
|
var existingNode *types.Node
|
||||||
|
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
|
||||||
// Ensure unique given name if not set
|
existingNode = nv.AsStruct()
|
||||||
if nodeToRegister.GivenName == "" {
|
|
||||||
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
|
|
||||||
}
|
|
||||||
nodeToRegister.GivenName = givenName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var savedNode *types.Node
|
savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{
|
||||||
|
node: &nodeToRegister,
|
||||||
|
userID: types.UserID(pak.User.ID),
|
||||||
|
user: &pak.User,
|
||||||
|
expiry: expiry,
|
||||||
|
updateExistingNode: updateFunc,
|
||||||
|
postSaveCallback: postSaveFunc,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("registering node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log re-authorization if it was an existing node
|
||||||
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
||||||
// Update existing node - NodeStore first, then database
|
|
||||||
s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
|
|
||||||
node.NodeKey = nodeToRegister.NodeKey
|
|
||||||
node.Hostname = nodeToRegister.Hostname
|
|
||||||
node.Hostinfo = nodeToRegister.Hostinfo
|
|
||||||
node.Endpoints = nodeToRegister.Endpoints
|
|
||||||
node.RegisterMethod = nodeToRegister.RegisterMethod
|
|
||||||
node.ForcedTags = nodeToRegister.ForcedTags
|
|
||||||
node.AuthKey = nodeToRegister.AuthKey
|
|
||||||
node.AuthKeyID = nodeToRegister.AuthKeyID
|
|
||||||
if nodeToRegister.Expiry != nil {
|
|
||||||
node.Expiry = nodeToRegister.Expiry
|
|
||||||
}
|
|
||||||
node.IsOnline = ptr.To(false)
|
|
||||||
node.LastSeen = ptr.To(time.Now())
|
|
||||||
})
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", nodeToRegister.Hostname).
|
Str("node", nodeToRegister.Hostname).
|
||||||
@ -1281,65 +1281,12 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
Str("node_key", regReq.NodeKey.ShortString()).
|
Str("node_key", regReq.NodeKey.ShortString()).
|
||||||
Str("user", pak.User.Username()).
|
Str("user", pak.User.Username()).
|
||||||
Msg("Node re-authorized")
|
Msg("Node re-authorized")
|
||||||
|
|
||||||
// Save to database
|
|
||||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pak.Reusable {
|
|
||||||
err = hsdb.UsePreAuthKey(tx, pak)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("using pre auth key: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &nodeToRegister, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// New node - database first to get ID, then NodeStore
|
|
||||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pak.Reusable {
|
|
||||||
err = hsdb.UsePreAuthKey(tx, pak)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("using pre auth key: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &nodeToRegister, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to NodeStore after database creates the ID
|
|
||||||
s.nodeStore.PutNode(*savedNode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update policy managers
|
// Finalize registration
|
||||||
usersChange, err := s.updatePolicyManagerUsers()
|
c, err := s.finalizeNodeRegistration(savedNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err)
|
return savedNode.View(), c, err
|
||||||
}
|
|
||||||
|
|
||||||
nodesChange, err := s.updatePolicyManagerNodes()
|
|
||||||
if err != nil {
|
|
||||||
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var c change.ChangeSet
|
|
||||||
if !usersChange.Empty() || !nodesChange.Empty() {
|
|
||||||
c = change.PolicyChange()
|
|
||||||
} else {
|
|
||||||
c = change.NodeAdded(savedNode.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return savedNode.View(), c, nil
|
return savedNode.View(), c, nil
|
||||||
@ -1622,3 +1569,157 @@ func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
|
|||||||
peerChange.LastSeen == nil &&
|
peerChange.LastSeen == nil &&
|
||||||
peerChange.KeyExpiry == nil
|
peerChange.KeyExpiry == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nodeRegistrationHelper contains common logic for registering or updating nodes.
|
||||||
|
// It handles IP allocation, given name generation, and the NodeStore vs Database update pattern.
|
||||||
|
type nodeRegistrationHelper struct {
|
||||||
|
node *types.Node
|
||||||
|
userID types.UserID
|
||||||
|
user *types.User
|
||||||
|
expiry *time.Time
|
||||||
|
updateExistingNode func(*types.Node)
|
||||||
|
postSaveCallback func(tx *gorm.DB, savedNode *types.Node) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerOrUpdateNode is a common helper for node registration that both HandleNodeFromAuthPath
|
||||||
|
// and HandleNodeFromPreAuthKey can use. It encapsulates the complex logic of handling
|
||||||
|
// existing vs new nodes, IP allocation, and the NodeStore/Database update pattern.
|
||||||
|
func (s *State) registerOrUpdateNode(helper nodeRegistrationHelper) (*types.Node, error) {
|
||||||
|
// Check if node exists with same machine key
|
||||||
|
var existingNode *types.Node
|
||||||
|
if nv, exists := s.nodeStore.GetNodeByMachineKey(helper.node.MachineKey); exists && nv.Valid() {
|
||||||
|
existingNode = nv.AsStruct()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for different user registration
|
||||||
|
if existingNode != nil && existingNode.UserID != uint(helper.userID) {
|
||||||
|
return nil, hsdb.ErrDifferentRegisteredUser
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle IP allocation and existing node properties
|
||||||
|
var ipv4, ipv6 *netip.Addr
|
||||||
|
if existingNode != nil && existingNode.UserID == uint(helper.userID) {
|
||||||
|
// Reuse existing node properties
|
||||||
|
helper.node.ID = existingNode.ID
|
||||||
|
helper.node.GivenName = existingNode.GivenName
|
||||||
|
helper.node.ApprovedRoutes = existingNode.ApprovedRoutes
|
||||||
|
ipv4 = existingNode.IPv4
|
||||||
|
ipv6 = existingNode.IPv6
|
||||||
|
} else {
|
||||||
|
// Allocate new IPs
|
||||||
|
var err error
|
||||||
|
ipv4, ipv6, err = s.ipAlloc.Next()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("allocating IPs: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
helper.node.IPv4 = ipv4
|
||||||
|
helper.node.IPv6 = ipv6
|
||||||
|
helper.node.UserID = uint(helper.userID)
|
||||||
|
helper.node.User = *helper.user
|
||||||
|
if helper.expiry != nil {
|
||||||
|
helper.node.Expiry = helper.expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure unique given name if not set
|
||||||
|
if helper.node.GivenName == "" {
|
||||||
|
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, helper.node.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||||
|
}
|
||||||
|
helper.node.GivenName = givenName
|
||||||
|
}
|
||||||
|
|
||||||
|
var savedNode *types.Node
|
||||||
|
var err error
|
||||||
|
if existingNode != nil && existingNode.UserID == uint(helper.userID) {
|
||||||
|
// Update existing node - NodeStore first, then database
|
||||||
|
s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
|
||||||
|
// Apply common updates
|
||||||
|
node.UserID = helper.node.UserID
|
||||||
|
node.User = helper.node.User
|
||||||
|
node.IPv4 = helper.node.IPv4
|
||||||
|
node.IPv6 = helper.node.IPv6
|
||||||
|
node.IsOnline = ptr.To(false)
|
||||||
|
node.LastSeen = ptr.To(time.Now())
|
||||||
|
if helper.expiry != nil {
|
||||||
|
node.Expiry = helper.expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom updates from caller
|
||||||
|
if helper.updateExistingNode != nil {
|
||||||
|
helper.updateExistingNode(node)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Save to database
|
||||||
|
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
if err := tx.Save(helper.node).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run post-save callback if provided
|
||||||
|
if helper.postSaveCallback != nil {
|
||||||
|
if err := helper.postSaveCallback(tx, helper.node); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return helper.node, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New node - database first to get ID, then NodeStore
|
||||||
|
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
if err := tx.Save(helper.node).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run post-save callback if provided
|
||||||
|
if helper.postSaveCallback != nil {
|
||||||
|
if err := helper.postSaveCallback(tx, helper.node); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return helper.node, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to NodeStore after database creates the ID
|
||||||
|
s.nodeStore.PutNode(*savedNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return savedNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// finalizeNodeRegistration handles the common final steps after node registration:
|
||||||
|
// updating policy managers and generating the appropriate change set.
|
||||||
|
func (s *State) finalizeNodeRegistration(savedNode *types.Node) (change.ChangeSet, error) {
|
||||||
|
// Update policy managers
|
||||||
|
usersChange, err := s.updatePolicyManagerUsers()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to update policy manager users after node registration")
|
||||||
|
// Don't fail the registration, just log the error
|
||||||
|
}
|
||||||
|
|
||||||
|
nodesChange, err := s.updatePolicyManagerNodes()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to update policy manager nodes after node registration")
|
||||||
|
// Don't fail the registration, just log the error
|
||||||
|
}
|
||||||
|
|
||||||
|
var c change.ChangeSet
|
||||||
|
if !usersChange.Empty() || !nodesChange.Empty() {
|
||||||
|
c = change.PolicyChange()
|
||||||
|
} else {
|
||||||
|
c = change.NodeAdded(savedNode.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
@ -354,7 +354,11 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
|
assert.Equal(
|
||||||
|
t,
|
||||||
|
[]string{"tag:test1", "tag:test2"},
|
||||||
|
listedPreAuthKeys[index].GetAclTags(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test key expiry
|
// Test key expiry
|
||||||
@ -386,6 +390,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
|||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
|
||||||
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
|
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
|
||||||
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
|
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
|
||||||
assert.True(t, listedPreAuthKeysAfterExpire[3].GetExpiration().AsTime().After(time.Now()))
|
assert.True(t, listedPreAuthKeysAfterExpire[3].GetExpiration().AsTime().After(time.Now()))
|
||||||
@ -445,6 +450,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
|||||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||||
assert.Len(t, listedPreAuthKeys, 2)
|
assert.Len(t, listedPreAuthKeys, 2)
|
||||||
|
|
||||||
|
|
||||||
assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now()))
|
assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now()))
|
||||||
assert.True(
|
assert.True(
|
||||||
t,
|
t,
|
||||||
|
Loading…
Reference in New Issue
Block a user