diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fdc63461..3bbba29d 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -313,6 +313,7 @@ func (h *Headscale) reqToNewRegisterResponse( MachineKey: machineKey, NodeKey: req.NodeKey, Hostinfo: hostinfo, + Ephemeral: req.Ephemeral, LastSeen: new(time.Now()), }, ) @@ -440,6 +441,7 @@ func (h *Headscale) handleRegisterInteractive( MachineKey: machineKey, NodeKey: req.NodeKey, Hostinfo: hostinfo, + Ephemeral: req.Ephemeral, LastSeen: new(time.Now()), }, ) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index d28ed565..219525bb 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -342,6 +342,35 @@ func TestAuthenticationFlows(t *testing.T) { validateCompleteResponse: true, expectedAuthURLPattern: "/register/", }, + { + name: "full_interactive_workflow_ephemeral_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Ephemeral: true, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-ephemeral-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + validate: func(t *testing.T, _ *tailcfg.RegisterResponse, app *Headscale) { + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node.IsEphemeral(), "interactive node with state=mem should be marked ephemeral") + }, + }, // TEST: Interactive workflow with no Auth struct in request // WHAT: Tests interactive flow when request has no Auth field (nil) // INPUT: Register request with Auth field set to nil diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index c518502e..f0e841c3 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -352,6 +352,7 @@ AND auth_key_id NOT IN ( given_name varchar(63), user_id integer, register_method text, + ephemeral numeric DEFAULT false, forced_tags text, auth_key_id integer, last_seen datetime, @@ -393,8 +394,8 @@ AND auth_key_id NOT IN ( SELECT id, prefix, hash, expiration, last_seen, created_at FROM api_keys_old`, - `INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at) - SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at + `INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, ephemeral, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at) + SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, false, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at FROM nodes_old`, `INSERT INTO policies (id, data, created_at, updated_at, deleted_at) @@ -699,6 +700,32 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + ID: "202602111000-node-ephemeral-column", + Migrate: func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&types.Node{}, "ephemeral") { + err := tx.Migrator().AddColumn(&types.Node{}, "ephemeral") + if err != nil { + return fmt.Errorf("adding nodes.ephemeral column: %w", err) + } + } + + // Backfill existing auth-key ephemeral nodes so historical data keeps behavior. + err := tx.Exec(` +UPDATE nodes +SET ephemeral = true +WHERE auth_key_id IN ( + SELECT id FROM pre_auth_keys WHERE ephemeral = true +); + `).Error + if err != nil { + return fmt.Errorf("backfilling nodes.ephemeral from pre_auth_keys: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 98b73551..a9c1cdf0 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -100,7 +100,10 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes := types.Nodes{} - err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error + err := rx. + Joins("LEFT JOIN pre_auth_keys pak ON pak.id = nodes.auth_key_id"). + Where("nodes.ephemeral = ? OR pak.ephemeral = ?", true, true). + Find(&nodes).Error if err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 5991d494..ce0ee4f4 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -672,25 +672,44 @@ func TestListEphemeralNodes(t *testing.T) { AuthKeyID: new(pakEph.ID), } + nodeEphNoAuthKey := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "ephemeral-oidc", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodOIDC, + Ephemeral: true, + } + err = db.DB.Save(&node).Error require.NoError(t, err) err = db.DB.Save(&nodeEph).Error require.NoError(t, err) + err = db.DB.Save(&nodeEphNoAuthKey).Error + require.NoError(t, err) + nodes, err := db.ListNodes() require.NoError(t, err) ephemeralNodes, err := db.ListEphemeralNodes() require.NoError(t, err) - assert.Len(t, nodes, 2) - assert.Len(t, ephemeralNodes, 1) + assert.Len(t, nodes, 3) + assert.Len(t, ephemeralNodes, 2) - assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID) - assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID) - assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) - assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) + ephemeralByHostname := map[string]types.Node{} + for _, ephemeralNode := range ephemeralNodes { + ephemeralByHostname[ephemeralNode.Hostname] = *ephemeralNode + } + + require.Contains(t, ephemeralByHostname, nodeEph.Hostname) + require.Contains(t, ephemeralByHostname, nodeEphNoAuthKey.Hostname) + + assert.Equal(t, nodeEph.AuthKeyID, ephemeralByHostname[nodeEph.Hostname].AuthKeyID) + assert.Equal(t, nodeEphNoAuthKey.AuthKeyID, ephemeralByHostname[nodeEphNoAuthKey.Hostname].AuthKeyID) } func TestNodeNaming(t *testing.T) { diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index ef0a2a0e..730582a1 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -81,6 +81,7 @@ CREATE TABLE nodes( given_name varchar(63), user_id integer, register_method text, + ephemeral numeric DEFAULT false, tags text, auth_key_id integer, last_seen datetime, diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index abb34eb0..6732820b 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -91,7 +91,7 @@ func (s *State) DebugOverview() string { expiredCount++ } - if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + if node.IsEphemeral() { ephemeralCount++ } } @@ -302,7 +302,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { info.Nodes.Expired++ } - if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + if node.IsEphemeral() { info.Nodes.Ephemeral++ } } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 65ebf905..ec149661 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -629,8 +629,7 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { var ephemeralNodes []types.NodeView for _, node := range allNodes.All() { - // Check if node is ephemeral by checking its AuthKey - if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + if node.IsEphemeral() { ephemeralNodes = append(ephemeralNodes, node) } } @@ -1101,6 +1100,7 @@ type newNodeParams struct { Endpoints []netip.AddrPort Expiry *time.Time RegisterMethod string + Ephemeral bool // Optional: Pre-auth key specific fields PreAuthKey *types.PreAuthKey @@ -1191,6 +1191,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView } else { node.RegisterMethod = params.RegEntry.Node.RegisterMethod } + node.Ephemeral = params.RegEntry.Node.Ephemeral // Track tagged status BEFORE processing tags wasTagged := node.IsTagged() @@ -1286,6 +1287,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro Endpoints: params.Endpoints, LastSeen: new(time.Now()), RegisterMethod: params.RegisterMethod, + Ephemeral: params.Ephemeral, Expiry: params.Expiry, } @@ -1314,6 +1316,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro nodeToRegister.AuthKey = params.PreAuthKey nodeToRegister.AuthKeyID = ¶ms.PreAuthKey.ID + nodeToRegister.Ephemeral = params.PreAuthKey.Ephemeral } else { // Non-PreAuthKey registration (OIDC, CLI) - always user-owned nodeToRegister.UserID = ¶ms.User.ID @@ -1676,6 +1679,7 @@ func (s *State) createNewNodeFromAuth( Endpoints: regEntry.Node.Endpoints, Expiry: cmp.Or(expiry, regEntry.Node.Expiry), RegisterMethod: registrationMethod, + Ephemeral: regEntry.Node.Ephemeral, ExistingNodeForNetinfo: existingNodeForNetinfo, }) } @@ -1820,6 +1824,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID + node.Ephemeral = pak.Ephemeral node.IsOnline = new(false) node.LastSeen = new(time.Now()) @@ -1910,6 +1915,7 @@ func (s *State) HandleNodeFromPreAuthKey( Endpoints: nil, // Endpoints not available in RegisterRequest Expiry: ®Req.Expiry, RegisterMethod: util.RegisterMethodAuthKey, + Ephemeral: pak.Ephemeral, PreAuthKey: pak, ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), }) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index e705a33a..4993d82e 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -113,6 +113,7 @@ type Node struct { User *User `gorm:"constraint:OnDelete:CASCADE;"` RegisterMethod string + Ephemeral bool `gorm:"default:false"` // Tags is the definitive owner for tagged nodes. // When non-empty, the node is "tagged" and tags define its identity. @@ -181,7 +182,7 @@ func (node *Node) IsExpired() bool { // IsEphemeral returns if the node is registered as an Ephemeral node. // https://tailscale.com/kb/1111/ephemeral-nodes/ func (node *Node) IsEphemeral() bool { - return node.AuthKey != nil && node.AuthKey.Ephemeral + return node.Ephemeral || (node.AuthKey != nil && node.AuthKey.Ephemeral) } func (node *Node) IPs() []netip.Addr { diff --git a/hscontrol/types/types_clone.go b/hscontrol/types/types_clone.go index 4dfeedc2..feec6b55 100644 --- a/hscontrol/types/types_clone.go +++ b/hscontrol/types/types_clone.go @@ -96,6 +96,7 @@ var _NodeCloneNeedsRegeneration = Node(struct { UserID *uint User *User RegisterMethod string + Ephemeral bool Tags []string AuthKeyID *uint64 AuthKey *PreAuthKey diff --git a/hscontrol/types/types_view.go b/hscontrol/types/types_view.go index e48dd029..7f6cd055 100644 --- a/hscontrol/types/types_view.go +++ b/hscontrol/types/types_view.go @@ -223,6 +223,7 @@ func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerO func (v NodeView) User() UserView { return v.ж.User.View() } func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod } +func (v NodeView) Ephemeral() bool { return v.ж.Ephemeral } // Tags is the definitive owner for tagged nodes. // When non-empty, the node is "tagged" and tags define its identity. @@ -277,6 +278,7 @@ var _NodeViewNeedsRegeneration = Node(struct { UserID *uint User *User RegisterMethod string + Ephemeral bool Tags []string AuthKeyID *uint64 AuthKey *PreAuthKey