1
0
mirror of https://github.com/juanfont/headscale.git synced 2026-02-07 20:04:00 +01:00

fix(auth): persist ephemeral flag for interactive nodes

This commit is contained in:
Louis Shawn 2026-02-11 10:03:48 +08:00
parent e0d8c3c877
commit a14d4601e4
11 changed files with 105 additions and 14 deletions

View File

@ -313,6 +313,7 @@ func (h *Headscale) reqToNewRegisterResponse(
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
Ephemeral: req.Ephemeral,
LastSeen: new(time.Now()),
},
)
@ -440,6 +441,7 @@ func (h *Headscale) handleRegisterInteractive(
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
Ephemeral: req.Ephemeral,
LastSeen: new(time.Now()),
},
)

View File

@ -342,6 +342,35 @@ func TestAuthenticationFlows(t *testing.T) {
validateCompleteResponse: true,
expectedAuthURLPattern: "/register/",
},
{
name: "full_interactive_workflow_ephemeral_node",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
return "", nil
},
request: func(_ string) tailcfg.RegisterRequest {
return tailcfg.RegisterRequest{
NodeKey: nodeKey1.Public(),
Ephemeral: true,
Hostinfo: &tailcfg.Hostinfo{
Hostname: "interactive-ephemeral-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
{stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false},
},
validateCompleteResponse: true,
expectedAuthURLPattern: "/register/",
validate: func(t *testing.T, _ *tailcfg.RegisterResponse, app *Headscale) {
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
assert.True(t, node.IsEphemeral(), "interactive node with state=mem should be marked ephemeral")
},
},
// TEST: Interactive workflow with no Auth struct in request
// WHAT: Tests interactive flow when request has no Auth field (nil)
// INPUT: Register request with Auth field set to nil

View File

@ -352,6 +352,7 @@ AND auth_key_id NOT IN (
given_name varchar(63),
user_id integer,
register_method text,
ephemeral numeric DEFAULT false,
forced_tags text,
auth_key_id integer,
last_seen datetime,
@ -393,8 +394,8 @@ AND auth_key_id NOT IN (
SELECT id, prefix, hash, expiration, last_seen, created_at
FROM api_keys_old`,
`INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at)
SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at
`INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, ephemeral, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at)
SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, false, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at
FROM nodes_old`,
`INSERT INTO policies (id, data, created_at, updated_at, deleted_at)
@ -699,6 +700,32 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
ID: "202602111000-node-ephemeral-column",
Migrate: func(tx *gorm.DB) error {
if !tx.Migrator().HasColumn(&types.Node{}, "ephemeral") {
err := tx.Migrator().AddColumn(&types.Node{}, "ephemeral")
if err != nil {
return fmt.Errorf("adding nodes.ephemeral column: %w", err)
}
}
// Backfill existing auth-key ephemeral nodes so historical data keeps behavior.
err := tx.Exec(`
UPDATE nodes
SET ephemeral = true
WHERE auth_key_id IN (
SELECT id FROM pre_auth_keys WHERE ephemeral = true
);
`).Error
if err != nil {
return fmt.Errorf("backfilling nodes.ephemeral from pre_auth_keys: %w", err)
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View File

@ -100,7 +100,10 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{}
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
err := rx.
Joins("LEFT JOIN pre_auth_keys pak ON pak.id = nodes.auth_key_id").
Where("nodes.ephemeral = ? OR pak.ephemeral = ?", true, true).
Find(&nodes).Error
if err != nil {
return nil, err
}

View File

@ -672,25 +672,44 @@ func TestListEphemeralNodes(t *testing.T) {
AuthKeyID: new(pakEph.ID),
}
nodeEphNoAuthKey := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "ephemeral-oidc",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodOIDC,
Ephemeral: true,
}
err = db.DB.Save(&node).Error
require.NoError(t, err)
err = db.DB.Save(&nodeEph).Error
require.NoError(t, err)
err = db.DB.Save(&nodeEphNoAuthKey).Error
require.NoError(t, err)
nodes, err := db.ListNodes()
require.NoError(t, err)
ephemeralNodes, err := db.ListEphemeralNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Len(t, ephemeralNodes, 1)
assert.Len(t, nodes, 3)
assert.Len(t, ephemeralNodes, 2)
assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID)
assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID)
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
ephemeralByHostname := map[string]types.Node{}
for _, ephemeralNode := range ephemeralNodes {
ephemeralByHostname[ephemeralNode.Hostname] = *ephemeralNode
}
require.Contains(t, ephemeralByHostname, nodeEph.Hostname)
require.Contains(t, ephemeralByHostname, nodeEphNoAuthKey.Hostname)
assert.Equal(t, nodeEph.AuthKeyID, ephemeralByHostname[nodeEph.Hostname].AuthKeyID)
assert.Equal(t, nodeEphNoAuthKey.AuthKeyID, ephemeralByHostname[nodeEphNoAuthKey.Hostname].AuthKeyID)
}
func TestNodeNaming(t *testing.T) {

View File

@ -81,6 +81,7 @@ CREATE TABLE nodes(
given_name varchar(63),
user_id integer,
register_method text,
ephemeral numeric DEFAULT false,
tags text,
auth_key_id integer,
last_seen datetime,

View File

@ -91,7 +91,7 @@ func (s *State) DebugOverview() string {
expiredCount++
}
if node.AuthKey().Valid() && node.AuthKey().Ephemeral() {
if node.IsEphemeral() {
ephemeralCount++
}
}
@ -302,7 +302,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
info.Nodes.Expired++
}
if node.AuthKey().Valid() && node.AuthKey().Ephemeral() {
if node.IsEphemeral() {
info.Nodes.Ephemeral++
}
}

View File

@ -629,8 +629,7 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] {
var ephemeralNodes []types.NodeView
for _, node := range allNodes.All() {
// Check if node is ephemeral by checking its AuthKey
if node.AuthKey().Valid() && node.AuthKey().Ephemeral() {
if node.IsEphemeral() {
ephemeralNodes = append(ephemeralNodes, node)
}
}
@ -1101,6 +1100,7 @@ type newNodeParams struct {
Endpoints []netip.AddrPort
Expiry *time.Time
RegisterMethod string
Ephemeral bool
// Optional: Pre-auth key specific fields
PreAuthKey *types.PreAuthKey
@ -1191,6 +1191,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
} else {
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
}
node.Ephemeral = params.RegEntry.Node.Ephemeral
// Track tagged status BEFORE processing tags
wasTagged := node.IsTagged()
@ -1286,6 +1287,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
Endpoints: params.Endpoints,
LastSeen: new(time.Now()),
RegisterMethod: params.RegisterMethod,
Ephemeral: params.Ephemeral,
Expiry: params.Expiry,
}
@ -1314,6 +1316,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
nodeToRegister.AuthKey = params.PreAuthKey
nodeToRegister.AuthKeyID = &params.PreAuthKey.ID
nodeToRegister.Ephemeral = params.PreAuthKey.Ephemeral
} else {
// Non-PreAuthKey registration (OIDC, CLI) - always user-owned
nodeToRegister.UserID = &params.User.ID
@ -1676,6 +1679,7 @@ func (s *State) createNewNodeFromAuth(
Endpoints: regEntry.Node.Endpoints,
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
RegisterMethod: registrationMethod,
Ephemeral: regEntry.Node.Ephemeral,
ExistingNodeForNetinfo: existingNodeForNetinfo,
})
}
@ -1820,6 +1824,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Only update AuthKey reference
node.AuthKey = pak
node.AuthKeyID = &pak.ID
node.Ephemeral = pak.Ephemeral
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
@ -1910,6 +1915,7 @@ func (s *State) HandleNodeFromPreAuthKey(
Endpoints: nil, // Endpoints not available in RegisterRequest
Expiry: &regReq.Expiry,
RegisterMethod: util.RegisterMethodAuthKey,
Ephemeral: pak.Ephemeral,
PreAuthKey: pak,
ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}),
})

View File

@ -113,6 +113,7 @@ type Node struct {
User *User `gorm:"constraint:OnDelete:CASCADE;"`
RegisterMethod string
Ephemeral bool `gorm:"default:false"`
// Tags is the definitive owner for tagged nodes.
// When non-empty, the node is "tagged" and tags define its identity.
@ -181,7 +182,7 @@ func (node *Node) IsExpired() bool {
// IsEphemeral returns if the node is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/
func (node *Node) IsEphemeral() bool {
return node.AuthKey != nil && node.AuthKey.Ephemeral
return node.Ephemeral || (node.AuthKey != nil && node.AuthKey.Ephemeral)
}
func (node *Node) IPs() []netip.Addr {

View File

@ -96,6 +96,7 @@ var _NodeCloneNeedsRegeneration = Node(struct {
UserID *uint
User *User
RegisterMethod string
Ephemeral bool
Tags []string
AuthKeyID *uint64
AuthKey *PreAuthKey

View File

@ -223,6 +223,7 @@ func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerO
func (v NodeView) User() UserView { return v.ж.User.View() }
func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod }
func (v NodeView) Ephemeral() bool { return v.ж.Ephemeral }
// Tags is the definitive owner for tagged nodes.
// When non-empty, the node is "tagged" and tags define its identity.
@ -277,6 +278,7 @@ var _NodeViewNeedsRegeneration = Node(struct {
UserID *uint
User *User
RegisterMethod string
Ephemeral bool
Tags []string
AuthKeyID *uint64
AuthKey *PreAuthKey