1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00
This commit is contained in:
Andrey 2025-09-22 13:54:14 +00:00 committed by GitHub
commit f6d9361141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 201 additions and 20 deletions

View File

@ -31,6 +31,7 @@ jobs:
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser
- TestOIDCFollowUpUrl
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestUserCommand

View File

@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
}
if regReq.Followup != "" {
return h.waitForFollowup(ctx, regReq)
return h.waitForFollowup(ctx, regReq, machineKey)
}
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
func (h *Headscale) waitForFollowup(
ctx context.Context,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
fu, err := url.Parse(regReq.Followup)
if err != nil {
@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered:
if node == nil {
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
return nodeToRegisterResponse(node), nil
}
}
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
func (h *Headscale) reqToNewRegisterResponse(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
}
log.Info().Msgf("New followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
}, nil
}
func (h *Headscale) handleRegisterWithAuthKey(
@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive(
return nil, fmt.Errorf("generating registration ID: %w", err)
}
nodeToRegister := types.RegisterNode{
Node: types.Node{
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
Registered: make(chan *types.Node),
}
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry

View File

@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err
}
newNode := types.RegisterNode{
Node: types.Node{
newNode := types.NewRegisterNode(
types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostinfo: &hostinfo,
},
Registered: make(chan *types.Node),
}
)
log.Debug().
Caller().

View File

@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
}
httpError(writer, err)
return
}

View File

@ -74,9 +74,25 @@ type State struct {
// NewState creates and initializes a new State instance, setting up the database,
// IP allocator, DERP map, policy manager, and loading existing users and nodes.
func NewState(cfg *types.Config) (*State, error) {
cacheExpiration := registerCacheExpiration
if cfg.Tuning.RegisterCacheExpiration != 0 {
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
}
cacheCleanup := registerCacheCleanup
if cfg.Tuning.RegisterCacheCleanup != 0 {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
cacheExpiration,
cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
},
)
db, err := hsdb.NewHeadscaleDatabase(
@ -1238,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
s.nodeStore.PutNode(*savedNode)
}
// Signal to waiting clients
regEntry.SendAndClose(savedNode)
// Delete from registration cache
s.registrationCache.Delete(registrationID)
// Signal to waiting clients
select {
case regEntry.Registered <- savedNode:
default:
}
close(regEntry.Registered)
// Update policy manager
nodesChange, err := s.updatePolicyManagerNodes()
if err != nil {

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"runtime"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/util"
@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
type RegisterNode struct {
Node Node
Registered chan *Node
closed *atomic.Bool
}
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
}
}
func (rn *RegisterNode) SendAndClose(node *Node) {
if rn.closed.Swap(true) {
return
}
select {
case rn.Registered <- node:
default:
}
close(rn.Registered)
}
// DefaultBatcherWorkers returns the default number of batcher workers.

View File

@ -235,6 +235,8 @@ type Tuning struct {
BatchChangeDelay time.Duration
NodeMapSessionBufferedChanSize int
BatcherWorkers int
RegisterCacheCleanup time.Duration
RegisterCacheExpiration time.Duration
}
func validatePKCEMethod(method string) error {
@ -1002,6 +1004,8 @@ func LoadServerConfig() (*Config, error) {
}
return DefaultBatcherWorkers()
}(),
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
},
}, nil
}

View File

@ -3,6 +3,7 @@ package integration
import (
"maps"
"net/netip"
"net/url"
"sort"
"testing"
"time"
@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
}
func TestOIDCFollowUpUrl(t *testing.T) {
IntegrationSkip(t)
// Create no nodes and no users
scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
// smaller cache expiration time to quickly expire AuthURL
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
// wait for the registration cache to expire
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
assertNoErr(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
assertNoErr(t, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
_, err = doLoginURL(ts.Hostname(), newUrl)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)
if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodes, 1)
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {