diff --git a/hscontrol/app.go b/hscontrol/app.go index 4affb6e0..87b37510 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { r.Get("/health", h.HealthHandler) r.Get("/version", h.VersionHandler) r.Get("/key", h.KeyHandler) - r.Get("/register/{registration_id}", h.authProvider.RegisterHandler) + r.Get("/register/{auth_id}", h.authProvider.RegisterHandler) + r.Get("/auth/{auth_id}", h.authProvider.AuthHandler) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { r.Get("/oidc/callback", provider.OIDCCallbackHandler) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fdc63461..fd1b231b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -20,7 +20,9 @@ import ( type AuthProvider interface { RegisterHandler(w http.ResponseWriter, r *http.Request) - AuthURL(regID types.RegistrationID) string + AuthHandler(w http.ResponseWriter, r *http.Request) + RegisterURL(authID types.AuthID) string + AuthURL(authID types.AuthID) string } func (h *Headscale) handleRegister( @@ -261,22 +263,22 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) } - followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) if err != nil { return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err) } - if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok { + if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok { select { case <-ctx.Done(): return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) - case node := <-reg.Registered: - if node == nil { + case node := <-reg.WaitForRegistration(): + if !node.Valid() { // registration is expired in the cache, instruct the client to try a new registration return h.reqToNewRegisterResponse(req, machineKey) } - return nodeToRegisterResponse(node.View()), nil + return nodeToRegisterResponse(node), nil } } @@ -291,14 +293,14 @@ func (h *Headscale) reqToNewRegisterResponse( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - newRegID, err := types.NewRegistrationID() + newAuthID, err := types.NewAuthID() if err != nil { return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -307,25 +309,25 @@ func (h *Headscale) reqToNewRegisterResponse( hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - log.Info().Msgf("new followup node registration using key: %s", newRegID) - h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + log.Info().Msgf("new followup node registration using key: %s", newAuthID) + h.state.SetAuthCacheEntry(newAuthID, authRegReq) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(newRegID), + AuthURL: h.authProvider.RegisterURL(newAuthID), }, nil } @@ -376,13 +378,6 @@ func (h *Headscale) handleRegisterWithAuthKey( // Send both changes. Empty changes are ignored by Change(). h.Change(changed, routesChange) - // TODO(kradalby): I think this is covered above, but we need to validate that. - // // If policy changed due to node registration, send a separate policy change - // if policyChanged { - // policyChange := change.PolicyChange() - // h.Change(policyChange) - // } - resp := &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), @@ -404,14 +399,14 @@ func (h *Headscale) handleRegisterInteractive( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - registrationId, err := types.NewRegistrationID() + authID, err := types.NewAuthID() if err != nil { return nil, fmt.Errorf("generating registration ID: %w", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -434,28 +429,28 @@ func (h *Headscale) handleRegisterInteractive( hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - h.state.SetRegistrationCacheEntry( - registrationId, - nodeToRegister, + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + h.state.SetAuthCacheEntry( + authID, + authRegReq, ) - log.Info().Msgf("starting node registration using key: %s", registrationId) + log.Info().Msgf("starting node registration using key: %s", authID) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(registrationId), + AuthURL: h.authProvider.RegisterURL(authID), }, nil } diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go index e7b74b75..7016af31 100644 --- a/hscontrol/auth_tags_test.go +++ b/hscontrol/auth_tags_test.go @@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 1: Create user-owned node WITH expiry set clientExpiry := time.Now().Add(24 * time.Hour) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "personal-to-tagged", @@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 2: Re-auth with tags (Personal → Tagged conversion) nodeKey2 := key.NewNode() - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "personal-to-tagged", @@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client still sends expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", @@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Create tagged node (expiry should be nil) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "tagged-to-personal", @@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { RequestTags: []string{"tag:server"}, // Tagged node }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { // Step 2: Re-auth with empty tags (Tagged → Personal conversion) nodeKey2 := key.NewNode() clientExpiry := time.Now().Add(48 * time.Hour) - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "tagged-to-personal", @@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client requests expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index d28ed565..8215b07c 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_success", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-success-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-success-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate successful registration - send to buffered channel - // The channel is buffered (size 1), so this can complete immediately - // and handleRegister will receive the value when it starts waiting + // Simulate successful registration + // handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") node := app.state.CreateNodeForTest(user, "followup-success-node") - registered <- node + nodeToRegister.FinishRegistration(node.View()) }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_timeout", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-timeout-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) - // Don't send anything on channel - will timeout + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-timeout-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) + // Don't call FinishRegistration - will timeout return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil }, @@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_node_nil_response", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "nil-response-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "nil-response-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate registration that returns nil (cache expired during auth) - // The channel is buffered (size 1), so this can complete immediately + // Simulate registration that returns empty NodeView (cache expired during auth) go func() { - registered <- nil // Nil indicates cache expiry + nodeToRegister.FinishRegistration(types.NodeView{}) // Empty view indicates cache expiry }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Generate a registration ID that doesn't exist in cache // This simulates an expired/missing cache entry - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } @@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) { // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") - newRegID, err := types.RegistrationIDFromString(newRegIDStr) + newRegID, err := types.AuthIDFromString(newRegIDStr) assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure // Verify new registration entry exists in cache - _, found := app.state.GetRegistrationCacheEntry(newRegID) + _, found := app.state.GetAuthCacheEntry(newRegID) assert.True(t, found, "new registration should exist in cache") }, }, @@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify cache entry exists - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) assert.True(t, found, "registration cache entry should exist initially") assert.NotNil(t, cacheEntry) @@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern // Cache entry should still exist after auth error (for retry scenarios) - _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) + _, stillFound := app.state.GetAuthCacheEntry(registrationID) assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") }, }, @@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) { assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") // Both cache entries should exist simultaneously - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify both exist - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) { } // First registration should still be in cache (not completed) - _, stillFound := app.state.GetRegistrationCacheEntry(regID1) + _, stillFound := app.state.GetAuthCacheEntry(regID1) assert.True(t, stillFound, "first registration should still be pending") }, }, @@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { var ( initialResp *tailcfg.RegisterResponse authURL string - registrationID types.RegistrationID + registrationID types.AuthID finalResp *tailcfg.RegisterResponse err error ) @@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if step.expectCacheEntry { // Verify registration cache entry was created - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) require.True(t, found, "registration cache entry should exist") require.NotNil(t, cacheEntry, "cache entry should not be nil") - require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") + require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key") } case stepTypeAuthCompletion: @@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { // Check cache cleanup expectation for this step if step.expectCacheEntry == false && registrationID != "" { // Verify cache entry was cleaned up - _, found := app.state.GetRegistrationCacheEntry(registrationID) + _, found := app.state.GetAuthCacheEntry(registrationID) require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) } } @@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { } // extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. -func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { +func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" @@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err idStr := authURL[idx+len(registerPrefix):] - return types.RegistrationIDFromString(idStr) + return types.AuthIDFromString(idStr) } // validateCompleteRegistrationResponse performs comprehensive validation of a registration response. @@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { nodeKey := key.NewNode() // Simulate a registration cache entry (as would be created during web auth) - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "webauth-tags-node", @@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete the web auth - should fail because tag is unauthorized _, _, err := app.state.HandleNodeFromAuthPath( @@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Initial registration with tags - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "reauth-untag-node", @@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{"tag:valid-owned", "tag:second"}, }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) // Complete initial registration with tags node, _, err := app.state.HandleNodeFromAuthPath( @@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { // Step 2: Reauth with EMPTY tags to untag nodeKey2 := key.NewNode() // New node key for reauth - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "reauth-untag-node", @@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned nodeKey2 := key.NewNode() // New node key for reauth - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "authkey-tagged-node", @@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { // Step 4: Re-register the node to alice via HandleNodeFromAuthPath // This is what happens when running: headscale nodes register --user alice --key ... nodeKey2 := key.NewNode() - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key as the tagged node NodeKey: nodeKey2.Public(), Hostname: "tagged-orphan-node", @@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { RequestTags: []string{}, // Empty - transition to user-owned }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // This should NOT panic - before the fix, this would panic with: // panic: runtime error: invalid memory address or nil pointer dereference diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 6841f446..69f71e36 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -47,7 +47,7 @@ const ( type HSDatabase struct { DB *gorm.DB cfg *types.Config - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + regCache *zcache.Cache[types.AuthID, types.AuthRequest] } // NewHeadscaleDatabase creates a new database connection and runs migrations. @@ -56,7 +56,7 @@ type HSDatabase struct { //nolint:gocyclo // complex database initialization with many migrations func NewHeadscaleDatabase( cfg *types.Config, - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], + regCache *zcache.Cache[types.AuthID, types.AuthRequest], ) (*HSDatabase, error) { dbConn, err := openDB(cfg.Database) if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3c687b39..151d9966 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { } } -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 073c6677..d7c192a6 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode( Str(zf.RegistrationKey, registrationKey). Msg("registering node") - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } @@ -780,33 +780,32 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostname: request.GetName(), } - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } - newNode := types.NewRegisterNode( - types.Node{ - NodeKey: key.NewNode().Public(), - MachineKey: key.NewMachine().Public(), - Hostname: request.GetName(), - User: user, + newNode := types.Node{ + NodeKey: key.NewNode().Public(), + MachineKey: key.NewMachine().Public(), + Hostname: request.GetName(), + User: user, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - Hostinfo: &hostinfo, - }, - ) + Hostinfo: &hostinfo, + } log.Debug(). Caller(). Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") - api.h.state.SetRegistrationCacheEntry(registrationId, newNode) + authRegReq := types.NewRegisterAuthRequest(newNode) + api.h.state.SetAuthCacheEntry(registrationId, authRegReq) - return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil + return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil } func (api headscaleV1APIServer) Health( diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 7c45f1ec..b7aa8460 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/assets" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { } } -func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { +func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationId.String()) + authID.String()) +} + +func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderWeb) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { +} + +func authIDFromRequest(req *http.Request) (types.AuthID, error) { + registrationId, err := urlParam[types.AuthID](req, "auth_id") + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + // We need to make sure we dont open for XSS style injections, if the parameter that + // is passed as a key is not parsable/validated as a NodePublic key, then fail to render + // the template and log an error. + err = registrationId.Validate() + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + return registrationId, nil } // RegisterHandler shows a simple message in the browser to point to the CLI @@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] - - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) + registrationId, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 9e544633..6f3fbccb 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{ } // emptyCache creates an empty registration cache for testing. -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } // Test configuration constants. diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 9d284921..2bc62fa9 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -12,7 +12,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -26,8 +25,8 @@ import ( const ( randomByteSize = 16 defaultOAuthOptionsCount = 3 - registerCacheExpiration = time.Minute * 15 - registerCacheCleanup = time.Minute * 20 + authCacheExpiration = time.Minute * 15 + authCacheCleanup = time.Minute * 20 ) var ( @@ -44,17 +43,21 @@ var ( errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email") ) -// RegistrationInfo contains both machine key and verifier information for OIDC validation. -type RegistrationInfo struct { - RegistrationID types.RegistrationID - Verifier *string +// AuthInfo contains both auth ID and verifier information for OIDC validation. +type AuthInfo struct { + AuthID types.AuthID + Verifier *string + Registration bool } type AuthProviderOIDC struct { - h *Headscale - serverURL string - cfg *types.OIDCConfig - registrationCache *zcache.Cache[string, RegistrationInfo] + h *Headscale + serverURL string + cfg *types.OIDCConfig + + // authCache holds auth information between + // the auth and the callback steps. + authCache *zcache.Cache[string, AuthInfo] oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -81,45 +84,63 @@ func NewAuthProviderOIDC( Scopes: cfg.Scope, } - registrationCache := zcache.New[string, RegistrationInfo]( - registerCacheExpiration, - registerCacheCleanup, + authCache := zcache.New[string, AuthInfo]( + authCacheExpiration, + authCacheCleanup, ) return &AuthProviderOIDC{ - h: h, - serverURL: serverURL, - cfg: cfg, - registrationCache: registrationCache, + h: h, + serverURL: serverURL, + cfg: cfg, + authCache: authCache, oidcProvider: oidcProvider, oauth2Config: oauth2Config, }, nil } -func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { +func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderOIDC) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { + a.authHandler(writer, req, false) +} + +func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationID.String()) + authID.String()) } // RegisterHandler registers the OIDC callback handler with the given router. // It puts NodeKey in cache so the callback can retrieve it using the oidc state param. -// Listens in /register/:registration_id. +// Listens in /register/:auth_id. func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] + a.authHandler(writer, req, true) +} - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) +// authHandler takes an incoming request that needs to be authenticated and +// validates and prepares it for the OIDC flow. +func (a *AuthProviderOIDC) authHandler( + writer http.ResponseWriter, + req *http.Request, + registration bool, +) { + authID, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } @@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler( return } - // Initialize registration info with machine key - registrationInfo := RegistrationInfo{ - RegistrationID: registrationId, + registrationInfo := AuthInfo{ + AuthID: authID, + Registration: registration, } extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) @@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler( extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info - a.registrationCache.Set(state, registrationInfo) + a.authCache.Set(state, registrationInfo) authURL := a.oauth2Config.AuthCodeURL(state, extras...) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL) @@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // If the node exists, then the node should be reauthenticated, // if the node does not exist, and the machine key exists, then // this is a new node that should be registered. - registrationId := a.getRegistrationIDFromState(state) + authInfo := a.getAuthInfoFromState(state) + if authInfo == nil { + log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) - // Register the node if it does not exist. - if registrationId != nil { + return + } + + // If this is a registration flow, then we need to register the node. + if authInfo.Registration { verb := "Reauthenticated" - newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) + newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { - log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed") httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) return @@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Neither node nor machine key was found in the state cache meaning - // that we could not reauth nor register the node. - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + // TODO(kradalby): handle login flow (without registration) if needed. + // We need to send an update here to whatever might be waiting for this auth flow. } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token( var exchangeOpts []oauth2.AuthCodeOption if a.cfg.PKCE.Enabled { - regInfo, ok := a.registrationCache.Get(state) + regInfo, ok := a.authCache.Get(state) if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } @@ -507,14 +533,14 @@ func doOIDCAuthorization( return nil } -// getRegistrationIDFromState retrieves the registration ID from the state. -func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { - regInfo, ok := a.registrationCache.Get(state) +// getAuthInfoFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo { + authInfo, ok := a.authCache.Get(state) if !ok { return nil } - return ®Info.RegistrationID + return &authInfo } func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( @@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) handleRegistration( user *types.User, - registrationID types.RegistrationID, + registrationID types.AuthID, expiry time.Time, ) (bool, error) { node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index f546f7a4..eb927750 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -82,8 +82,10 @@ type State struct { derpMap atomic.Pointer[tailcfg.DERPMap] // polMan handles policy evaluation and management polMan policy.PolicyManager - // registrationCache caches node registration data to reduce database load - registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + + // authCache caches any pending authentication requests, from either auth type (Web and OIDC). + authCache *zcache.Cache[types.AuthID, types.AuthRequest] + // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes } @@ -101,20 +103,20 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup = cfg.Tuning.RegisterCacheCleanup } - registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( + authCache := zcache.New[types.AuthID, types.AuthRequest]( cacheExpiration, cacheCleanup, ) - registrationCache.OnEvicted( - func(id types.RegistrationID, rn types.RegisterNode) { - rn.SendAndClose(nil) + authCache.OnEvicted( + func(id types.AuthID, rn types.AuthRequest) { + rn.FinishRegistration(types.NodeView{}) }, ) db, err := hsdb.NewHeadscaleDatabase( cfg, - registrationCache, + authCache, ) if err != nil { return nil, fmt.Errorf("initializing database: %w", err) @@ -178,12 +180,12 @@ func NewState(cfg *types.Config) (*State, error) { return &State{ cfg: cfg, - db: db, - ipAlloc: ipAlloc, - polMan: polMan, - registrationCache: registrationCache, - primaryRoutes: routes.New(), - nodeStore: nodeStore, + db: db, + ipAlloc: ipAlloc, + polMan: polMan, + authCache: authCache, + primaryRoutes: routes.New(), + nodeStore: nodeStore, }, nil } @@ -1042,9 +1044,9 @@ func (s *State) DeletePreAuthKey(id uint64) error { return s.db.DeletePreAuthKey(id) } -// GetRegistrationCacheEntry retrieves a node registration from cache. -func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) { - entry, found := s.registrationCache.Get(id) +// GetAuthCacheEntry retrieves a node registration from cache. +func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) { + entry, found := s.authCache.Get(id) if !found { return nil, false } @@ -1052,26 +1054,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis return &entry, true } -// SetRegistrationCacheEntry stores a node registration in cache. -func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) { - s.registrationCache.Set(id, entry) +// SetAuthCacheEntry stores a node registration in cache. +func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) { + s.authCache.Set(id, entry) } // logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. -func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { - if hostinfo == nil { +func logHostinfoValidation(nv types.NodeView, username, hostname string) { + if !nv.Hostinfo().Valid() { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had nil hostinfo, generated default hostname") - } else if hostinfo.Hostname == "" { + } else if nv.Hostinfo().Hostname() == "" { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had empty hostname, generated default") @@ -1113,7 +1113,7 @@ type authNodeUpdateParams struct { // Node to update; must be valid and in NodeStore. ExistingNode types.NodeView // Client data: keys, hostinfo, endpoints. - RegEntry *types.RegisterNode + RegEntry *types.AuthRequest // Pre-validated hostinfo; NetInfo preserved from ExistingNode. ValidHostinfo *tailcfg.Hostinfo // Hostname from hostinfo, or generated from keys if client omits it. @@ -1132,6 +1132,7 @@ type authNodeUpdateParams struct { // an existing node. It updates the node in NodeStore, processes RequestTags, and // persists changes to the database. func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { + regNv := params.RegEntry.Node() // Log the operation type if params.IsConvertFromTag { log.Info(). @@ -1140,16 +1141,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView Msg("Converting tagged node to user-owned node") } else { log.Info(). - EmbedObject(params.ExistingNode). - Interface("hostinfo", params.RegEntry.Node.Hostinfo). + Object("existing", params.ExistingNode). + Object("incoming", regNv). Msg("Updating existing node registration via reauth") } // Process RequestTags during reauth (#2979) // Due to json:",omitempty", we treat empty/nil as "clear tags" var requestTags []string - if params.RegEntry.Node.Hostinfo != nil { - requestTags = params.RegEntry.Node.Hostinfo.RequestTags + if regNv.Hostinfo().Valid() { + requestTags = regNv.Hostinfo().RequestTags().AsSlice() } oldTags := params.ExistingNode.Tags().AsSlice() @@ -1167,8 +1168,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView // Update existing node in NodeStore - validation passed, safe to mutate updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { - node.NodeKey = params.RegEntry.Node.NodeKey - node.DiscoKey = params.RegEntry.Node.DiscoKey + node.NodeKey = regNv.NodeKey() + node.DiscoKey = regNv.DiscoKey() node.Hostname = params.Hostname // Preserve NetInfo from existing node when re-registering @@ -1179,7 +1180,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView params.ValidHostinfo, ) - node.Endpoints = params.RegEntry.Node.Endpoints + node.Endpoints = regNv.Endpoints().AsSlice() node.IsOnline = new(false) node.LastSeen = new(time.Now()) @@ -1188,7 +1189,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.IsConvertFromTag { node.RegisterMethod = params.RegisterMethod } else { - node.RegisterMethod = params.RegEntry.Node.RegisterMethod + node.RegisterMethod = regNv.RegisterMethod() } // Track tagged status BEFORE processing tags @@ -1208,7 +1209,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !wasTagged && isTagged: // Personal → Tagged: clear expiry (tagged nodes don't expire) @@ -1218,14 +1219,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !isTagged: // Personal → Personal: update expiry from client if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } } // Tagged → Tagged: keep existing expiry (nil) - no action needed @@ -1511,13 +1512,13 @@ func (s *State) processReauthTags( // HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). func (s *State) HandleNodeFromAuthPath( - registrationID types.RegistrationID, + authID types.AuthID, userID types.UserID, expiry *time.Time, registrationMethod string, ) (types.NodeView, change.Change, error) { // Get the registration entry from cache - regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + regEntry, ok := s.GetAuthCacheEntry(authID) if !ok { return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache } @@ -1530,25 +1531,27 @@ func (s *State) HandleNodeFromAuthPath( // Ensure we have a valid hostname from the registration cache entry hostname := util.EnsureHostname( - regEntry.Node.Hostinfo, - regEntry.Node.MachineKey.String(), - regEntry.Node.NodeKey.String(), + regEntry.Node().Hostinfo(), + regEntry.Node().MachineKey().String(), + regEntry.Node().NodeKey().String(), ) // Ensure we have valid hostinfo - validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{}) - validHostinfo.Hostname = hostname + hostinfo := &tailcfg.Hostinfo{} + if regEntry.Node().Hostinfo().Valid() { + hostinfo = regEntry.Node().Hostinfo().AsStruct() + } + + hostinfo.Hostname = hostname logHostinfoValidation( - regEntry.Node.MachineKey.ShortString(), - regEntry.Node.NodeKey.String(), + regEntry.Node(), user.Name, hostname, - regEntry.Node.Hostinfo, ) // Lookup existing nodes - machineKey := regEntry.Node.MachineKey + machineKey := regEntry.Node().MachineKey() existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) @@ -1562,7 +1565,7 @@ func (s *State) HandleNodeFromAuthPath( // Create logger with common fields for all auth operations logger := log.With(). - Str(zf.RegistrationID, registrationID.String()). + Str(zf.RegistrationID, authID.String()). Str(zf.UserName, user.Name). Str(zf.MachineKey, machineKey.ShortString()). Str(zf.Method, registrationMethod). @@ -1571,7 +1574,7 @@ func (s *State) HandleNodeFromAuthPath( // Common params for update operations updateParams := authNodeUpdateParams{ RegEntry: regEntry, - ValidHostinfo: validHostinfo, + ValidHostinfo: hostinfo, Hostname: hostname, User: user, Expiry: expiry, @@ -1605,7 +1608,7 @@ func (s *State) HandleNodeFromAuthPath( Msg("Creating new node for different user (same machine key exists for another user)") finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, existingNodeAnyUser, ) if err != nil { @@ -1613,7 +1616,7 @@ func (s *State) HandleNodeFromAuthPath( } } else { finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, types.NodeView{}, ) if err != nil { @@ -1622,10 +1625,10 @@ func (s *State) HandleNodeFromAuthPath( } // Signal to waiting clients - regEntry.SendAndClose(finalNode.AsStruct()) + regEntry.FinishRegistration(finalNode) // Delete from registration cache - s.registrationCache.Delete(registrationID) + s.authCache.Delete(authID) // Update policy managers usersChange, err := s.updatePolicyManagerUsers() @@ -1654,7 +1657,7 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) createNewNodeFromAuth( logger zerolog.Logger, user *types.User, - regEntry *types.RegisterNode, + regEntry *types.AuthRequest, hostname string, validHostinfo *tailcfg.Hostinfo, expiry *time.Time, @@ -1667,13 +1670,13 @@ func (s *State) createNewNodeFromAuth( return s.createAndSaveNewNode(newNodeParams{ User: *user, - MachineKey: regEntry.Node.MachineKey, - NodeKey: regEntry.Node.NodeKey, - DiscoKey: regEntry.Node.DiscoKey, + MachineKey: regEntry.Node().MachineKey(), + NodeKey: regEntry.Node().NodeKey(), + DiscoKey: regEntry.Node().DiscoKey(), Hostname: hostname, Hostinfo: validHostinfo, - Endpoints: regEntry.Node.Endpoints, - Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + Endpoints: regEntry.Node().Endpoints().AsSlice(), + Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()), RegisterMethod: registrationMethod, ExistingNodeForNetinfo: existingNodeForNetinfo, }) @@ -1759,7 +1762,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Ensure we have a valid hostname - handle nil/empty cases hostname := util.EnsureHostname( - regReq.Hostinfo, + regReq.Hostinfo.View(), machineKey.String(), regReq.NodeKey.String(), ) @@ -1768,14 +1771,6 @@ func (s *State) HandleNodeFromPreAuthKey( validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{}) validHostinfo.Hostname = hostname - logHostinfoValidation( - machineKey.ShortString(), - regReq.NodeKey.ShortString(), - pakUsername(), - hostname, - regReq.Hostinfo, - ) - log.Debug(). Caller(). Str(zf.NodeName, hostname). diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go index 829af7fb..cdede03b 100644 --- a/hscontrol/templates/register_web.go +++ b/hscontrol/templates/register_web.go @@ -7,7 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" ) -func RegisterWeb(registrationID types.RegistrationID) *elem.Element { +func RegisterWeb(registrationID types.AuthID) *elem.Element { return HtmlStructure( elem.Title(nil, elem.Text("Registration - Headscale")), mdTypesetBody( diff --git a/hscontrol/templates_consistency_test.go b/hscontrol/templates_consistency_test.go index 369639cc..0464fb88 100644 --- a/hscontrol/templates_consistency_test.go +++ b/hscontrol/templates_consistency_test.go @@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", @@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", @@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), externalURLs: []string{}, // No external links }, { @@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d852753e..66bbf619 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -22,8 +22,8 @@ const ( // Common errors. var ( - ErrCannotParsePrefix = errors.New("cannot parse prefix") - ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length") + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidAuthIDLength = errors.New("registration ID has invalid length") ) type StateUpdateType int @@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } -const RegistrationIDLength = 24 +const AuthIDLength = 24 -type RegistrationID string +type AuthID string -func NewRegistrationID() (RegistrationID, error) { - rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) +func NewAuthID() (AuthID, error) { + rid, err := util.GenerateRandomStringURLSafe(AuthIDLength) if err != nil { return "", err } - return RegistrationID(rid), nil + return AuthID(rid), nil } -func MustRegistrationID() RegistrationID { - rid, err := NewRegistrationID() +func MustAuthID() AuthID { + rid, err := NewAuthID() if err != nil { panic(err) } @@ -181,43 +181,87 @@ func MustRegistrationID() RegistrationID { return rid } -func RegistrationIDFromString(str string) (RegistrationID, error) { - if len(str) != RegistrationIDLength { - return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str)) +func AuthIDFromString(str string) (AuthID, error) { + r := AuthID(str) + + err := r.Validate() + if err != nil { + return "", err } - return RegistrationID(str), nil + return r, nil } -func (r RegistrationID) String() string { +func (r AuthID) String() string { return string(r) } -type RegisterNode struct { - Node Node - Registered chan *Node - closed *atomic.Bool +func (r AuthID) Validate() error { + if len(r) != AuthIDLength { + return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r)) + } + + return nil } -func NewRegisterNode(node Node) RegisterNode { - return RegisterNode{ - Node: node, - Registered: make(chan *Node), - closed: &atomic.Bool{}, +// AuthRequest represent a pending authentication request from a user or a node. +// If it is a registration request, the node field will be populate with the node that is trying to register. +// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel. +// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. +type AuthRequest struct { + node *Node + finished chan NodeView + closed *atomic.Bool +} + +func NewRegisterAuthRequest(node Node) AuthRequest { + return AuthRequest{ + node: &node, + finished: make(chan NodeView), + closed: &atomic.Bool{}, } } -func (rn *RegisterNode) SendAndClose(node *Node) { +// Node returns the node that is trying to register. +// It will panic if the AuthRequest is not a registration request. +// Can _only_ be used in the registration path. +func (rn *AuthRequest) Node() NodeView { + if rn.node == nil { + panic("Node can only be used in registration requests") + } + + return rn.node.View() +} + +func (rn *AuthRequest) FinishAuth() { + rn.FinishRegistration(NodeView{}) +} + +func (rn *AuthRequest) FinishRegistration(node NodeView) { if rn.closed.Swap(true) { return } - select { - case rn.Registered <- node: - default: + if node.Valid() { + select { + case rn.finished <- node: + default: + } } - close(rn.Registered) + close(rn.finished) +} + +// WaitForRegistration waits for the authentication process to finish +// and returns the authenticated node. +// Can _only_ be used in the registration path. +func (rn *AuthRequest) WaitForRegistration() <-chan NodeView { + return rn.finished +} + +// WaitForAuth waits until a authentication request has been finished. +func (rn *AuthRequest) WaitForAuth() { + <-rn.WaitForRegistration() } // DefaultBatcherWorkers returns the default number of batcher workers. diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index cbce663b..034779b5 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -295,8 +295,8 @@ func IsCI() bool { // 3. If normalisation fails → generate invalid- replacement // // Returns the guaranteed-valid hostname to use. -func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { - if hostinfo == nil || hostinfo.Hostname == "" { +func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string { + if !hostinfo.Valid() || hostinfo.Hostname() == "" { key := cmp.Or(machineKey, nodeKey) if key == "" { return "unknown-node" @@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri return "node-" + keyPrefix } - lowercased := strings.ToLower(hostinfo.Hostname) + lowercased := strings.ToLower(hostinfo.Hostname()) err := ValidateHostname(lowercased) if err == nil { diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 5cca4990..6e7a0630 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { if !strings.HasPrefix(got, "invalid-") { @@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { if !strings.HasPrefix(gotHostname, "invalid-") { @@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { hostinfo := &tailcfg.Hostinfo{Hostname: hostname} - result := EnsureHostname(hostinfo, "mkey", "nkey") + result := EnsureHostname(hostinfo.View(), "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) } @@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) { OS: "linux", } - hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey") - hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey") + hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") + hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") if hostname1 != hostname2 { t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) diff --git a/integration/cli_test.go b/integration/cli_test.go index a1174277..c46361d4 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-5", listAll[4].GetName()) otherUserRegIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) @@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs))