diff --git a/hscontrol/auth.go b/hscontrol/auth.go index cb284173..928ed1f8 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -59,7 +59,7 @@ func (h *Headscale) handleRegister( } if regReq.Followup != "" { - return h.waitForFollowup(ctx, regReq) + return h.waitForFollowup(ctx, regReq, machineKey) } if regReq.Auth != nil && regReq.Auth.AuthKey != "" { @@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode( } h.Change(c) - } + } - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(node), nil } func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { @@ -147,6 +147,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { func (h *Headscale) waitForFollowup( ctx context.Context, regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { fu, err := url.Parse(regReq.Followup) if err != nil { @@ -164,13 +165,45 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) case node := <-reg.Registered: if node == nil { - return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil) + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) } return nodeToRegisterResponse(node), nil } } - return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) + // if the follow-up registration isn't found anymore, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) +} + +func (h *Headscale) reqToNewRegisterResponse( + regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + newRegID, err := types.NewRegistrationID() + if err != nil { + return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) + } + + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: regReq.Hostinfo.Hostname, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + }, + Registered: make(chan *types.Node), + } + + if !regReq.Expiry.IsZero() { + nodeToRegister.Node.Expiry = ®Req.Expiry + } + h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(newRegID), + }, nil } func (h *Headscale) handleRegisterWithAuthKey( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 4b1b4f0d..0eb91dd3 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -82,6 +82,16 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup, ) + registrationCache.OnEvicted( + func(id types.RegistrationID, node types.RegisterNode) { + select { + case node.Registered <- nil: + // notify the followup handler that registration is not valid anymore + default: + } + }, + ) + db, err := hsdb.NewHeadscaleDatabase( cfg.Database, cfg.BaseDomain, diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index d118b643..5204f046 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -15,6 +15,7 @@ import ( "github.com/oauth2-proxy/mockoidc" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "net/url" ) func TestOIDCAuthenticationPingAll(t *testing.T) { @@ -616,6 +617,110 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) } +func TestOIDCFollowUpUrl(t *testing.T) { + IntegrationSkip(t) + + // Create no nodes and no users + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // smaller cache expiration time to quickly expire AuthURL + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s", + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "30s", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + assertNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + assertNoErr(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + // wait for the registration cache to expire + time.Sleep(2 * time.Minute) + + st, err := ts.Status() + assertNoErr(t, err) + assert.Equal(t, "NeedsLogin", st.BackendState) + + // get new AuthURL from daemon + t.Logf("Status: %s", st.AuthURL) + newUrl, err := url.Parse(st.AuthURL) + assertNoErr(t, err) + t.Logf("AuthUrl: %s", newUrl.String()) + + assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change") + + _, err = doLoginURL(ts.Hostname(), newUrl) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodes, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodes, 1) + +} + func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { t.Helper()