diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index a16f0aab..459fc664 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -31,6 +31,7 @@ jobs: - TestOIDC024UserCreation - TestOIDCAuthenticationWithPKCE - TestOIDCReloginSameNodeNewUser + - TestOIDCFollowUpUrl - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 81032640..6d50ab7f 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -40,7 +40,7 @@ func (h *Headscale) handleRegister( } if regReq.Followup != "" { - return h.waitForFollowup(ctx, regReq) + return h.waitForFollowup(ctx, regReq, machineKey) } if regReq.Auth != nil && regReq.Auth.AuthKey != "" { @@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { func (h *Headscale) waitForFollowup( ctx context.Context, regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { fu, err := url.Parse(regReq.Followup) if err != nil { @@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) case node := <-reg.Registered: if node == nil { - return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil) + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) } return nodeToRegisterResponse(node), nil } } - return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) + // if the follow-up registration isn't found anymore, instruct the client to try a new registration + return h.reqToNewRegisterResponse(regReq, machineKey) +} + +func (h *Headscale) reqToNewRegisterResponse( + regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + newRegID, err := types.NewRegistrationID() + if err != nil { + return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) + } + + nodeToRegister := types.NewRegisterNode( + types.Node{ + Hostname: regReq.Hostinfo.Hostname, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + }, + ) + + if !regReq.Expiry.IsZero() { + nodeToRegister.Node.Expiry = ®Req.Expiry + } + + log.Info().Msgf("New followup node registration using key: %s", newRegID) + h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(newRegID), + }, nil } func (h *Headscale) handleRegisterWithAuthKey( @@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive( return nil, fmt.Errorf("generating registration ID: %w", err) } - nodeToRegister := types.RegisterNode{ - Node: types.Node{ + nodeToRegister := types.NewRegisterNode( + types.Node{ Hostname: regReq.Hostinfo.Hostname, MachineKey: machineKey, NodeKey: regReq.NodeKey, Hostinfo: regReq.Hostinfo, LastSeen: ptr.To(time.Now()), }, - Registered: make(chan *types.Node), - } + ) if !regReq.Expiry.IsZero() { nodeToRegister.Node.Expiry = ®Req.Expiry diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 6663b44a..01d3c6b3 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode( return nil, err } - newNode := types.RegisterNode{ - Node: types.Node{ + newNode := types.NewRegisterNode( + types.Node{ NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(), Hostname: request.GetName(), @@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostinfo: &hostinfo, }, - Registered: make(chan *types.Node), - } + ) log.Debug(). Caller(). diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 55f917d7..84d00712 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( verb := "Reauthenticated" newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { + if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { + log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) + + return + } httpError(writer, err) return } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 15597706..ad7770ff 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -74,9 +74,25 @@ type State struct { // NewState creates and initializes a new State instance, setting up the database, // IP allocator, DERP map, policy manager, and loading existing users and nodes. func NewState(cfg *types.Config) (*State, error) { + cacheExpiration := registerCacheExpiration + if cfg.Tuning.RegisterCacheExpiration != 0 { + cacheExpiration = cfg.Tuning.RegisterCacheExpiration + } + + cacheCleanup := registerCacheCleanup + if cfg.Tuning.RegisterCacheCleanup != 0 { + cacheCleanup = cfg.Tuning.RegisterCacheCleanup + } + registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( - registerCacheExpiration, - registerCacheCleanup, + cacheExpiration, + cacheCleanup, + ) + + registrationCache.OnEvicted( + func(id types.RegistrationID, rn types.RegisterNode) { + rn.SendAndClose(nil) + }, ) db, err := hsdb.NewHeadscaleDatabase( @@ -1238,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath( s.nodeStore.PutNode(*savedNode) } + // Signal to waiting clients + regEntry.SendAndClose(savedNode) + // Delete from registration cache s.registrationCache.Delete(registrationID) - // Signal to waiting clients - select { - case regEntry.Registered <- savedNode: - default: - } - close(regEntry.Registered) - // Update policy manager nodesChange, err := s.updatePolicyManagerNodes() if err != nil { diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index a80f2ab4..a7d815bf 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "runtime" + "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/util" @@ -186,6 +187,28 @@ func (r RegistrationID) String() string { type RegisterNode struct { Node Node Registered chan *Node + closed *atomic.Bool +} + +func NewRegisterNode(node Node) RegisterNode { + return RegisterNode{ + Node: node, + Registered: make(chan *Node), + closed: &atomic.Bool{}, + } +} + +func (rn *RegisterNode) SendAndClose(node *Node) { + if rn.closed.Swap(true) { + return + } + + select { + case rn.Registered <- node: + default: + } + + close(rn.Registered) } // DefaultBatcherWorkers returns the default number of batcher workers. diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4a0a366e..d4a7d662 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -235,6 +235,8 @@ type Tuning struct { BatchChangeDelay time.Duration NodeMapSessionBufferedChanSize int BatcherWorkers int + RegisterCacheCleanup time.Duration + RegisterCacheExpiration time.Duration } func validatePKCEMethod(method string) error { @@ -1002,6 +1004,8 @@ func LoadServerConfig() (*Config, error) { } return DefaultBatcherWorkers() }(), + RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), + RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"), }, }, nil } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 751a8d11..dedb9e91 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -3,6 +3,7 @@ package integration import ( "maps" "net/netip" + "net/url" "sort" "testing" "time" @@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") } +func TestOIDCFollowUpUrl(t *testing.T) { + IntegrationSkip(t) + + // Create no nodes and no users + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // smaller cache expiration time to quickly expire AuthURL + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s", + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + assertNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + assertNoErr(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + // wait for the registration cache to expire + // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION + time.Sleep(2 * time.Minute) + + st, err := ts.Status() + assertNoErr(t, err) + assert.Equal(t, "NeedsLogin", st.BackendState) + + // get new AuthURL from daemon + newUrl, err := url.Parse(st.AuthURL) + assertNoErr(t, err) + + assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change") + + _, err = doLoginURL(ts.Hostname(), newUrl) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodes, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodes, 1) +} + // assertTailscaleNodesLogout verifies that all provided Tailscale clients // are in the logged-out state (NeedsLogin). func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {