1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-10-19 11:15:48 +02:00

fix: return valid AuthUrl in followup request on expired reg id

- tailscale client gets a new AuthUrl and sets entry in the regcache
- regcache entry expires
- client doesn't know about that
- client always polls followup request а gets error

When user clicks "Login" in the app (after cache expiry), they visit
invalid URL and get "node not found in registration cache". Some clients
on Windows for e.g. can't get a new AuthUrl without restart the app.

To fix that we can issue a new reg id and return user a new valid
AuthUrl.

RegisterNode is refactored to be created with NewRegisterNode() to
autocreate channel and other stuff.
This commit is contained in:
Andrey Bobelev 2025-08-29 15:55:42 +02:00 committed by nblock
parent 022098fe4e
commit c4a8c038cd
7 changed files with 196 additions and 18 deletions

View File

@ -31,6 +31,7 @@ jobs:
- TestOIDC024UserCreation - TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE - TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser - TestOIDCReloginSameNodeNewUser
- TestOIDCFollowUpUrl
- TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin - TestAuthWebFlowLogoutAndRelogin
- TestUserCommand - TestUserCommand

View File

@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
} }
if regReq.Followup != "" { if regReq.Followup != "" {
return h.waitForFollowup(ctx, regReq) return h.waitForFollowup(ctx, regReq, machineKey)
} }
if regReq.Auth != nil && regReq.Auth.AuthKey != "" { if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
func (h *Headscale) waitForFollowup( func (h *Headscale) waitForFollowup(
ctx context.Context, ctx context.Context,
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
fu, err := url.Parse(regReq.Followup) fu, err := url.Parse(regReq.Followup)
if err != nil { if err != nil {
@ -159,13 +160,49 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered: case node := <-reg.Registered:
if node == nil { if node == nil {
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil) // registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
} }
return nodeToRegisterResponse(node), nil return nodeToRegisterResponse(node), nil
} }
} }
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) // if the follow-up registration isn't found anymore, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
// reqToNewRegisterResponse refreshes the registration flow by creating a new
// registration ID and returning the corresponding AuthURL so the client can
// restart the authentication process.
func (h *Headscale) reqToNewRegisterResponse(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
}
log.Info().Msgf("New followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
}, nil
} }
func (h *Headscale) handleRegisterWithAuthKey( func (h *Headscale) handleRegisterWithAuthKey(
@ -244,16 +281,15 @@ func (h *Headscale) handleRegisterInteractive(
return nil, fmt.Errorf("generating registration ID: %w", err) return nil, fmt.Errorf("generating registration ID: %w", err)
} }
nodeToRegister := types.RegisterNode{ nodeToRegister := types.NewRegisterNode(
Node: types.Node{ types.Node{
Hostname: regReq.Hostinfo.Hostname, Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey, MachineKey: machineKey,
NodeKey: regReq.NodeKey, NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo, Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()), LastSeen: ptr.To(time.Now()),
}, },
Registered: make(chan *types.Node), )
}
if !regReq.Expiry.IsZero() { if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry nodeToRegister.Node.Expiry = &regReq.Expiry

View File

@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err return nil, err
} }
newNode := types.RegisterNode{ newNode := types.NewRegisterNode(
Node: types.Node{ types.Node{
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(), Hostname: request.GetName(),
@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostinfo: &hostinfo, Hostinfo: &hostinfo,
}, },
Registered: make(chan *types.Node), )
}
log.Debug(). log.Debug().
Caller(). Caller().

View File

@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
verb := "Reauthenticated" verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
}
httpError(writer, err) httpError(writer, err)
return return
} }

View File

@ -89,6 +89,12 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup, cacheCleanup,
) )
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
},
)
db, err := hsdb.NewHeadscaleDatabase( db, err := hsdb.NewHeadscaleDatabase(
cfg.Database, cfg.Database,
cfg.BaseDomain, cfg.BaseDomain,
@ -1248,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
s.nodeStore.PutNode(*savedNode) s.nodeStore.PutNode(*savedNode)
} }
// Signal to waiting clients
regEntry.SendAndClose(savedNode)
// Delete from registration cache // Delete from registration cache
s.registrationCache.Delete(registrationID) s.registrationCache.Delete(registrationID)
// Signal to waiting clients
select {
case regEntry.Registered <- savedNode:
default:
}
close(regEntry.Registered)
// Update policy manager // Update policy manager
nodesChange, err := s.updatePolicyManagerNodes() nodesChange, err := s.updatePolicyManagerNodes()
if err != nil { if err != nil {

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"sync/atomic"
"time" "time"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
type RegisterNode struct { type RegisterNode struct {
Node Node Node Node
Registered chan *Node Registered chan *Node
closed *atomic.Bool
}
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
}
}
func (rn *RegisterNode) SendAndClose(node *Node) {
if rn.closed.Swap(true) {
return
}
select {
case rn.Registered <- node:
default:
}
close(rn.Registered)
} }
// DefaultBatcherWorkers returns the default number of batcher workers. // DefaultBatcherWorkers returns the default number of batcher workers.

View File

@ -3,6 +3,7 @@ package integration
import ( import (
"maps" "maps"
"net/netip" "net/netip"
"net/url"
"sort" "sort"
"testing" "testing"
"time" "time"
@ -688,6 +689,116 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
} }
// TestOIDCFollowUpUrl validates the follow-up login flow
// Prerequisites:
// - short TTL for the registration cache via HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
// Scenario:
// - client starts a login process and gets initial AuthURL
// - time.sleep(HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION + 30 secs) waits for the cache to expire
// - client checks its status to verify that AuthUrl has changed (by followup URL)
// - client uses the new AuthURL to log in. It should complete successfully.
func TestOIDCFollowUpUrl(t *testing.T) {
IntegrationSkip(t)
// Create no nodes and no users
scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
// smaller cache expiration time to quickly expire AuthURL
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
// wait for the registration cache to expire
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
assertNoErr(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
assertNoErr(t, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
_, err = doLoginURL(ts.Hostname(), newUrl)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)
if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodes, 1)
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients // assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin). // are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {