1
0
mirror of https://github.com/juanfont/headscale.git synced 2026-02-07 20:04:00 +01:00

auth: add /auth dummy, tighten AuthRequest, generalise

This commit generalise the "Registration" pipeline to a more general auth
pipeline supporting both registrations and general auth requests.

This means we have renamed the RegistrationID to AuthID.
Fields from AuthRequest has been unexported and made read only.
Added dummy /auth endpoints to be filled.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2026-02-11 15:31:06 +01:00
parent d1364194ef
commit e45cf30867
No known key found for this signature in database
17 changed files with 403 additions and 334 deletions

View File

@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
r.Get("/health", h.HealthHandler)
r.Get("/version", h.VersionHandler)
r.Get("/key", h.KeyHandler)
r.Get("/register/{registration_id}", h.authProvider.RegisterHandler)
r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
r.Get("/oidc/callback", provider.OIDCCallbackHandler)

View File

@ -20,7 +20,9 @@ import (
type AuthProvider interface {
RegisterHandler(w http.ResponseWriter, r *http.Request)
AuthURL(regID types.RegistrationID) string
AuthHandler(w http.ResponseWriter, r *http.Request)
RegisterURL(authID types.AuthID) string
AuthURL(authID types.AuthID) string
}
func (h *Headscale) handleRegister(
@ -261,22 +263,22 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
}
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
}
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok {
if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok {
select {
case <-ctx.Done():
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered:
if node == nil {
case node := <-reg.WaitForRegistration():
if !node.Valid() {
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node.View()), nil
return nodeToRegisterResponse(node), nil
}
}
@ -291,14 +293,14 @@ func (h *Headscale) reqToNewRegisterResponse(
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
newAuthID, err := types.NewAuthID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo,
req.Hostinfo.View(),
machineKey.String(),
req.NodeKey.String(),
)
@ -307,25 +309,25 @@ func (h *Headscale) reqToNewRegisterResponse(
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
nodeToRegister := types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
}
log.Info().Msgf("new followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
if !req.Expiry.IsZero() {
nodeToRegister.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
log.Info().Msgf("new followup node registration using key: %s", newAuthID)
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
AuthURL: h.authProvider.RegisterURL(newAuthID),
}, nil
}
@ -376,13 +378,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
// Send both changes. Empty changes are ignored by Change().
h.Change(changed, routesChange)
// TODO(kradalby): I think this is covered above, but we need to validate that.
// // If policy changed due to node registration, send a separate policy change
// if policyChanged {
// policyChange := change.PolicyChange()
// h.Change(policyChange)
// }
resp := &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
@ -404,14 +399,14 @@ func (h *Headscale) handleRegisterInteractive(
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
registrationId, err := types.NewRegistrationID()
authID, err := types.NewAuthID()
if err != nil {
return nil, fmt.Errorf("generating registration ID: %w", err)
}
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo,
req.Hostinfo.View(),
machineKey.String(),
req.NodeKey.String(),
)
@ -434,28 +429,28 @@ func (h *Headscale) handleRegisterInteractive(
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
nodeToRegister := types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
}
h.state.SetRegistrationCacheEntry(
registrationId,
nodeToRegister,
if !req.Expiry.IsZero() {
nodeToRegister.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
h.state.SetAuthCacheEntry(
authID,
authRegReq,
)
log.Info().Msgf("starting node registration using key: %s", registrationId)
log.Info().Msgf("starting node registration using key: %s", authID)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId),
AuthURL: h.authProvider.RegisterURL(authID),
}, nil
}

View File

@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 1: Create user-owned node WITH expiry set
clientExpiry := time.Now().Add(24 * time.Hour)
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "personal-to-tagged",
@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
},
Expiry: &clientExpiry,
})
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth",
@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 2: Re-auth with tags (Personal → Tagged conversion)
nodeKey2 := key.NewNode()
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(),
Hostname: "personal-to-tagged",
@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
},
Expiry: &clientExpiry, // Client still sends expiry
})
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth",
@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
nodeKey1 := key.NewNode()
// Step 1: Create tagged node (expiry should be nil)
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "tagged-to-personal",
@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
RequestTags: []string{"tag:server"}, // Tagged node
},
})
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth",
@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
// Step 2: Re-auth with empty tags (Tagged → Personal conversion)
nodeKey2 := key.NewNode()
clientExpiry := time.Now().Add(48 * time.Hour)
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(),
Hostname: "tagged-to-personal",
@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
},
Expiry: &clientExpiry, // Client requests expiry
})
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth",

View File

@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_success",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "followup-success-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "followup-success-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Simulate successful registration - send to buffered channel
// The channel is buffered (size 1), so this can complete immediately
// and handleRegister will receive the value when it starts waiting
// Simulate successful registration
// handleRegister will receive the value when it starts waiting
go func() {
user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node")
registered <- node
nodeToRegister.FinishRegistration(node.View())
}()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_timeout",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "followup-timeout-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Don't send anything on channel - will timeout
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "followup-timeout-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Don't call FinishRegistration - will timeout
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
},
@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_node_nil_response",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "nil-response-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "nil-response-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Simulate registration that returns nil (cache expired during auth)
// The channel is buffered (size 1), so this can complete immediately
// Simulate registration that returns empty NodeView (cache expired during auth)
go func() {
registered <- nil // Nil indicates cache expiry
nodeToRegister.FinishRegistration(types.NodeView{}) // Empty view indicates cache expiry
}()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
// Generate a registration ID that doesn't exist in cache
// This simulates an expired/missing cache entry
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) {
// Extract and validate the new registration ID exists in cache
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
newRegID, err := types.AuthIDFromString(newRegIDStr)
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
// Verify new registration entry exists in cache
_, found := app.state.GetRegistrationCacheEntry(newRegID)
_, found := app.state.GetAuthCacheEntry(newRegID)
assert.True(t, found, "new registration should exist in cache")
},
},
@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err)
// Verify cache entry exists
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
assert.True(t, found, "registration cache entry should exist initially")
assert.NotNil(t, cacheEntry)
@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
// Cache entry should still exist after auth error (for retry scenarios)
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
_, stillFound := app.state.GetAuthCacheEntry(registrationID)
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
},
},
@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) {
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
// Both cache entries should exist simultaneously
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
_, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetAuthCacheEntry(regID2)
assert.True(t, found1, "first registration cache entry should exist")
assert.True(t, found2, "second registration cache entry should exist")
@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err)
// Verify both exist
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
_, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetAuthCacheEntry(regID2)
assert.True(t, found1, "first cache entry should exist")
assert.True(t, found2, "second cache entry should exist")
@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) {
}
// First registration should still be in cache (not completed)
_, stillFound := app.state.GetRegistrationCacheEntry(regID1)
_, stillFound := app.state.GetAuthCacheEntry(regID1)
assert.True(t, stillFound, "first registration should still be pending")
},
},
@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
var (
initialResp *tailcfg.RegisterResponse
authURL string
registrationID types.RegistrationID
registrationID types.AuthID
finalResp *tailcfg.RegisterResponse
err error
)
@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
if step.expectCacheEntry {
// Verify registration cache entry was created
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
require.True(t, found, "registration cache entry should exist")
require.NotNil(t, cacheEntry, "cache entry should not be nil")
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key")
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
}
case stepTypeAuthCompletion:
@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
// Check cache cleanup expectation for this step
if step.expectCacheEntry == false && registrationID != "" {
// Verify cache entry was cleaned up
_, found := app.state.GetRegistrationCacheEntry(registrationID)
_, found := app.state.GetAuthCacheEntry(registrationID)
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
}
}
@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
}
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
// AuthURL format: "http://localhost/register/abc123"
const registerPrefix = "/register/"
@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err
idStr := authURL[idx+len(registerPrefix):]
return types.RegistrationIDFromString(idStr)
return types.AuthIDFromString(idStr)
}
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
nodeKey := key.NewNode()
// Simulate a registration cache entry (as would be created during web auth)
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "webauth-tags-node",
@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
app.state.SetAuthCacheEntry(registrationID, regEntry)
// Complete the web auth - should fail because tag is unauthorized
_, _, err := app.state.HandleNodeFromAuthPath(
@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
nodeKey1 := key.NewNode()
// Step 1: Initial registration with tags
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "reauth-untag-node",
@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{"tag:valid-owned", "tag:second"},
},
})
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
// Complete initial registration with tags
node, _, err := app.state.HandleNodeFromAuthPath(
@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
// Step 2: Reauth with EMPTY tags to untag
nodeKey2 := key.NewNode() // New node key for reauth
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "reauth-untag-node",
@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag
},
})
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
// Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
nodeKey2 := key.NewNode() // New node key for reauth
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "authkey-tagged-node",
@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
app.state.SetAuthCacheEntry(registrationID, regEntry)
// Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
// This is what happens when running: headscale nodes register --user alice --key ...
nodeKey2 := key.NewNode()
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key as the tagged node
NodeKey: nodeKey2.Public(),
Hostname: "tagged-orphan-node",
@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
RequestTags: []string{}, // Empty - transition to user-owned
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
app.state.SetAuthCacheEntry(registrationID, regEntry)
// This should NOT panic - before the fix, this would panic with:
// panic: runtime error: invalid memory address or nil pointer dereference

View File

@ -47,7 +47,7 @@ const (
type HSDatabase struct {
DB *gorm.DB
cfg *types.Config
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
}
// NewHeadscaleDatabase creates a new database connection and runs migrations.
@ -56,7 +56,7 @@ type HSDatabase struct {
//nolint:gocyclo // complex database initialization with many migrations
func NewHeadscaleDatabase(
cfg *types.Config,
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg.Database)
if err != nil {

View File

@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
}
}
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
}
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {

View File

@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode(
Str(zf.RegistrationKey, registrationKey).
Msg("registering node")
registrationId, err := types.RegistrationIDFromString(request.GetKey())
registrationId, err := types.AuthIDFromString(request.GetKey())
if err != nil {
return nil, err
}
@ -780,33 +780,32 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: request.GetName(),
}
registrationId, err := types.RegistrationIDFromString(request.GetKey())
registrationId, err := types.AuthIDFromString(request.GetKey())
if err != nil {
return nil, err
}
newNode := types.NewRegisterNode(
types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: user,
newNode := types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: user,
Expiry: &time.Time{},
LastSeen: &time.Time{},
Expiry: &time.Time{},
LastSeen: &time.Time{},
Hostinfo: &hostinfo,
},
)
Hostinfo: &hostinfo,
}
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")
api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
authRegReq := types.NewRegisterAuthRequest(newNode)
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
}
func (api headscaleV1APIServer) Health(

View File

@ -11,7 +11,6 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/assets"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
}
}
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
registrationId.String())
authID.String())
}
func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderWeb) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
}
func authIDFromRequest(req *http.Request) (types.AuthID, error) {
registrationId, err := urlParam[types.AuthID](req, "auth_id")
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
err = registrationId.Validate()
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
return registrationId, nil
}
// RegisterHandler shows a simple message in the browser to point to the CLI
@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
registrationId, err := authIDFromRequest(req)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
httpError(writer, err)
return
}

View File

@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{
}
// emptyCache creates an empty registration cache for testing.
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
}
// Test configuration constants.

View File

