1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

fix: return valid AuthUrl in followup request on expired reg id

- tailscale client gets a new AuthUrl and sets entry in the regcache
- regcache entry expires
- client doesn't know about that
- client always polls followup request а gets error

When user clicks "Login" in the app (after cache expiry), they visit
invalid URL and get "node not found in registration cache". Some clients
on Windows for e.g. can't get a new AuthUrl without restart the app.

To fix that we can issue a new reg id and return user a new valid
AuthUrl.

RegisterNode is refactored to be created with NewRegisterNode() to
autocreate channel and other stuff.
This commit is contained in:
Andrey Bobelev 2025-08-29 15:55:42 +02:00
parent 2dfba3d2cc
commit 80fe04971d
No known key found for this signature in database
GPG Key ID: 8BC07FB9FEAEEF63
7 changed files with 185 additions and 18 deletions

View File

@ -31,6 +31,7 @@ jobs:
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser
- TestOIDCFollowUpUrl
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestUserCommand

View File

@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
}
if regReq.Followup != "" {
return h.waitForFollowup(ctx, regReq)
return h.waitForFollowup(ctx, regReq, machineKey)
}
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
func (h *Headscale) waitForFollowup(
ctx context.Context,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
fu, err := url.Parse(regReq.Followup)
if err != nil {
@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered:
if node == nil {
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
return nodeToRegisterResponse(node), nil
}
}
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
func (h *Headscale) reqToNewRegisterResponse(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
}
log.Info().Msgf("New followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
}, nil
}
func (h *Headscale) handleRegisterWithAuthKey(
@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive(
return nil, fmt.Errorf("generating registration ID: %w", err)
}
nodeToRegister := types.RegisterNode{
Node: types.Node{
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
Registered: make(chan *types.Node),
}
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry

View File

@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err
}
newNode := types.RegisterNode{
Node: types.Node{
newNode := types.NewRegisterNode(
types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostinfo: &hostinfo,
},
Registered: make(chan *types.Node),
}
)
log.Debug().
Caller().

View File

@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
}
httpError(writer, err)
return
}

View File

@ -89,6 +89,12 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
},
)
db, err := hsdb.NewHeadscaleDatabase(
cfg.Database,
cfg.BaseDomain,
@ -1248,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
s.nodeStore.PutNode(*savedNode)
}
// Signal to waiting clients
regEntry.SendAndClose(savedNode)
// Delete from registration cache
s.registrationCache.Delete(registrationID)
// Signal to waiting clients
select {
case regEntry.Registered <- savedNode:
default:
}
close(regEntry.Registered)
// Update policy manager
nodesChange, err := s.updatePolicyManagerNodes()
if err != nil {

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"runtime"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/util"
@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
type RegisterNode struct {
Node Node
Registered chan *Node
closed *atomic.Bool
}
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
}
}
func (rn *RegisterNode) SendAndClose(node *Node) {
if rn.closed.Swap(true) {
return
}
select {
case rn.Registered <- node:
default:
}
close(rn.Registered)
}
// DefaultBatcherWorkers returns the default number of batcher workers.

View File

@ -3,6 +3,7 @@ package integration
import (
"maps"
"net/netip"
"net/url"
"sort"
"testing"
"time"
@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
}
func TestOIDCFollowUpUrl(t *testing.T) {
IntegrationSkip(t)
// Create no nodes and no users
scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
// smaller cache expiration time to quickly expire AuthURL
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
// wait for the registration cache to expire
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
assertNoErr(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
assertNoErr(t, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
_, err = doLoginURL(ts.Hostname(), newUrl)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)
if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodes, 1)
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {