mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
fix: return valid AuthUrl in followup request on expired reg id
- tailscale client gets a new AuthUrl and sets entry in the regcache - regcache entry expires - client doesn't know about that - client always polls followup request а gets error When user clicks "Login" in the app (after cache expiry), they visit invalid URL and get "node not found in registration cache". Some clients on Windows for e.g. can't get a new AuthUrl without restart the app. To fix that we can issue a new reg id and return user a new valid AuthUrl.
This commit is contained in:
parent
2174f6d0b9
commit
75217c4579
@ -59,7 +59,7 @@ func (h *Headscale) handleRegister(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if regReq.Followup != "" {
|
if regReq.Followup != "" {
|
||||||
return h.waitForFollowup(ctx, regReq)
|
return h.waitForFollowup(ctx, regReq, machineKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||||
@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.Change(c)
|
h.Change(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodeToRegisterResponse(node), nil
|
return nodeToRegisterResponse(node), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||||
@ -147,6 +147,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
|||||||
func (h *Headscale) waitForFollowup(
|
func (h *Headscale) waitForFollowup(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
fu, err := url.Parse(regReq.Followup)
|
fu, err := url.Parse(regReq.Followup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -164,13 +165,45 @@ func (h *Headscale) waitForFollowup(
|
|||||||
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
||||||
case node := <-reg.Registered:
|
case node := <-reg.Registered:
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
|
// registration is expired in the cache, instruct the client to try a new registration
|
||||||
|
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||||
}
|
}
|
||||||
return nodeToRegisterResponse(node), nil
|
return nodeToRegisterResponse(node), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
|
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
|
||||||
|
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) reqToNewRegisterResponse(
|
||||||
|
regReq tailcfg.RegisterRequest,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
|
newRegID, err := types.NewRegistrationID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeToRegister := types.RegisterNode{
|
||||||
|
Node: types.Node{
|
||||||
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
|
MachineKey: machineKey,
|
||||||
|
NodeKey: regReq.NodeKey,
|
||||||
|
Hostinfo: regReq.Hostinfo,
|
||||||
|
LastSeen: ptr.To(time.Now()),
|
||||||
|
},
|
||||||
|
Registered: make(chan *types.Node),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !regReq.Expiry.IsZero() {
|
||||||
|
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||||
|
}
|
||||||
|
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||||
|
|
||||||
|
return &tailcfg.RegisterResponse{
|
||||||
|
AuthURL: h.authProvider.AuthURL(newRegID),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleRegisterWithAuthKey(
|
func (h *Headscale) handleRegisterWithAuthKey(
|
||||||
|
@ -82,6 +82,16 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||||||
cacheCleanup,
|
cacheCleanup,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
registrationCache.OnEvicted(
|
||||||
|
func(id types.RegistrationID, node types.RegisterNode) {
|
||||||
|
select {
|
||||||
|
case node.Registered <- nil:
|
||||||
|
// notify the followup handler that registration is not valid anymore
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
db, err := hsdb.NewHeadscaleDatabase(
|
db, err := hsdb.NewHeadscaleDatabase(
|
||||||
cfg.Database,
|
cfg.Database,
|
||||||
cfg.BaseDomain,
|
cfg.BaseDomain,
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/oauth2-proxy/mockoidc"
|
"github.com/oauth2-proxy/mockoidc"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOIDCAuthenticationPingAll(t *testing.T) {
|
func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
@ -616,6 +617,110 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||||||
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
|
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCFollowUpUrl(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
// Create no nodes and no users
|
||||||
|
scenario, err := NewScenario(
|
||||||
|
ScenarioSpec{
|
||||||
|
OIDCUsers: []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
// smaller cache expiration time to quickly expire AuthURL
|
||||||
|
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
|
||||||
|
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "30s",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||||
|
nil,
|
||||||
|
hsic.WithTestName("oidcauthrelog"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||||
|
hsic.WithEmbeddedDERPServerOnly(),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err := headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Empty(t, listUsers)
|
||||||
|
|
||||||
|
ts, err := scenario.CreateTailscaleNode(
|
||||||
|
"unstable",
|
||||||
|
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
// wait for the registration cache to expire
|
||||||
|
time.Sleep(2 * time.Minute)
|
||||||
|
|
||||||
|
st, err := ts.Status()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Equal(t, "NeedsLogin", st.BackendState)
|
||||||
|
|
||||||
|
// get new AuthURL from daemon
|
||||||
|
t.Logf("Status: %s", st.AuthURL)
|
||||||
|
newUrl, err := url.Parse(st.AuthURL)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
t.Logf("AuthUrl: %s", newUrl.String())
|
||||||
|
|
||||||
|
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
|
||||||
|
|
||||||
|
_, err = doLoginURL(ts.Hostname(), newUrl)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err = headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listUsers, 1)
|
||||||
|
|
||||||
|
wantUsers := []*v1.User{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(
|
||||||
|
listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(
|
||||||
|
wantUsers,
|
||||||
|
listUsers,
|
||||||
|
cmpopts.IgnoreUnexported(v1.User{}),
|
||||||
|
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
|
||||||
|
); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listNodes, 1)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
|
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user