@ -12,7 +12,6 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@ -26,8 +25,8 @@ import (
const (
randomByteSize = 16
defaultOAuthOptionsCount = 3
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
authCacheExpiration = time.Minute * 15
authCacheCleanup = time.Minute * 20
)
var (
@ -44,17 +43,21 @@ var (
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
)
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
RegistrationID types.RegistrationID
Verifier *string
// AuthInfo contains both auth ID and verifier information for OIDC validation.
type AuthInfo struct {
AuthID types.AuthID
Verifier *string
Registration bool
}
type AuthProviderOIDC struct {
h *Headscale
serverURL string
cfg *types.OIDCConfig
registrationCache *zcache.Cache[string, RegistrationInfo]
h *Headscale
serverURL string
cfg *types.OIDCConfig
// authCache holds auth information between
// the auth and the callback steps.
authCache *zcache.Cache[string, AuthInfo]
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration,
registerCacheCleanup,
authCache := zcache.New[string, AuthInfo](
authCacheExpiration,
authCacheCleanup,
)
return &AuthProviderOIDC{
h: h,
serverURL: serverURL,
cfg: cfg,
registrationCache: registrationCache,
h: h,
serverURL: serverURL,
cfg: cfg,
authCache: authCache,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderOIDC) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
a.authHandler(writer, req, false)
}
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
registrationID.String())
authID.String())
}
// RegisterHandler registers the OIDC callback handler with the given router.
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
// Listens in /register/:registration_id.
// Listens in /register/:auth_id.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
a.authHandler(writer, req, true)
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
// authHandler takes an incoming request that needs to be authenticated and
// validates and prepares it for the OIDC flow.
func (a *AuthProviderOIDC) authHandler(
writer http.ResponseWriter,
req *http.Request,
registration bool,
) {
authID, err := authIDFromRequest(req)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
httpError(writer, err)
return
}
@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
return
}
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
RegistrationID: registrationId,
registrationInfo := AuthInfo{
AuthID: authID,
Registration: registration,
}
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info
a.registrationCache.Set(state, registrationInfo)
a.authCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
registrationId := a.getRegistrationIDFromState(state)
authInfo := a.getAuthInfoFromState(state)
if authInfo == nil {
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// Register the node if it does not exist.
if registrationId != nil {
return
}
// If this is a registration flow, then we need to register the node.
if authInfo.Registration {
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// TODO(kradalby): handle login flow (without registration) if needed.
// We need to send an update here to whatever might be waiting for this auth flow.
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
regInfo, ok := a.authCache.Get(state)
if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
@ -507,14 +533,14 @@ func doOIDCAuthorization(
return nil
}
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
// getAuthInfoFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
authInfo, ok := a.authCache.Get(state)
if !ok {
return nil
}
return &regInfo.RegistrationID
return &authInfo
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.RegistrationID,
registrationID types.AuthID,
expiry time.Time,
) (bool, error) {
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(

View File

@ -82,8 +82,10 @@ type State struct {
derpMap atomic.Pointer[tailcfg.DERPMap]
// polMan handles policy evaluation and management
polMan policy.PolicyManager
// registrationCache caches node registration data to reduce database load
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
}
@ -101,20 +103,20 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
authCache := zcache.New[types.AuthID, types.AuthRequest](
cacheExpiration,
cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
authCache.OnEvicted(
func(id types.AuthID, rn types.AuthRequest) {
rn.FinishRegistration(types.NodeView{})
},
)
db, err := hsdb.NewHeadscaleDatabase(
cfg,
registrationCache,
authCache,
)
if err != nil {
return nil, fmt.Errorf("initializing database: %w", err)
@ -178,12 +180,12 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{
cfg: cfg,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
registrationCache: registrationCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
authCache: authCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
}, nil
}
@ -1042,9 +1044,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id)
}
// GetRegistrationCacheEntry retrieves a node registration from cache.
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
entry, found := s.registrationCache.Get(id)
// GetAuthCacheEntry retrieves a node registration from cache.
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
entry, found := s.authCache.Get(id)
if !found {
return nil, false
}
@ -1052,26 +1054,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
return &entry, true
}
// SetRegistrationCacheEntry stores a node registration in cache.
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
s.registrationCache.Set(id, entry)
// SetAuthCacheEntry stores a node registration in cache.
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
s.authCache.Set(id, entry)
}
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
if hostinfo == nil {
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
if !nv.Hostinfo().Valid() {
log.Warn().
Caller().
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
EmbedObject(nv).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had nil hostinfo, generated default hostname")
} else if hostinfo.Hostname == "" {
} else if nv.Hostinfo().Hostname() == "" {
log.Warn().
Caller().
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
EmbedObject(nv).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had empty hostname, generated default")
@ -1113,7 +1113,7 @@ type authNodeUpdateParams struct {
// Node to update; must be valid and in NodeStore.
ExistingNode types.NodeView
// Client data: keys, hostinfo, endpoints.
RegEntry *types.RegisterNode
RegEntry *types.AuthRequest
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
ValidHostinfo *tailcfg.Hostinfo
// Hostname from hostinfo, or generated from keys if client omits it.
@ -1132,6 +1132,7 @@ type authNodeUpdateParams struct {
// an existing node. It updates the node in NodeStore, processes RequestTags, and
// persists changes to the database.
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
regNv := params.RegEntry.Node()
// Log the operation type
if params.IsConvertFromTag {
log.Info().
@ -1140,16 +1141,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
Msg("Converting tagged node to user-owned node")
} else {
log.Info().
EmbedObject(params.ExistingNode).
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
Object("existing", params.ExistingNode).
Object("incoming", regNv).
Msg("Updating existing node registration via reauth")
}
// Process RequestTags during reauth (#2979)
// Due to json:",omitempty", we treat empty/nil as "clear tags"
var requestTags []string
if params.RegEntry.Node.Hostinfo != nil {
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
if regNv.Hostinfo().Valid() {
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
}
oldTags := params.ExistingNode.Tags().AsSlice()
@ -1167,8 +1168,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
// Update existing node in NodeStore - validation passed, safe to mutate
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
node.NodeKey = params.RegEntry.Node.NodeKey
node.DiscoKey = params.RegEntry.Node.DiscoKey
node.NodeKey = regNv.NodeKey()
node.DiscoKey = regNv.DiscoKey()
node.Hostname = params.Hostname
// Preserve NetInfo from existing node when re-registering
@ -1179,7 +1180,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
params.ValidHostinfo,
)
node.Endpoints = params.RegEntry.Node.Endpoints
node.Endpoints = regNv.Endpoints().AsSlice()
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
@ -1188,7 +1189,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.IsConvertFromTag {
node.RegisterMethod = params.RegisterMethod
} else {
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
node.RegisterMethod = regNv.RegisterMethod()
}
// Track tagged status BEFORE processing tags
@ -1208,7 +1209,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
case !wasTagged && isTagged:
// Personal → Tagged: clear expiry (tagged nodes don't expire)
@ -1218,14 +1219,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
case !isTagged:
// Personal → Personal: update expiry from client
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
}
// Tagged → Tagged: keep existing expiry (nil) - no action needed
@ -1511,13 +1512,13 @@ func (s *State) processReauthTags(
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
func (s *State) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
authID types.AuthID,
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (types.NodeView, change.Change, error) {
// Get the registration entry from cache
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
regEntry, ok := s.GetAuthCacheEntry(authID)
if !ok {
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
}
@ -1530,25 +1531,27 @@ func (s *State) HandleNodeFromAuthPath(
// Ensure we have a valid hostname from the registration cache entry
hostname := util.EnsureHostname(
regEntry.Node.Hostinfo,
regEntry.Node.MachineKey.String(),
regEntry.Node.NodeKey.String(),
regEntry.Node().Hostinfo(),
regEntry.Node().MachineKey().String(),
regEntry.Node().NodeKey().String(),
)
// Ensure we have valid hostinfo
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
hostinfo := &tailcfg.Hostinfo{}
if regEntry.Node().Hostinfo().Valid() {
hostinfo = regEntry.Node().Hostinfo().AsStruct()
}
hostinfo.Hostname = hostname
logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
regEntry.Node(),
user.Name,
hostname,
regEntry.Node.Hostinfo,
)
// Lookup existing nodes
machineKey := regEntry.Node.MachineKey
machineKey := regEntry.Node().MachineKey()
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
@ -1562,7 +1565,7 @@ func (s *State) HandleNodeFromAuthPath(
// Create logger with common fields for all auth operations
logger := log.With().
Str(zf.RegistrationID, registrationID.String()).
Str(zf.RegistrationID, authID.String()).
Str(zf.UserName, user.Name).
Str(zf.MachineKey, machineKey.ShortString()).
Str(zf.Method, registrationMethod).
@ -1571,7 +1574,7 @@ func (s *State) HandleNodeFromAuthPath(
// Common params for update operations
updateParams := authNodeUpdateParams{
RegEntry: regEntry,
ValidHostinfo: validHostinfo,
ValidHostinfo: hostinfo,
Hostname: hostname,
User: user,
Expiry: expiry,
@ -1605,7 +1608,7 @@ func (s *State) HandleNodeFromAuthPath(
Msg("Creating new node for different user (same machine key exists for another user)")
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo,
logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, existingNodeAnyUser,
)
if err != nil {
@ -1613,7 +1616,7 @@ func (s *State) HandleNodeFromAuthPath(
}
} else {
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo,
logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, types.NodeView{},
)
if err != nil {
@ -1622,10 +1625,10 @@ func (s *State) HandleNodeFromAuthPath(
}
// Signal to waiting clients
regEntry.SendAndClose(finalNode.AsStruct())
regEntry.FinishRegistration(finalNode)
// Delete from registration cache
s.registrationCache.Delete(registrationID)
s.authCache.Delete(authID)
// Update policy managers
usersChange, err := s.updatePolicyManagerUsers()
@ -1654,7 +1657,7 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) createNewNodeFromAuth(
logger zerolog.Logger,
user *types.User,
regEntry *types.RegisterNode,
regEntry *types.AuthRequest,
hostname string,
validHostinfo *tailcfg.Hostinfo,
expiry *time.Time,
@ -1667,13 +1670,13 @@ func (s *State) createNewNodeFromAuth(
return s.createAndSaveNewNode(newNodeParams{
User: *user,
MachineKey: regEntry.Node.MachineKey,
NodeKey: regEntry.Node.NodeKey,
DiscoKey: regEntry.Node.DiscoKey,
MachineKey: regEntry.Node().MachineKey(),
NodeKey: regEntry.Node().NodeKey(),
DiscoKey: regEntry.Node().DiscoKey(),
Hostname: hostname,
Hostinfo: validHostinfo,
Endpoints: regEntry.Node.Endpoints,
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
Endpoints: regEntry.Node().Endpoints().AsSlice(),
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
RegisterMethod: registrationMethod,
ExistingNodeForNetinfo: existingNodeForNetinfo,
})
@ -1759,7 +1762,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Ensure we have a valid hostname - handle nil/empty cases
hostname := util.EnsureHostname(
regReq.Hostinfo,
regReq.Hostinfo.View(),
machineKey.String(),
regReq.NodeKey.String(),
)
@ -1768,14 +1771,6 @@ func (s *State) HandleNodeFromPreAuthKey(
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
machineKey.ShortString(),
regReq.NodeKey.ShortString(),
pakUsername(),
hostname,
regReq.Hostinfo,
)
log.Debug().
Caller().
Str(zf.NodeName, hostname).

View File

@ -7,7 +7,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
)
func RegisterWeb(registrationID types.RegistrationID) *elem.Element {
func RegisterWeb(registrationID types.AuthID) *elem.Element {
return HtmlStructure(
elem.Title(nil, elem.Text("Registration - Headscale")),
mdTypesetBody(

View File

@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) {
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
},
{
name: "Windows Config",
@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) {
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
},
{
name: "Windows Config",
@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) {
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
externalURLs: []string{}, // No external links
},
{
@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) {
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
},
{
name: "Windows Config",

View File

@ -22,8 +22,8 @@ const (
// Common errors.
var (
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
)
type StateUpdateType int
@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
const RegistrationIDLength = 24
const AuthIDLength = 24
type RegistrationID string
type AuthID string
func NewRegistrationID() (RegistrationID, error) {
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength)
func NewAuthID() (AuthID, error) {
rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
if err != nil {
return "", err
}
return RegistrationID(rid), nil
return AuthID(rid), nil
}
func MustRegistrationID() RegistrationID {
rid, err := NewRegistrationID()
func MustAuthID() AuthID {
rid, err := NewAuthID()
if err != nil {
panic(err)
}
@ -181,43 +181,87 @@ func MustRegistrationID() RegistrationID {
return rid
}
func RegistrationIDFromString(str string) (RegistrationID, error) {
if len(str) != RegistrationIDLength {
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
func AuthIDFromString(str string) (AuthID, error) {
r := AuthID(str)
err := r.Validate()
if err != nil {
return "", err
}
return RegistrationID(str), nil
return r, nil
}
func (r RegistrationID) String() string {
func (r AuthID) String() string {
return string(r)
}
type RegisterNode struct {
Node Node
Registered chan *Node
closed *atomic.Bool
func (r AuthID) Validate() error {
if len(r) != AuthIDLength {
return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
}
return nil
}
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
// AuthRequest represent a pending authentication request from a user or a node.
// If it is a registration request, the node field will be populate with the node that is trying to register.
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
type AuthRequest struct {
node *Node
finished chan NodeView
closed *atomic.Bool
}
func NewRegisterAuthRequest(node Node) AuthRequest {
return AuthRequest{
node: &node,
finished: make(chan NodeView),
closed: &atomic.Bool{},
}
}
func (rn *RegisterNode) SendAndClose(node *Node) {
// Node returns the node that is trying to register.
// It will panic if the AuthRequest is not a registration request.
// Can _only_ be used in the registration path.
func (rn *AuthRequest) Node() NodeView {
if rn.node == nil {
panic("Node can only be used in registration requests")
}
return rn.node.View()
}
func (rn *AuthRequest) FinishAuth() {
rn.FinishRegistration(NodeView{})
}
func (rn *AuthRequest) FinishRegistration(node NodeView) {
if rn.closed.Swap(true) {
return
}
select {
case rn.Registered <- node:
default:
if node.Valid() {
select {
case rn.finished <- node:
default:
}
}
close(rn.Registered)
close(rn.finished)
}
// WaitForRegistration waits for the authentication process to finish
// and returns the authenticated node.
// Can _only_ be used in the registration path.
func (rn *AuthRequest) WaitForRegistration() <-chan NodeView {
return rn.finished
}
// WaitForAuth waits until a authentication request has been finished.
func (rn *AuthRequest) WaitForAuth() {
<-rn.WaitForRegistration()
}
// DefaultBatcherWorkers returns the default number of batcher workers.

View File

@ -295,8 +295,8 @@ func IsCI() bool {
// 3. If normalisation fails → generate invalid-<random> replacement
//
// Returns the guaranteed-valid hostname to use.
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
if hostinfo == nil || hostinfo.Hostname == "" {
func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
if !hostinfo.Valid() || hostinfo.Hostname() == "" {
key := cmp.Or(machineKey, nodeKey)
if key == "" {
return "unknown-node"
@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
return "node-" + keyPrefix
}
lowercased := strings.ToLower(hostinfo.Hostname)
lowercased := strings.ToLower(hostinfo.Hostname())
err := ValidateHostname(lowercased)
if err == nil {

View File

@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.want, "invalid-") {
if !strings.HasPrefix(got, "invalid-") {
@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.wantHostname, "invalid-") {
if !strings.HasPrefix(gotHostname, "invalid-") {
@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := EnsureHostname(hostinfo, "mkey", "nkey")
result := EnsureHostname(hostinfo.View(), "mkey", "nkey")
if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
}
@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
OS: "linux",
}
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey")
hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
if hostname1 != hostname2 {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)

View File

@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) {
assert.Equal(t, "node-5", listAll[4].GetName())
otherUserRegIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
nodes := make([]*v1.Node, len(regIDs))