From 75217c4579a6915b5ad7ada1b47c8160b2f3d561 Mon Sep 17 00:00:00 2001 From: Andrey Bobelev Date: Fri, 29 Aug 2025 15:55:42 +0200 Subject: [PATCH] fix: return valid AuthUrl in followup request on expired reg id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tailscale client gets a new AuthUrl and sets entry in the regcache - regcache entry expires - client doesn't know about that - client always polls followup request а gets error When user clicks "Login" in the app (after cache expiry), they visit invalid URL and get "node not found in registration cache". Some clients on Windows for e.g. can't get a new AuthUrl without restart the app. To fix that we can issue a new reg id and return user a new valid AuthUrl. --- hscontrol/auth.go | 43 ++++++++++++-- hscontrol/state/state.go | 10 ++++ integration/auth_oidc_test.go | 105 ++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 5 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index cb284173..928ed1f8 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -59,7 +59,7 @@ func (h *Headscale) handleRegister( } if regReq.Followup != "" { - return h.waitForFollowup(ctx, regReq) + return h.waitForFollowup(ctx, regReq, machineKey) } if regReq.Auth != nil && regReq.Auth.AuthKey != "" { @@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode( } h.Change(c) - } + } - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(node), nil } func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { @@ -147,6 +147,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { func (h *Headscale) waitForFollowup( ctx context.Context, regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { fu, err := url.Parse(regReq.Followup) if err != nil { @@ -164,13 +165,45 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) case node := <-reg.Registered: if node == nil { - return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil) + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) } return nodeToRegisterResponse(node), nil } } - return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) + // if the follow-up registration isn't found anymore, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) +} + +func (h *Headscale) reqToNewRegisterResponse( + regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + newRegID, err := types.NewRegistrationID() + if err != nil { + return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) + } + + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: regReq.Hostinfo.Hostname, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + }, + Registered: make(chan *types.Node), + } + + if !regReq.Expiry.IsZero() { + nodeToRegister.Node.Expiry = ®Req.Expiry + } + h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(newRegID), + }, nil } func (h *Headscale) handleRegisterWithAuthKey( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 4b1b4f0d..0eb91dd3 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -82,6 +82,16 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup, ) + registrationCache.OnEvicted( + func(id types.RegistrationID, node types.RegisterNode) { + select { + case node.Registered <- nil: + // notify the followup handler that registration is not valid anymore + default: + } + }, + ) + db, err := hsdb.NewHeadscaleDatabase( cfg.Database, cfg.BaseDomain, diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index d118b643..5204f046 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -15,6 +15,7 @@ import ( "github.com/oauth2-proxy/mockoidc" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "net/url" ) func TestOIDCAuthenticationPingAll(t *testing.T) { @@ -616,6 +617,110 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) } +func TestOIDCFollowUpUrl(t *testing.T) { + IntegrationSkip(t) + + // Create no nodes and no users + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // smaller cache expiration time to quickly expire AuthURL + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s", + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "30s", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + assertNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + assertNoErr(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + // wait for the registration cache to expire + time.Sleep(2 * time.Minute) + + st, err := ts.Status() + assertNoErr(t, err) + assert.Equal(t, "NeedsLogin", st.BackendState) + + // get new AuthURL from daemon + t.Logf("Status: %s", st.AuthURL) + newUrl, err := url.Parse(st.AuthURL) + assertNoErr(t, err) + t.Logf("AuthUrl: %s", newUrl.String()) + + assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change") + + _, err = doLoginURL(ts.Hostname(), newUrl) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodes, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodes, 1) + +} + func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { t.Helper()