mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-14 13:51:01 +02:00
move registration into state
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
c24b988247
commit
d78c69e112
@ -327,115 +327,20 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
||||
})
|
||||
}
|
||||
|
||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
||||
// with a registrationID to register or reauthenticate a node.
|
||||
// If the node found in the registration cache is not already registered,
|
||||
// it will be registered with the user and the node will be removed from the cache.
|
||||
// If the node is already registered, the expiry will be updated.
|
||||
// The node, and a boolean indicating if it was a new node or not, will be returned.
|
||||
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationID types.RegistrationID,
|
||||
userID types.UserID,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
var nodeChange change.ChangeSet
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
user, err := GetUserByID(tx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register node from auth callback, %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registration_id", registrationID.String()).
|
||||
Str("username", user.Username()).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||
Msg("Registering node from API/CLI or auth callback")
|
||||
// RegisterNodeForTest is used only for testing purposes to register a node directly in the database.
|
||||
// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey.
|
||||
func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
if !testing.Testing() {
|
||||
panic("RegisterNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
// TODO(kradalby): This looks quite wrong? why ID 0?
|
||||
// Why not always?
|
||||
// Registration of expired node with different user
|
||||
if reg.Node.ID != 0 &&
|
||||
reg.Node.UserID != user.ID {
|
||||
return nil, ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
reg.Node.UserID = user.ID
|
||||
reg.Node.User = *user
|
||||
reg.Node.RegisterMethod = registrationMethod
|
||||
|
||||
if nodeExpiry != nil {
|
||||
reg.Node.Expiry = nodeExpiry
|
||||
}
|
||||
|
||||
node, err := RegisterNode(
|
||||
tx,
|
||||
reg.Node,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
hsdb.regCache.Delete(registrationID)
|
||||
}
|
||||
|
||||
// Signal to waiting clients that the machine has been registered.
|
||||
select {
|
||||
case reg.Registered <- node:
|
||||
default:
|
||||
}
|
||||
close(reg.Registered)
|
||||
|
||||
nodeChange = change.NodeAdded(node.ID)
|
||||
|
||||
return node, err
|
||||
} else {
|
||||
// If the node is already registered, this is a refresh.
|
||||
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// CRITICAL: Reload the node to get the updated expiry
|
||||
// Without this, we return stale node data to NodeStore
|
||||
updatedNode, err := GetNodeByID(tx, node.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reload node after expiry update: %w", err)
|
||||
}
|
||||
|
||||
nodeChange = change.KeyExpiry(node.ID)
|
||||
|
||||
return updatedNode, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, nodeChange, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return RegisterNode(tx, node, ipv4, ipv6)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Username()).
|
||||
Msg("Registering node")
|
||||
Msg("Registering test node")
|
||||
|
||||
// If the a new node is registered with the same machine key, to the same user,
|
||||
// update the existing node.
|
||||
@ -469,7 +374,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Username()).
|
||||
Msg("Node authorized again")
|
||||
Msg("Test node authorized again")
|
||||
|
||||
return &node, nil
|
||||
}
|
||||
@ -478,7 +383,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
node.IPv6 = ipv6
|
||||
|
||||
if node.GivenName == "" {
|
||||
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
|
||||
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||
}
|
||||
@ -493,7 +398,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Msg("Node registered with the database")
|
||||
Msg("Test node registered with the database")
|
||||
|
||||
return &node, nil
|
||||
}
|
||||
@ -566,7 +471,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
||||
return len(nodes) == 0, nil
|
||||
}
|
||||
|
||||
func ensureUniqueGivenName(
|
||||
// EnsureUniqueGivenName generates a unique given name for a node based on its hostname.
|
||||
func EnsureUniqueGivenName(
|
||||
tx *gorm.DB,
|
||||
name string,
|
||||
) (string, error) {
|
||||
@ -797,7 +703,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
||||
var registeredNode *types.Node
|
||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
registeredNode, err = RegisterNode(tx, *node, ipv4, ipv6)
|
||||
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -679,11 +679,11 @@ func TestRenameNode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node, nil, nil)
|
||||
_, err := RegisterNodeForTest(tx, node, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
})
|
||||
@ -780,11 +780,11 @@ func TestListPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node1, nil, nil)
|
||||
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
})
|
||||
@ -865,11 +865,11 @@ func TestListNodes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node1, nil, nil)
|
||||
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
})
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
@ -66,9 +65,9 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
// Always use batcher's view of online status for self node
|
||||
// The batcher respects grace periods for logout scenarios
|
||||
node := nodeView.AsStruct()
|
||||
if b.mapper.batcher != nil {
|
||||
node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||
}
|
||||
// if b.mapper.batcher != nil {
|
||||
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||
// }
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
|
@ -175,7 +175,7 @@ func (m *mapper) fullMapResponse(
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
// Add fresh online status to peers from batcher connection state
|
||||
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||
// peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
@ -186,9 +186,9 @@ func (m *mapper) fullMapResponse(
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peersWithOnlineStatus).
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peersWithOnlineStatus).
|
||||
WithPeers(peers).
|
||||
Build(messages...)
|
||||
}
|
||||
|
||||
@ -220,13 +220,13 @@ func (m *mapper) peerChangeResponse(
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
// Add fresh online status to peers from batcher connection state
|
||||
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||
// peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peersWithOnlineStatus).
|
||||
WithPeerChanges(peersWithOnlineStatus).
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
Build()
|
||||
}
|
||||
|
||||
|
@ -381,6 +381,25 @@ func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bo
|
||||
return nodeView, exists
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists.
|
||||
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc()
|
||||
|
||||
snapshot := s.data.Load()
|
||||
// We don't have a byMachineKey map, so we need to iterate
|
||||
// This could be optimized by adding a byMachineKey map if this becomes a hot path
|
||||
for _, node := range snapshot.nodesByID {
|
||||
if node.MachineKey == machineKey {
|
||||
return node.View(), true
|
||||
}
|
||||
}
|
||||
|
||||
return types.NodeView{}, false
|
||||
}
|
||||
|
||||
// DebugString returns debug information about the NodeStore.
|
||||
func (s *NodeStore) DebugString() string {
|
||||
snapshot := s.data.Load()
|
||||
|
@ -99,7 +99,7 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading nodes: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// On startup, all nodes should be marked as offline until they reconnect
|
||||
// This ensures we don't have stale online status from previous runs
|
||||
for _, node := range nodes {
|
||||
@ -519,6 +519,14 @@ func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool)
|
||||
return s.nodeStore.GetNodeByNodeKey(nodeKey)
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey retrieves a node by its machine key.
|
||||
// The bool indicates if the node exists or is available (like "err not found").
|
||||
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||
func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||
return s.nodeStore.GetNodeByMachineKey(machineKey)
|
||||
}
|
||||
|
||||
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
|
||||
func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] {
|
||||
if len(nodeIDs) == 0 {
|
||||
@ -790,6 +798,8 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
if exists && existingNode.Valid() {
|
||||
node.IsOnline = ptr.To(existingNode.IsOnline().Get())
|
||||
}
|
||||
// TODO(kradalby): This should just update the IP addresses, nothing else in the node store.
|
||||
// We should avoid PutNode here.
|
||||
s.nodeStore.PutNode(*node)
|
||||
}
|
||||
}
|
||||
@ -969,68 +979,181 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (types.NodeView, change.ChangeSet, error) {
|
||||
// Get the registration entry to check the machine key
|
||||
var ipv4, ipv6 *netip.Addr
|
||||
var err error
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if we have the registration entry to determine if we should reuse IPs
|
||||
if regEntry, ok := s.GetRegistrationCacheEntry(registrationID); ok {
|
||||
// Check if node already exists with same machine key and user
|
||||
// to avoid allocating new IPs unnecessarily
|
||||
existingNode, _ := s.db.GetNodeByMachineKey(regEntry.Node.MachineKey)
|
||||
// Get the registration entry from cache
|
||||
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
// Only allocate new IPs if:
|
||||
// 1. No existing node found, OR
|
||||
// 2. Existing node belongs to a different user
|
||||
if existingNode == nil || existingNode.UserID != uint(userID) {
|
||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
// Get the user
|
||||
user, err := s.db.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
// Check if node already exists by node key
|
||||
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey)
|
||||
if exists && existingNodeView.Valid() {
|
||||
// Node exists - this is a refresh/re-registration
|
||||
log.Debug().
|
||||
Str("registration_id", registrationID.String()).
|
||||
Str("username", user.Username()).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("node", existingNodeView.Hostname()).
|
||||
Msg("Refreshing existing node registration")
|
||||
|
||||
// Update NodeStore first with the new expiry
|
||||
s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) {
|
||||
if expiry != nil {
|
||||
node.Expiry = expiry
|
||||
}
|
||||
// Mark as offline since node is reconnecting
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
// Save to database
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Return the node to satisfy the Write signature
|
||||
return hsdb.GetNodeByID(tx, existingNodeView.ID())
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err)
|
||||
}
|
||||
// If existing node found for same user, HandleNodeFromAuthPath will reuse its IPs
|
||||
|
||||
// Get updated node from NodeStore
|
||||
updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID())
|
||||
return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil
|
||||
}
|
||||
|
||||
// New node registration
|
||||
log.Debug().
|
||||
Str("registration_id", registrationID.String()).
|
||||
Str("username", user.Username()).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", expiry)).
|
||||
Msg("Registering new node from auth callback")
|
||||
|
||||
// Check if node exists with same machine key
|
||||
var existingMachineNode *types.Node
|
||||
if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() {
|
||||
existingMachineNode = nv.AsStruct()
|
||||
}
|
||||
|
||||
// Check for different user registration
|
||||
if existingMachineNode != nil && existingMachineNode.UserID != uint(userID) {
|
||||
return types.NodeView{}, change.EmptySet, hsdb.ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
// Prepare the node for registration
|
||||
nodeToRegister := regEntry.Node
|
||||
nodeToRegister.UserID = uint(userID)
|
||||
nodeToRegister.User = *user
|
||||
nodeToRegister.RegisterMethod = registrationMethod
|
||||
if expiry != nil {
|
||||
nodeToRegister.Expiry = expiry
|
||||
}
|
||||
|
||||
// Handle IP allocation
|
||||
var ipv4, ipv6 *netip.Addr
|
||||
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
||||
// Reuse existing IPs and properties
|
||||
nodeToRegister.ID = existingMachineNode.ID
|
||||
nodeToRegister.GivenName = existingMachineNode.GivenName
|
||||
nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes
|
||||
ipv4 = existingMachineNode.IPv4
|
||||
ipv6 = existingMachineNode.IPv6
|
||||
} else {
|
||||
// If no registration entry found, allocate new IPs (shouldn't happen in normal flow)
|
||||
// Allocate new IPs
|
||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
nodeToRegister.IPv4 = ipv4
|
||||
nodeToRegister.IPv6 = ipv6
|
||||
|
||||
// Ensure unique given name if not set
|
||||
if nodeToRegister.GivenName == "" {
|
||||
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||
}
|
||||
nodeToRegister.GivenName = givenName
|
||||
}
|
||||
|
||||
var savedNode *types.Node
|
||||
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
||||
// Update existing node - NodeStore first, then database
|
||||
s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) {
|
||||
node.NodeKey = nodeToRegister.NodeKey
|
||||
node.DiscoKey = nodeToRegister.DiscoKey
|
||||
node.Hostname = nodeToRegister.Hostname
|
||||
node.Hostinfo = nodeToRegister.Hostinfo
|
||||
node.Endpoints = nodeToRegister.Endpoints
|
||||
node.RegisterMethod = nodeToRegister.RegisterMethod
|
||||
if expiry != nil {
|
||||
node.Expiry = expiry
|
||||
}
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
// Save to database
|
||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
return &nodeToRegister, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
} else {
|
||||
// New node - database first to get ID, then NodeStore
|
||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
return &nodeToRegister, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Add to NodeStore after database creates the ID
|
||||
s.nodeStore.PutNode(*savedNode)
|
||||
}
|
||||
|
||||
node, nodeChange, err := s.db.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
userID,
|
||||
expiry,
|
||||
registrationMethod,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
// Delete from registration cache
|
||||
s.registrationCache.Delete(registrationID)
|
||||
|
||||
// Update NodeStore to ensure it has the latest node data
|
||||
// For re-registrations (key expiry), mark as offline since node is reconnecting
|
||||
// For new registrations, leave IsOnline as nil to let batcher manage connection state
|
||||
if nodeChange.Change == change.NodeKeyExpiry {
|
||||
// This is a re-registration/key refresh - node was disconnected and is coming back
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
// Signal to waiting clients
|
||||
select {
|
||||
case regEntry.Registered <- savedNode:
|
||||
default:
|
||||
}
|
||||
// For new registrations (NodeNewOrUpdate), don't set IsOnline - batcher manages it
|
||||
s.nodeStore.PutNode(*node)
|
||||
close(regEntry.Registered)
|
||||
|
||||
// Update policy manager with the new node if needed
|
||||
// Update policy manager
|
||||
nodesChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node.View(), nodeChange, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err)
|
||||
}
|
||||
|
||||
// If policy manager detected changes, use that instead
|
||||
if !nodesChange.Empty() {
|
||||
nodeChange = nodesChange
|
||||
return savedNode.View(), nodesChange, nil
|
||||
}
|
||||
|
||||
return node.View(), nodeChange, nil
|
||||
return savedNode.View(), change.NodeAdded(savedNode.ID), nil
|
||||
}
|
||||
|
||||
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
|
||||
@ -1038,6 +1161,9 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (types.NodeView, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
@ -1048,6 +1174,40 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Check if this is a logout request for an ephemeral node
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||
// Find the node to delete
|
||||
var nodeToDelete types.NodeView
|
||||
for _, nv := range s.nodeStore.ListNodes().All() {
|
||||
if nv.Valid() && nv.MachineKey() == machineKey {
|
||||
nodeToDelete = nv
|
||||
break
|
||||
}
|
||||
}
|
||||
if nodeToDelete.Valid() {
|
||||
c, err := s.DeleteNode(nodeToDelete)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
}
|
||||
return types.NodeView{}, c, nil
|
||||
}
|
||||
return types.NodeView{}, change.EmptySet, nil
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("node", regReq.Hostinfo.Hostname).
|
||||
Str("machine_key", machineKey.ShortString()).
|
||||
Str("node_key", regReq.NodeKey.ShortString()).
|
||||
Str("user", pak.User.Username()).
|
||||
Msg("Registering node with pre-auth key")
|
||||
|
||||
// Check if node already exists with same machine key
|
||||
var existingNode *types.Node
|
||||
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
|
||||
existingNode = nv.AsStruct()
|
||||
}
|
||||
|
||||
// Prepare the node for registration
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
UserID: pak.User.ID,
|
||||
@ -1057,103 +1217,132 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
||||
// TODO(kradalby): This should not be set on the node,
|
||||
// they should be looked up through the key, which is
|
||||
// attached to the node.
|
||||
ForcedTags: pak.Proto().GetAclTags(),
|
||||
AuthKey: pak,
|
||||
AuthKeyID: &pak.ID,
|
||||
ForcedTags: pak.Proto().GetAclTags(),
|
||||
AuthKey: pak,
|
||||
AuthKeyID: &pak.ID,
|
||||
}
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
}
|
||||
|
||||
// Check if node already exists with same machine key and user
|
||||
// to avoid allocating new IPs unnecessarily
|
||||
existingNode, _ := s.db.GetNodeByMachineKey(machineKey)
|
||||
// Handle IP allocation and existing node properties
|
||||
var ipv4, ipv6 *netip.Addr
|
||||
|
||||
// Only allocate new IPs if:
|
||||
// 1. No existing node found, OR
|
||||
// 2. Existing node belongs to a different user
|
||||
if existingNode == nil || existingNode.UserID != pak.User.ID {
|
||||
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
||||
// Reuse existing node properties
|
||||
nodeToRegister.ID = existingNode.ID
|
||||
nodeToRegister.GivenName = existingNode.GivenName
|
||||
nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes
|
||||
ipv4 = existingNode.IPv4
|
||||
ipv6 = existingNode.IPv6
|
||||
} else {
|
||||
// Allocate new IPs
|
||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
||||
}
|
||||
}
|
||||
// If existing node found for same user, RegisterNode will reuse its IPs
|
||||
|
||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
node, err := hsdb.RegisterNode(tx,
|
||||
nodeToRegister,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
nodeToRegister.IPv4 = ipv4
|
||||
nodeToRegister.IPv6 = ipv6
|
||||
|
||||
// Ensure unique given name if not set
|
||||
if nodeToRegister.GivenName == "" {
|
||||
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registering node: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||
}
|
||||
|
||||
if !pak.Reusable {
|
||||
err = hsdb.UsePreAuthKey(tx, pak)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("using pre auth key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||
nodeToRegister.GivenName = givenName
|
||||
}
|
||||
|
||||
// Check if this is a logout request for an ephemeral node
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||
// This is a logout request for an ephemeral node, delete it immediately
|
||||
c, err := s.DeleteNode(node.View())
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
}
|
||||
|
||||
return types.NodeView{}, c, nil
|
||||
}
|
||||
|
||||
// Update NodeStore BEFORE updating policy manager so it has the latest node data
|
||||
// CRITICAL: For re-registration of existing nodes, we must update NodeStore
|
||||
// to ensure it has the latest state from the database transaction
|
||||
// For re-registrations of existing nodes, mark as offline since they're reconnecting
|
||||
// For new registrations, leave IsOnline as nil to let batcher manage connection state
|
||||
var savedNode *types.Node
|
||||
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
||||
// This is a re-registration of existing node - was disconnected and is coming back
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
}
|
||||
// For new registrations, don't set IsOnline - batcher manages it
|
||||
s.nodeStore.PutNode(*node)
|
||||
// Update existing node - NodeStore first, then database
|
||||
s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
|
||||
node.NodeKey = nodeToRegister.NodeKey
|
||||
node.Hostname = nodeToRegister.Hostname
|
||||
node.Hostinfo = nodeToRegister.Hostinfo
|
||||
node.Endpoints = nodeToRegister.Endpoints
|
||||
node.RegisterMethod = nodeToRegister.RegisterMethod
|
||||
node.ForcedTags = nodeToRegister.ForcedTags
|
||||
node.AuthKey = nodeToRegister.AuthKey
|
||||
node.AuthKeyID = nodeToRegister.AuthKeyID
|
||||
if nodeToRegister.Expiry != nil {
|
||||
node.Expiry = nodeToRegister.Expiry
|
||||
}
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
// Check if policy manager needs updating
|
||||
// This is necessary because we just created a new node.
|
||||
// We need to ensure that the policy manager is aware of this new node.
|
||||
// Also update users to ensure all users are known when evaluating policies.
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", nodeToRegister.Hostname).
|
||||
Str("machine_key", machineKey.ShortString()).
|
||||
Str("node_key", regReq.NodeKey.ShortString()).
|
||||
Str("user", pak.User.Username()).
|
||||
Msg("Node re-authorized")
|
||||
|
||||
// Save to database
|
||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
|
||||
if !pak.Reusable {
|
||||
err = hsdb.UsePreAuthKey(tx, pak)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("using pre auth key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &nodeToRegister, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
} else {
|
||||
// New node - database first to get ID, then NodeStore
|
||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
|
||||
if !pak.Reusable {
|
||||
err = hsdb.UsePreAuthKey(tx, pak)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("using pre auth key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &nodeToRegister, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
|
||||
// Add to NodeStore after database creates the ID
|
||||
s.nodeStore.PutNode(*savedNode)
|
||||
}
|
||||
|
||||
// Update policy managers
|
||||
usersChange, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err)
|
||||
}
|
||||
|
||||
nodesChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err)
|
||||
}
|
||||
|
||||
var c change.ChangeSet
|
||||
if !usersChange.Empty() || !nodesChange.Empty() {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
c = change.NodeAdded(node.ID)
|
||||
c = change.NodeAdded(savedNode.ID)
|
||||
}
|
||||
|
||||
return node.View(), c, nil
|
||||
return savedNode.View(), c, nil
|
||||
}
|
||||
|
||||
// updatePolicyManagerUsers updates the policy manager with current users.
|
||||
@ -1289,7 +1478,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
req.Hostinfo.RoutableIPs,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// Log when routes change but approval doesn't
|
||||
if hostinfoChanged && req.Hostinfo != nil && routesChanged(currentNode.View(), req.Hostinfo) && !routeChange {
|
||||
log.Debug().
|
||||
@ -1341,7 +1530,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
})
|
||||
|
||||
nodeRouteChange := change.EmptySet
|
||||
|
||||
|
||||
// Handle route changes after NodeStore update
|
||||
// We need to update node routes if either:
|
||||
// 1. The approved routes changed (routeChange is true), OR
|
||||
@ -1360,7 +1549,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Msg("updating routes because announced routes changed but approved routes did not")
|
||||
}
|
||||
|
||||
|
||||
if needsRouteUpdate {
|
||||
// Get the updated node to access its subnet routes
|
||||
updatedNode, exists := s.GetNodeByID(id)
|
||||
|
Loading…
Reference in New Issue
Block a user