1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00
This commit is contained in:
Andrey 2025-09-22 13:54:14 +00:00 committed by GitHub
commit f6d9361141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 201 additions and 20 deletions

View File

@ -31,6 +31,7 @@ jobs:
- TestOIDC024UserCreation - TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE - TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser - TestOIDCReloginSameNodeNewUser
- TestOIDCFollowUpUrl
- TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin - TestAuthWebFlowLogoutAndRelogin
- TestUserCommand - TestUserCommand

View File

@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
} }
if regReq.Followup != "" { if regReq.Followup != "" {
return h.waitForFollowup(ctx, regReq) return h.waitForFollowup(ctx, regReq, machineKey)
} }
if regReq.Auth != nil && regReq.Auth.AuthKey != "" { if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
func (h *Headscale) waitForFollowup( func (h *Headscale) waitForFollowup(
ctx context.Context, ctx context.Context,
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
fu, err := url.Parse(regReq.Followup) fu, err := url.Parse(regReq.Followup)
if err != nil { if err != nil {
@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered: case node := <-reg.Registered:
if node == nil { if node == nil {
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil) // registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
} }
return nodeToRegisterResponse(node), nil return nodeToRegisterResponse(node), nil
} }
} }
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) // if the follow-up registration isn't found anymore, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
}
func (h *Headscale) reqToNewRegisterResponse(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()),
},
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
}
log.Info().Msgf("New followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
}, nil
} }
func (h *Headscale) handleRegisterWithAuthKey( func (h *Headscale) handleRegisterWithAuthKey(
@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive(
return nil, fmt.Errorf("generating registration ID: %w", err) return nil, fmt.Errorf("generating registration ID: %w", err)
} }
nodeToRegister := types.RegisterNode{ nodeToRegister := types.NewRegisterNode(
Node: types.Node{ types.Node{
Hostname: regReq.Hostinfo.Hostname, Hostname: regReq.Hostinfo.Hostname,
MachineKey: machineKey, MachineKey: machineKey,
NodeKey: regReq.NodeKey, NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo, Hostinfo: regReq.Hostinfo,
LastSeen: ptr.To(time.Now()), LastSeen: ptr.To(time.Now()),
}, },
Registered: make(chan *types.Node), )
}
if !regReq.Expiry.IsZero() { if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry nodeToRegister.Node.Expiry = &regReq.Expiry

View File

@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err return nil, err
} }
newNode := types.RegisterNode{ newNode := types.NewRegisterNode(
Node: types.Node{ types.Node{
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(), Hostname: request.GetName(),
@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostinfo: &hostinfo, Hostinfo: &hostinfo,
}, },
Registered: make(chan *types.Node), )
}
log.Debug(). log.Debug().
Caller(). Caller().

View File

@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
verb := "Reauthenticated" verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
}
httpError(writer, err) httpError(writer, err)
return return
} }

View File

@ -74,9 +74,25 @@ type State struct {
// NewState creates and initializes a new State instance, setting up the database, // NewState creates and initializes a new State instance, setting up the database,
// IP allocator, DERP map, policy manager, and loading existing users and nodes. // IP allocator, DERP map, policy manager, and loading existing users and nodes.
func NewState(cfg *types.Config) (*State, error) { func NewState(cfg *types.Config) (*State, error) {
cacheExpiration := registerCacheExpiration
if cfg.Tuning.RegisterCacheExpiration != 0 {
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
}
cacheCleanup := registerCacheCleanup
if cfg.Tuning.RegisterCacheCleanup != 0 {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration, cacheExpiration,
registerCacheCleanup, cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
},
) )
db, err := hsdb.NewHeadscaleDatabase( db, err := hsdb.NewHeadscaleDatabase(
@ -1238,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
s.nodeStore.PutNode(*savedNode) s.nodeStore.PutNode(*savedNode)
} }
// Signal to waiting clients
regEntry.SendAndClose(savedNode)
// Delete from registration cache // Delete from registration cache
s.registrationCache.Delete(registrationID) s.registrationCache.Delete(registrationID)
// Signal to waiting clients
select {
case regEntry.Registered <- savedNode:
default:
}
close(regEntry.Registered)
// Update policy manager // Update policy manager
nodesChange, err := s.updatePolicyManagerNodes() nodesChange, err := s.updatePolicyManagerNodes()
if err != nil { if err != nil {

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"sync/atomic"
"time" "time"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
type RegisterNode struct { type RegisterNode struct {
Node Node Node Node
Registered chan *Node Registered chan *Node
closed *atomic.Bool
}
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
}
}
func (rn *RegisterNode) SendAndClose(node *Node) {
if rn.closed.Swap(true) {
return
}
select {
case rn.Registered <- node:
default:
}
close(rn.Registered)
} }
// DefaultBatcherWorkers returns the default number of batcher workers. // DefaultBatcherWorkers returns the default number of batcher workers.

View File

@ -235,6 +235,8 @@ type Tuning struct {
BatchChangeDelay time.Duration BatchChangeDelay time.Duration
NodeMapSessionBufferedChanSize int NodeMapSessionBufferedChanSize int
BatcherWorkers int BatcherWorkers int
RegisterCacheCleanup time.Duration
RegisterCacheExpiration time.Duration
} }
func validatePKCEMethod(method string) error { func validatePKCEMethod(method string) error {
@ -1002,6 +1004,8 @@ func LoadServerConfig() (*Config, error) {
} }
return DefaultBatcherWorkers() return DefaultBatcherWorkers()
}(), }(),
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
}, },
}, nil }, nil
} }

View File

@ -3,6 +3,7 @@ package integration
import ( import (
"maps" "maps"
"net/netip" "net/netip"
"net/url"
"sort" "sort"
"testing" "testing"
"time" "time"
@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
} }
func TestOIDCFollowUpUrl(t *testing.T) {
IntegrationSkip(t)
// Create no nodes and no users
scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
// smaller cache expiration time to quickly expire AuthURL
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
assertNoErr(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
// wait for the registration cache to expire
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
assertNoErr(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
assertNoErr(t, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
_, err = doLoginURL(ts.Hostname(), newUrl)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)
if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodes, 1)
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients // assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin). // are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {