mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-14 13:51:01 +02:00
move registration into state
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
c24b988247
commit
d78c69e112
@ -327,115 +327,20 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
|
||||||
// with a registrationID to register or reauthenticate a node.
|
// RegisterNodeForTest is used only for testing purposes to register a node directly in the database.
|
||||||
// If the node found in the registration cache is not already registered,
|
// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey.
|
||||||
// it will be registered with the user and the node will be removed from the cache.
|
func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||||
// If the node is already registered, the expiry will be updated.
|
if !testing.Testing() {
|
||||||
// The node, and a boolean indicating if it was a new node or not, will be returned.
|
panic("RegisterNodeForTest can only be called during tests")
|
||||||
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|
||||||
registrationID types.RegistrationID,
|
|
||||||
userID types.UserID,
|
|
||||||
nodeExpiry *time.Time,
|
|
||||||
registrationMethod string,
|
|
||||||
ipv4 *netip.Addr,
|
|
||||||
ipv6 *netip.Addr,
|
|
||||||
) (*types.Node, change.ChangeSet, error) {
|
|
||||||
var nodeChange change.ChangeSet
|
|
||||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
|
||||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
|
||||||
user, err := GetUserByID(tx, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf(
|
|
||||||
"failed to find user in register node from auth callback, %w",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().
|
|
||||||
Str("registration_id", registrationID.String()).
|
|
||||||
Str("username", user.Username()).
|
|
||||||
Str("registrationMethod", registrationMethod).
|
|
||||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
|
||||||
Msg("Registering node from API/CLI or auth callback")
|
|
||||||
|
|
||||||
// TODO(kradalby): This looks quite wrong? why ID 0?
|
|
||||||
// Why not always?
|
|
||||||
// Registration of expired node with different user
|
|
||||||
if reg.Node.ID != 0 &&
|
|
||||||
reg.Node.UserID != user.ID {
|
|
||||||
return nil, ErrDifferentRegisteredUser
|
|
||||||
}
|
|
||||||
|
|
||||||
reg.Node.UserID = user.ID
|
|
||||||
reg.Node.User = *user
|
|
||||||
reg.Node.RegisterMethod = registrationMethod
|
|
||||||
|
|
||||||
if nodeExpiry != nil {
|
|
||||||
reg.Node.Expiry = nodeExpiry
|
|
||||||
}
|
|
||||||
|
|
||||||
node, err := RegisterNode(
|
|
||||||
tx,
|
|
||||||
reg.Node,
|
|
||||||
ipv4, ipv6,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
hsdb.regCache.Delete(registrationID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal to waiting clients that the machine has been registered.
|
|
||||||
select {
|
|
||||||
case reg.Registered <- node:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
close(reg.Registered)
|
|
||||||
|
|
||||||
nodeChange = change.NodeAdded(node.ID)
|
|
||||||
|
|
||||||
return node, err
|
|
||||||
} else {
|
|
||||||
// If the node is already registered, this is a refresh.
|
|
||||||
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CRITICAL: Reload the node to get the updated expiry
|
|
||||||
// Without this, we return stale node data to NodeStore
|
|
||||||
updatedNode, err := GetNodeByID(tx, node.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to reload node after expiry update: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeChange = change.KeyExpiry(node.ID)
|
|
||||||
|
|
||||||
return updatedNode, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrNodeNotFoundRegistrationCache
|
|
||||||
})
|
|
||||||
|
|
||||||
return node, nodeChange, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
return RegisterNode(tx, node, ipv4, ipv6)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
|
||||||
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str("node_key", node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Username()).
|
Str("user", node.User.Username()).
|
||||||
Msg("Registering node")
|
Msg("Registering test node")
|
||||||
|
|
||||||
// If the a new node is registered with the same machine key, to the same user,
|
// If the a new node is registered with the same machine key, to the same user,
|
||||||
// update the existing node.
|
// update the existing node.
|
||||||
@ -469,7 +374,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str("node_key", node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Username()).
|
Str("user", node.User.Username()).
|
||||||
Msg("Node authorized again")
|
Msg("Test node authorized again")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
@ -478,7 +383,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
node.IPv6 = ipv6
|
node.IPv6 = ipv6
|
||||||
|
|
||||||
if node.GivenName == "" {
|
if node.GivenName == "" {
|
||||||
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
|
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||||
}
|
}
|
||||||
@ -493,7 +398,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Node registered with the database")
|
Msg("Test node registered with the database")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
@ -566,7 +471,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
|||||||
return len(nodes) == 0, nil
|
return len(nodes) == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureUniqueGivenName(
|
// EnsureUniqueGivenName generates a unique given name for a node based on its hostname.
|
||||||
|
func EnsureUniqueGivenName(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
name string,
|
name string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
@ -797,7 +703,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
|||||||
var registeredNode *types.Node
|
var registeredNode *types.Node
|
||||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
var err error
|
var err error
|
||||||
registeredNode, err = RegisterNode(tx, *node, ipv4, ipv6)
|
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -679,11 +679,11 @@ func TestRenameNode(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node, nil, nil)
|
_, err := RegisterNodeForTest(tx, node, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
@ -780,11 +780,11 @@ func TestListPeers(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node1, nil, nil)
|
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
@ -865,11 +865,11 @@ func TestListNodes(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node1, nil, nil)
|
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
"tailscale.com/util/multierr"
|
"tailscale.com/util/multierr"
|
||||||
)
|
)
|
||||||
@ -66,9 +65,9 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
|||||||
// Always use batcher's view of online status for self node
|
// Always use batcher's view of online status for self node
|
||||||
// The batcher respects grace periods for logout scenarios
|
// The batcher respects grace periods for logout scenarios
|
||||||
node := nodeView.AsStruct()
|
node := nodeView.AsStruct()
|
||||||
if b.mapper.batcher != nil {
|
// if b.mapper.batcher != nil {
|
||||||
node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||||
}
|
// }
|
||||||
|
|
||||||
_, matchers := b.mapper.state.Filter()
|
_, matchers := b.mapper.state.Filter()
|
||||||
tailnode, err := tailNode(
|
tailnode, err := tailNode(
|
||||||
|
@ -175,7 +175,7 @@ func (m *mapper) fullMapResponse(
|
|||||||
peers := m.state.ListPeers(nodeID)
|
peers := m.state.ListPeers(nodeID)
|
||||||
|
|
||||||
// Add fresh online status to peers from batcher connection state
|
// Add fresh online status to peers from batcher connection state
|
||||||
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
// peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||||
|
|
||||||
return m.NewMapResponseBuilder(nodeID).
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
@ -186,9 +186,9 @@ func (m *mapper) fullMapResponse(
|
|||||||
WithDebugConfig().
|
WithDebugConfig().
|
||||||
WithSSHPolicy().
|
WithSSHPolicy().
|
||||||
WithDNSConfig().
|
WithDNSConfig().
|
||||||
WithUserProfiles(peersWithOnlineStatus).
|
WithUserProfiles(peers).
|
||||||
WithPacketFilters().
|
WithPacketFilters().
|
||||||
WithPeers(peersWithOnlineStatus).
|
WithPeers(peers).
|
||||||
Build(messages...)
|
Build(messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,13 +220,13 @@ func (m *mapper) peerChangeResponse(
|
|||||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||||
|
|
||||||
// Add fresh online status to peers from batcher connection state
|
// Add fresh online status to peers from batcher connection state
|
||||||
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
// peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||||
|
|
||||||
return m.NewMapResponseBuilder(nodeID).
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
WithSelfNode().
|
WithSelfNode().
|
||||||
WithUserProfiles(peersWithOnlineStatus).
|
WithUserProfiles(peers).
|
||||||
WithPeerChanges(peersWithOnlineStatus).
|
WithPeerChanges(peers).
|
||||||
Build()
|
Build()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,6 +381,25 @@ func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bo
|
|||||||
return nodeView, exists
|
return nodeView, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists.
|
||||||
|
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc()
|
||||||
|
|
||||||
|
snapshot := s.data.Load()
|
||||||
|
// We don't have a byMachineKey map, so we need to iterate
|
||||||
|
// This could be optimized by adding a byMachineKey map if this becomes a hot path
|
||||||
|
for _, node := range snapshot.nodesByID {
|
||||||
|
if node.MachineKey == machineKey {
|
||||||
|
return node.View(), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return types.NodeView{}, false
|
||||||
|
}
|
||||||
|
|
||||||
// DebugString returns debug information about the NodeStore.
|
// DebugString returns debug information about the NodeStore.
|
||||||
func (s *NodeStore) DebugString() string {
|
func (s *NodeStore) DebugString() string {
|
||||||
snapshot := s.data.Load()
|
snapshot := s.data.Load()
|
||||||
|
@ -519,6 +519,14 @@ func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool)
|
|||||||
return s.nodeStore.GetNodeByNodeKey(nodeKey)
|
return s.nodeStore.GetNodeByNodeKey(nodeKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNodeByMachineKey retrieves a node by its machine key.
|
||||||
|
// The bool indicates if the node exists or is available (like "err not found").
|
||||||
|
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||||
|
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||||
|
func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||||
|
return s.nodeStore.GetNodeByMachineKey(machineKey)
|
||||||
|
}
|
||||||
|
|
||||||
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
|
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
|
||||||
func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] {
|
func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] {
|
||||||
if len(nodeIDs) == 0 {
|
if len(nodeIDs) == 0 {
|
||||||
@ -790,6 +798,8 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
|||||||
if exists && existingNode.Valid() {
|
if exists && existingNode.Valid() {
|
||||||
node.IsOnline = ptr.To(existingNode.IsOnline().Get())
|
node.IsOnline = ptr.To(existingNode.IsOnline().Get())
|
||||||
}
|
}
|
||||||
|
// TODO(kradalby): This should just update the IP addresses, nothing else in the node store.
|
||||||
|
// We should avoid PutNode here.
|
||||||
s.nodeStore.PutNode(*node)
|
s.nodeStore.PutNode(*node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -969,68 +979,181 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
expiry *time.Time,
|
expiry *time.Time,
|
||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
) (types.NodeView, change.ChangeSet, error) {
|
) (types.NodeView, change.ChangeSet, error) {
|
||||||
// Get the registration entry to check the machine key
|
s.mu.Lock()
|
||||||
var ipv4, ipv6 *netip.Addr
|
defer s.mu.Unlock()
|
||||||
var err error
|
|
||||||
|
|
||||||
// Check if we have the registration entry to determine if we should reuse IPs
|
// Get the registration entry from cache
|
||||||
if regEntry, ok := s.GetRegistrationCacheEntry(registrationID); ok {
|
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
|
||||||
// Check if node already exists with same machine key and user
|
if !ok {
|
||||||
// to avoid allocating new IPs unnecessarily
|
return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache
|
||||||
existingNode, _ := s.db.GetNodeByMachineKey(regEntry.Node.MachineKey)
|
}
|
||||||
|
|
||||||
// Only allocate new IPs if:
|
// Get the user
|
||||||
// 1. No existing node found, OR
|
user, err := s.db.GetUserByID(userID)
|
||||||
// 2. Existing node belongs to a different user
|
|
||||||
if existingNode == nil || existingNode.UserID != uint(userID) {
|
|
||||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NodeView{}, change.EmptySet, err
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
|
||||||
}
|
|
||||||
}
|
|
||||||
// If existing node found for same user, HandleNodeFromAuthPath will reuse its IPs
|
|
||||||
} else {
|
|
||||||
// If no registration entry found, allocate new IPs (shouldn't happen in normal flow)
|
|
||||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
node, nodeChange, err := s.db.HandleNodeFromAuthPath(
|
// Check if node already exists by node key
|
||||||
registrationID,
|
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey)
|
||||||
userID,
|
if exists && existingNodeView.Valid() {
|
||||||
expiry,
|
// Node exists - this is a refresh/re-registration
|
||||||
registrationMethod,
|
log.Debug().
|
||||||
ipv4, ipv6,
|
Str("registration_id", registrationID.String()).
|
||||||
)
|
Str("username", user.Username()).
|
||||||
if err != nil {
|
Str("registrationMethod", registrationMethod).
|
||||||
return types.NodeView{}, change.EmptySet, err
|
Str("node", existingNodeView.Hostname()).
|
||||||
}
|
Msg("Refreshing existing node registration")
|
||||||
|
|
||||||
// Update NodeStore to ensure it has the latest node data
|
// Update NodeStore first with the new expiry
|
||||||
// For re-registrations (key expiry), mark as offline since node is reconnecting
|
s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) {
|
||||||
// For new registrations, leave IsOnline as nil to let batcher manage connection state
|
if expiry != nil {
|
||||||
if nodeChange.Change == change.NodeKeyExpiry {
|
node.Expiry = expiry
|
||||||
// This is a re-registration/key refresh - node was disconnected and is coming back
|
}
|
||||||
|
// Mark as offline since node is reconnecting
|
||||||
node.IsOnline = ptr.To(false)
|
node.IsOnline = ptr.To(false)
|
||||||
node.LastSeen = ptr.To(time.Now())
|
node.LastSeen = ptr.To(time.Now())
|
||||||
}
|
})
|
||||||
// For new registrations (NodeNewOrUpdate), don't set IsOnline - batcher manages it
|
|
||||||
s.nodeStore.PutNode(*node)
|
|
||||||
|
|
||||||
// Update policy manager with the new node if needed
|
// Save to database
|
||||||
|
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Return the node to satisfy the Write signature
|
||||||
|
return hsdb.GetNodeByID(tx, existingNodeView.ID())
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get updated node from NodeStore
|
||||||
|
updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID())
|
||||||
|
return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New node registration
|
||||||
|
log.Debug().
|
||||||
|
Str("registration_id", registrationID.String()).
|
||||||
|
Str("username", user.Username()).
|
||||||
|
Str("registrationMethod", registrationMethod).
|
||||||
|
Str("expiresAt", fmt.Sprintf("%v", expiry)).
|
||||||
|
Msg("Registering new node from auth callback")
|
||||||
|
|
||||||
|
// Check if node exists with same machine key
|
||||||
|
var existingMachineNode *types.Node
|
||||||
|
if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() {
|
||||||
|
existingMachineNode = nv.AsStruct()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for different user registration
|
||||||
|
if existingMachineNode != nil && existingMachineNode.UserID != uint(userID) {
|
||||||
|
return types.NodeView{}, change.EmptySet, hsdb.ErrDifferentRegisteredUser
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the node for registration
|
||||||
|
nodeToRegister := regEntry.Node
|
||||||
|
nodeToRegister.UserID = uint(userID)
|
||||||
|
nodeToRegister.User = *user
|
||||||
|
nodeToRegister.RegisterMethod = registrationMethod
|
||||||
|
if expiry != nil {
|
||||||
|
nodeToRegister.Expiry = expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle IP allocation
|
||||||
|
var ipv4, ipv6 *netip.Addr
|
||||||
|
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
||||||
|
// Reuse existing IPs and properties
|
||||||
|
nodeToRegister.ID = existingMachineNode.ID
|
||||||
|
nodeToRegister.GivenName = existingMachineNode.GivenName
|
||||||
|
nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes
|
||||||
|
ipv4 = existingMachineNode.IPv4
|
||||||
|
ipv6 = existingMachineNode.IPv6
|
||||||
|
} else {
|
||||||
|
// Allocate new IPs
|
||||||
|
ipv4, ipv6, err = s.ipAlloc.Next()
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeToRegister.IPv4 = ipv4
|
||||||
|
nodeToRegister.IPv6 = ipv6
|
||||||
|
|
||||||
|
// Ensure unique given name if not set
|
||||||
|
if nodeToRegister.GivenName == "" {
|
||||||
|
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||||
|
}
|
||||||
|
nodeToRegister.GivenName = givenName
|
||||||
|
}
|
||||||
|
|
||||||
|
var savedNode *types.Node
|
||||||
|
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
||||||
|
// Update existing node - NodeStore first, then database
|
||||||
|
s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) {
|
||||||
|
node.NodeKey = nodeToRegister.NodeKey
|
||||||
|
node.DiscoKey = nodeToRegister.DiscoKey
|
||||||
|
node.Hostname = nodeToRegister.Hostname
|
||||||
|
node.Hostinfo = nodeToRegister.Hostinfo
|
||||||
|
node.Endpoints = nodeToRegister.Endpoints
|
||||||
|
node.RegisterMethod = nodeToRegister.RegisterMethod
|
||||||
|
if expiry != nil {
|
||||||
|
node.Expiry = expiry
|
||||||
|
}
|
||||||
|
node.IsOnline = ptr.To(false)
|
||||||
|
node.LastSeen = ptr.To(time.Now())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Save to database
|
||||||
|
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||||
|
}
|
||||||
|
return &nodeToRegister, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New node - database first to get ID, then NodeStore
|
||||||
|
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||||
|
}
|
||||||
|
return &nodeToRegister, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to NodeStore after database creates the ID
|
||||||
|
s.nodeStore.PutNode(*savedNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete from registration cache
|
||||||
|
s.registrationCache.Delete(registrationID)
|
||||||
|
|
||||||
|
// Signal to waiting clients
|
||||||
|
select {
|
||||||
|
case regEntry.Registered <- savedNode:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
close(regEntry.Registered)
|
||||||
|
|
||||||
|
// Update policy manager
|
||||||
nodesChange, err := s.updatePolicyManagerNodes()
|
nodesChange, err := s.updatePolicyManagerNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return node.View(), nodeChange, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If policy manager detected changes, use that instead
|
|
||||||
if !nodesChange.Empty() {
|
if !nodesChange.Empty() {
|
||||||
nodeChange = nodesChange
|
return savedNode.View(), nodesChange, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return node.View(), nodeChange, nil
|
return savedNode.View(), change.NodeAdded(savedNode.ID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
|
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
|
||||||
@ -1038,6 +1161,9 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (types.NodeView, change.ChangeSet, error) {
|
) (types.NodeView, change.ChangeSet, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NodeView{}, change.EmptySet, err
|
return types.NodeView{}, change.EmptySet, err
|
||||||
@ -1048,6 +1174,40 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
return types.NodeView{}, change.EmptySet, err
|
return types.NodeView{}, change.EmptySet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is a logout request for an ephemeral node
|
||||||
|
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||||
|
// Find the node to delete
|
||||||
|
var nodeToDelete types.NodeView
|
||||||
|
for _, nv := range s.nodeStore.ListNodes().All() {
|
||||||
|
if nv.Valid() && nv.MachineKey() == machineKey {
|
||||||
|
nodeToDelete = nv
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nodeToDelete.Valid() {
|
||||||
|
c, err := s.DeleteNode(nodeToDelete)
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||||
|
}
|
||||||
|
return types.NodeView{}, c, nil
|
||||||
|
}
|
||||||
|
return types.NodeView{}, change.EmptySet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("node", regReq.Hostinfo.Hostname).
|
||||||
|
Str("machine_key", machineKey.ShortString()).
|
||||||
|
Str("node_key", regReq.NodeKey.ShortString()).
|
||||||
|
Str("user", pak.User.Username()).
|
||||||
|
Msg("Registering node with pre-auth key")
|
||||||
|
|
||||||
|
// Check if node already exists with same machine key
|
||||||
|
var existingNode *types.Node
|
||||||
|
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
|
||||||
|
existingNode = nv.AsStruct()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the node for registration
|
||||||
nodeToRegister := types.Node{
|
nodeToRegister := types.Node{
|
||||||
Hostname: regReq.Hostinfo.Hostname,
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
UserID: pak.User.ID,
|
UserID: pak.User.ID,
|
||||||
@ -1057,10 +1217,6 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
Hostinfo: regReq.Hostinfo,
|
Hostinfo: regReq.Hostinfo,
|
||||||
LastSeen: ptr.To(time.Now()),
|
LastSeen: ptr.To(time.Now()),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
|
||||||
// TODO(kradalby): This should not be set on the node,
|
|
||||||
// they should be looked up through the key, which is
|
|
||||||
// attached to the node.
|
|
||||||
ForcedTags: pak.Proto().GetAclTags(),
|
ForcedTags: pak.Proto().GetAclTags(),
|
||||||
AuthKey: pak,
|
AuthKey: pak,
|
||||||
AuthKeyID: &pak.ID,
|
AuthKeyID: &pak.ID,
|
||||||
@ -1070,29 +1226,66 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
nodeToRegister.Expiry = ®Req.Expiry
|
nodeToRegister.Expiry = ®Req.Expiry
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if node already exists with same machine key and user
|
// Handle IP allocation and existing node properties
|
||||||
// to avoid allocating new IPs unnecessarily
|
|
||||||
existingNode, _ := s.db.GetNodeByMachineKey(machineKey)
|
|
||||||
var ipv4, ipv6 *netip.Addr
|
var ipv4, ipv6 *netip.Addr
|
||||||
|
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
||||||
// Only allocate new IPs if:
|
// Reuse existing node properties
|
||||||
// 1. No existing node found, OR
|
nodeToRegister.ID = existingNode.ID
|
||||||
// 2. Existing node belongs to a different user
|
nodeToRegister.GivenName = existingNode.GivenName
|
||||||
if existingNode == nil || existingNode.UserID != pak.User.ID {
|
nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes
|
||||||
|
ipv4 = existingNode.IPv4
|
||||||
|
ipv6 = existingNode.IPv6
|
||||||
|
} else {
|
||||||
|
// Allocate new IPs
|
||||||
ipv4, ipv6, err = s.ipAlloc.Next()
|
ipv4, ipv6, err = s.ipAlloc.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If existing node found for same user, RegisterNode will reuse its IPs
|
|
||||||
|
|
||||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
nodeToRegister.IPv4 = ipv4
|
||||||
node, err := hsdb.RegisterNode(tx,
|
nodeToRegister.IPv6 = ipv6
|
||||||
nodeToRegister,
|
|
||||||
ipv4, ipv6,
|
// Ensure unique given name if not set
|
||||||
)
|
if nodeToRegister.GivenName == "" {
|
||||||
|
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("registering node: %w", err)
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||||
|
}
|
||||||
|
nodeToRegister.GivenName = givenName
|
||||||
|
}
|
||||||
|
|
||||||
|
var savedNode *types.Node
|
||||||
|
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
||||||
|
// Update existing node - NodeStore first, then database
|
||||||
|
s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
|
||||||
|
node.NodeKey = nodeToRegister.NodeKey
|
||||||
|
node.Hostname = nodeToRegister.Hostname
|
||||||
|
node.Hostinfo = nodeToRegister.Hostinfo
|
||||||
|
node.Endpoints = nodeToRegister.Endpoints
|
||||||
|
node.RegisterMethod = nodeToRegister.RegisterMethod
|
||||||
|
node.ForcedTags = nodeToRegister.ForcedTags
|
||||||
|
node.AuthKey = nodeToRegister.AuthKey
|
||||||
|
node.AuthKeyID = nodeToRegister.AuthKeyID
|
||||||
|
if nodeToRegister.Expiry != nil {
|
||||||
|
node.Expiry = nodeToRegister.Expiry
|
||||||
|
}
|
||||||
|
node.IsOnline = ptr.To(false)
|
||||||
|
node.LastSeen = ptr.To(time.Now())
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("node", nodeToRegister.Hostname).
|
||||||
|
Str("machine_key", machineKey.ShortString()).
|
||||||
|
Str("node_key", regReq.NodeKey.ShortString()).
|
||||||
|
Str("user", pak.User.Username()).
|
||||||
|
Msg("Node re-authorized")
|
||||||
|
|
||||||
|
// Save to database
|
||||||
|
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !pak.Reusable {
|
if !pak.Reusable {
|
||||||
@ -1102,58 +1295,54 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return node, nil
|
return &nodeToRegister, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New node - database first to get ID, then NodeStore
|
||||||
|
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pak.Reusable {
|
||||||
|
err = hsdb.UsePreAuthKey(tx, pak)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("using pre auth key: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &nodeToRegister, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is a logout request for an ephemeral node
|
// Add to NodeStore after database creates the ID
|
||||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
s.nodeStore.PutNode(*savedNode)
|
||||||
// This is a logout request for an ephemeral node, delete it immediately
|
|
||||||
c, err := s.DeleteNode(node.View())
|
|
||||||
if err != nil {
|
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.NodeView{}, c, nil
|
// Update policy managers
|
||||||
}
|
|
||||||
|
|
||||||
// Update NodeStore BEFORE updating policy manager so it has the latest node data
|
|
||||||
// CRITICAL: For re-registration of existing nodes, we must update NodeStore
|
|
||||||
// to ensure it has the latest state from the database transaction
|
|
||||||
// For re-registrations of existing nodes, mark as offline since they're reconnecting
|
|
||||||
// For new registrations, leave IsOnline as nil to let batcher manage connection state
|
|
||||||
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
|
||||||
// This is a re-registration of existing node - was disconnected and is coming back
|
|
||||||
node.IsOnline = ptr.To(false)
|
|
||||||
node.LastSeen = ptr.To(time.Now())
|
|
||||||
}
|
|
||||||
// For new registrations, don't set IsOnline - batcher manages it
|
|
||||||
s.nodeStore.PutNode(*node)
|
|
||||||
|
|
||||||
// Check if policy manager needs updating
|
|
||||||
// This is necessary because we just created a new node.
|
|
||||||
// We need to ensure that the policy manager is aware of this new node.
|
|
||||||
// Also update users to ensure all users are known when evaluating policies.
|
|
||||||
usersChange, err := s.updatePolicyManagerUsers()
|
usersChange, err := s.updatePolicyManagerUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodesChange, err := s.updatePolicyManagerNodes()
|
nodesChange, err := s.updatePolicyManagerNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var c change.ChangeSet
|
var c change.ChangeSet
|
||||||
if !usersChange.Empty() || !nodesChange.Empty() {
|
if !usersChange.Empty() || !nodesChange.Empty() {
|
||||||
c = change.PolicyChange()
|
c = change.PolicyChange()
|
||||||
} else {
|
} else {
|
||||||
c = change.NodeAdded(node.ID)
|
c = change.NodeAdded(savedNode.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return node.View(), c, nil
|
return savedNode.View(), c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updatePolicyManagerUsers updates the policy manager with current users.
|
// updatePolicyManagerUsers updates the policy manager with current users.
|
||||||
|
Loading…
Reference in New Issue
Block a user