From 2dfba3d2ccfea1e65ed7ceecfd289b6cf29d2ed7 Mon Sep 17 00:00:00 2001 From: Andrey Bobelev Date: Fri, 29 Aug 2025 14:20:07 +0200 Subject: [PATCH 1/2] chore: make reg cache expiry tunable Mostly for the tests, opts: - tuning.register_cache_expiration - tuning.register_cache_cleanup --- hscontrol/state/state.go | 14 ++++++++++++-- hscontrol/types/config.go | 4 ++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 15597706..b4baf7b5 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -74,9 +74,19 @@ type State struct { // NewState creates and initializes a new State instance, setting up the database, // IP allocator, DERP map, policy manager, and loading existing users and nodes. func NewState(cfg *types.Config) (*State, error) { + cacheExpiration := registerCacheExpiration + if cfg.Tuning.RegisterCacheExpiration != 0 { + cacheExpiration = cfg.Tuning.RegisterCacheExpiration + } + + cacheCleanup := registerCacheCleanup + if cfg.Tuning.RegisterCacheCleanup != 0 { + cacheCleanup = cfg.Tuning.RegisterCacheCleanup + } + registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( - registerCacheExpiration, - registerCacheCleanup, + cacheExpiration, + cacheCleanup, ) db, err := hsdb.NewHeadscaleDatabase( diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4a0a366e..d4a7d662 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -235,6 +235,8 @@ type Tuning struct { BatchChangeDelay time.Duration NodeMapSessionBufferedChanSize int BatcherWorkers int + RegisterCacheCleanup time.Duration + RegisterCacheExpiration time.Duration } func validatePKCEMethod(method string) error { @@ -1002,6 +1004,8 @@ func LoadServerConfig() (*Config, error) { } return DefaultBatcherWorkers() }(), + RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), + RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"), }, }, nil } From 80fe04971dec4fb433c75eff7801e4dbbd96e7f7 Mon Sep 17 00:00:00 2001 From: Andrey Bobelev Date: Fri, 29 Aug 2025 15:55:42 +0200 Subject: [PATCH 2/2] fix: return valid AuthUrl in followup request on expired reg id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tailscale client gets a new AuthUrl and sets entry in the regcache - regcache entry expires - client doesn't know about that - client always polls followup request а gets error When user clicks "Login" in the app (after cache expiry), they visit invalid URL and get "node not found in registration cache". Some clients on Windows for e.g. can't get a new AuthUrl without restart the app. To fix that we can issue a new reg id and return user a new valid AuthUrl. RegisterNode is refactored to be created with NewRegisterNode() to autocreate channel and other stuff. --- .github/workflows/test-integration.yaml | 1 + hscontrol/auth.go | 47 +++++++++-- hscontrol/grpcv1.go | 7 +- hscontrol/oidc.go | 6 ++ hscontrol/state/state.go | 16 ++-- hscontrol/types/common.go | 23 ++++++ integration/auth_oidc_test.go | 103 ++++++++++++++++++++++++ 7 files changed, 185 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index a16f0aab..459fc664 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -31,6 +31,7 @@ jobs: - TestOIDC024UserCreation - TestOIDCAuthenticationWithPKCE - TestOIDCReloginSameNodeNewUser + - TestOIDCFollowUpUrl - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 81032640..6d50ab7f 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -40,7 +40,7 @@ func (h *Headscale) handleRegister( } if regReq.Followup != "" { - return h.waitForFollowup(ctx, regReq) + return h.waitForFollowup(ctx, regReq, machineKey) } if regReq.Auth != nil && regReq.Auth.AuthKey != "" { @@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { func (h *Headscale) waitForFollowup( ctx context.Context, regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { fu, err := url.Parse(regReq.Followup) if err != nil { @@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) case node := <-reg.Registered: if node == nil { - return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil) + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) } return nodeToRegisterResponse(node), nil } } - return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) + // if the follow-up registration isn't found anymore, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) +} + +func (h *Headscale) reqToNewRegisterResponse( + regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + newRegID, err := types.NewRegistrationID() + if err != nil { + return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) + } + + nodeToRegister := types.NewRegisterNode( + types.Node{ + Hostname: regReq.Hostinfo.Hostname, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + }, + ) + + if !regReq.Expiry.IsZero() { + nodeToRegister.Node.Expiry = ®Req.Expiry + } + + log.Info().Msgf("New followup node registration using key: %s", newRegID) + h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(newRegID), + }, nil } func (h *Headscale) handleRegisterWithAuthKey( @@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive( return nil, fmt.Errorf("generating registration ID: %w", err) } - nodeToRegister := types.RegisterNode{ - Node: types.Node{ + nodeToRegister := types.NewRegisterNode( + types.Node{ Hostname: regReq.Hostinfo.Hostname, MachineKey: machineKey, NodeKey: regReq.NodeKey, Hostinfo: regReq.Hostinfo, LastSeen: ptr.To(time.Now()), }, - Registered: make(chan *types.Node), - } + ) if !regReq.Expiry.IsZero() { nodeToRegister.Node.Expiry = ®Req.Expiry diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 6663b44a..01d3c6b3 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode( return nil, err } - newNode := types.RegisterNode{ - Node: types.Node{ + newNode := types.NewRegisterNode( + types.Node{ NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(), Hostname: request.GetName(), @@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostinfo: &hostinfo, }, - Registered: make(chan *types.Node), - } + ) log.Debug(). Caller(). diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 55f917d7..84d00712 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( verb := "Reauthenticated" newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { + if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { + log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) + + return + } httpError(writer, err) return } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index b4baf7b5..ad7770ff 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -89,6 +89,12 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup, ) + registrationCache.OnEvicted( + func(id types.RegistrationID, rn types.RegisterNode) { + rn.SendAndClose(nil) + }, + ) + db, err := hsdb.NewHeadscaleDatabase( cfg.Database, cfg.BaseDomain, @@ -1248,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath( s.nodeStore.PutNode(*savedNode) } + // Signal to waiting clients + regEntry.SendAndClose(savedNode) + // Delete from registration cache s.registrationCache.Delete(registrationID) - // Signal to waiting clients - select { - case regEntry.Registered <- savedNode: - default: - } - close(regEntry.Registered) - // Update policy manager nodesChange, err := s.updatePolicyManagerNodes() if err != nil { diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index a80f2ab4..a7d815bf 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "runtime" + "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/util" @@ -186,6 +187,28 @@ func (r RegistrationID) String() string { type RegisterNode struct { Node Node Registered chan *Node + closed *atomic.Bool +} + +func NewRegisterNode(node Node) RegisterNode { + return RegisterNode{ + Node: node, + Registered: make(chan *Node), + closed: &atomic.Bool{}, + } +} + +func (rn *RegisterNode) SendAndClose(node *Node) { + if rn.closed.Swap(true) { + return + } + + select { + case rn.Registered <- node: + default: + } + + close(rn.Registered) } // DefaultBatcherWorkers returns the default number of batcher workers. diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 751a8d11..dedb9e91 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -3,6 +3,7 @@ package integration import ( "maps" "net/netip" + "net/url" "sort" "testing" "time" @@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") } +func TestOIDCFollowUpUrl(t *testing.T) { + IntegrationSkip(t) + + // Create no nodes and no users + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // smaller cache expiration time to quickly expire AuthURL + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s", + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + assertNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + assertNoErr(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + // wait for the registration cache to expire + // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION + time.Sleep(2 * time.Minute) + + st, err := ts.Status() + assertNoErr(t, err) + assert.Equal(t, "NeedsLogin", st.BackendState) + + // get new AuthURL from daemon + newUrl, err := url.Parse(st.AuthURL) + assertNoErr(t, err) + + assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change") + + _, err = doLoginURL(ts.Hostname(), newUrl) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodes, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodes, 1) +} + // assertTailscaleNodesLogout verifies that all provided Tailscale clients // are in the logged-out state (NeedsLogin). func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {