1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

fix: return valid AuthUrl in followup request on expired reg id

- tailscale client gets a new AuthUrl and sets entry in the regcache
- regcache entry expires
- client doesn't know about that
- client always polls followup request а gets error

When user clicks "Login" in the app (after cache expiry), they visit
invalid URL and get "node not found in registration cache". Some clients
on Windows for e.g. can't get a new AuthUrl without restart the app.

To fix that we can issue a new reg id and return user a new valid
AuthUrl.
This commit is contained in:
Andrey Bobelev 2025-08-29 15:55:42 +02:00
parent 2174f6d0b9
commit 75217c4579
No known key found for this signature in database
GPG Key ID: 8BC07FB9FEAEEF63
3 changed files with 153 additions and 5 deletions

View File

@ -59,7 +59,7 @@ func (h *Headscale) handleRegister(
}
if regReq.Followup != "" {
return h.waitForFollowup(ctx, regReq)
return h.waitForFollowup(ctx, regReq, machineKey)
}
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode(
}
h.Change(c)
}
}
return nodeToRegisterResponse(node), nil
return nodeToRegisterResponse(node), nil
}
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
@ -147,6 +147,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
func (h *Headscale) waitForFollowup(
ctx context.Context,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
fu, err := url.Parse(regReq.Followup)
if err != nil {
@ -164,13 +165,45 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered:
if node == nil {
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
return nodeToRegisterResponse(node), nil
}
}
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
func (h *Headscale) reqToNewRegisterResponse(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
Registered: make(chan *types.Node),
}
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
}
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
}, nil
}
func (h *Headscale) handleRegisterWithAuthKey(

View File

@ -82,6 +82,16 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, node types.RegisterNode) {
select {
case node.Registered <- nil:
// notify the followup handler that registration is not valid anymore
default:
}
},
)
db, err := hsdb.NewHeadscaleDatabase(
cfg.Database,
cfg.BaseDomain,

View File

@ -15,6 +15,7 @@ import (
"github.com/oauth2-proxy/mockoidc"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"net/url"
)
func TestOIDCAuthenticationPingAll(t *testing.T) {
@ -616,6 +617,110 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
}
func TestOIDCFollowUpUrl(t *testing.T) {
IntegrationSkip(t)
// Create no nodes and no users
scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
// smaller cache expiration time to quickly expire AuthURL
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "30s",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
// wait for the registration cache to expire
time.Sleep(2 * time.Minute)
st, err := ts.Status()
assertNoErr(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
t.Logf("Status: %s", st.AuthURL)
newUrl, err := url.Parse(st.AuthURL)
assertNoErr(t, err)
t.Logf("AuthUrl: %s", newUrl.String())
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
_, err = doLoginURL(ts.Hostname(), newUrl)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)
if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodes, 1)
}
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
t.Helper()