mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
Merge 80fe04971d
into bd35fcf338
This commit is contained in:
commit
f6d9361141
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
@ -31,6 +31,7 @@ jobs:
|
||||
- TestOIDC024UserCreation
|
||||
- TestOIDCAuthenticationWithPKCE
|
||||
- TestOIDCReloginSameNodeNewUser
|
||||
- TestOIDCFollowUpUrl
|
||||
- TestAuthWebFlowAuthenticationPingAll
|
||||
- TestAuthWebFlowLogoutAndRelogin
|
||||
- TestUserCommand
|
||||
|
@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
if regReq.Followup != "" {
|
||||
return h.waitForFollowup(ctx, regReq)
|
||||
return h.waitForFollowup(ctx, regReq, machineKey)
|
||||
}
|
||||
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
func (h *Headscale) waitForFollowup(
|
||||
ctx context.Context,
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
fu, err := url.Parse(regReq.Followup)
|
||||
if err != nil {
|
||||
@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup(
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
||||
case node := <-reg.Registered:
|
||||
if node == nil {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
|
||||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||
}
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
|
||||
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||
}
|
||||
|
||||
func (h *Headscale) reqToNewRegisterResponse(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
newRegID, err := types.NewRegistrationID()
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||
}
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||
}
|
||||
|
||||
log.Info().Msgf("New followup node registration using key: %s", newRegID)
|
||||
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
AuthURL: h.authProvider.AuthURL(newRegID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterWithAuthKey(
|
||||
@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||
}
|
||||
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
},
|
||||
Registered: make(chan *types.Node),
|
||||
}
|
||||
)
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||
|
@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newNode := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
newNode := types.NewRegisterNode(
|
||||
types.Node{
|
||||
NodeKey: key.NewNode().Public(),
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
Hostname: request.GetName(),
|
||||
@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
|
||||
Hostinfo: &hostinfo,
|
||||
},
|
||||
Registered: make(chan *types.Node),
|
||||
}
|
||||
)
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
|
@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
verb := "Reauthenticated"
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
||||
|
||||
return
|
||||
}
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
@ -74,9 +74,25 @@ type State struct {
|
||||
// NewState creates and initializes a new State instance, setting up the database,
|
||||
// IP allocator, DERP map, policy manager, and loading existing users and nodes.
|
||||
func NewState(cfg *types.Config) (*State, error) {
|
||||
cacheExpiration := registerCacheExpiration
|
||||
if cfg.Tuning.RegisterCacheExpiration != 0 {
|
||||
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
|
||||
}
|
||||
|
||||
cacheCleanup := registerCacheCleanup
|
||||
if cfg.Tuning.RegisterCacheCleanup != 0 {
|
||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
|
||||
registerCacheExpiration,
|
||||
registerCacheCleanup,
|
||||
cacheExpiration,
|
||||
cacheCleanup,
|
||||
)
|
||||
|
||||
registrationCache.OnEvicted(
|
||||
func(id types.RegistrationID, rn types.RegisterNode) {
|
||||
rn.SendAndClose(nil)
|
||||
},
|
||||
)
|
||||
|
||||
db, err := hsdb.NewHeadscaleDatabase(
|
||||
@ -1238,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
s.nodeStore.PutNode(*savedNode)
|
||||
}
|
||||
|
||||
// Signal to waiting clients
|
||||
regEntry.SendAndClose(savedNode)
|
||||
|
||||
// Delete from registration cache
|
||||
s.registrationCache.Delete(registrationID)
|
||||
|
||||
// Signal to waiting clients
|
||||
select {
|
||||
case regEntry.Registered <- savedNode:
|
||||
default:
|
||||
}
|
||||
close(regEntry.Registered)
|
||||
|
||||
// Update policy manager
|
||||
nodesChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
|
||||
type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
closed *atomic.Bool
|
||||
}
|
||||
|
||||
func NewRegisterNode(node Node) RegisterNode {
|
||||
return RegisterNode{
|
||||
Node: node,
|
||||
Registered: make(chan *Node),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rn *RegisterNode) SendAndClose(node *Node) {
|
||||
if rn.closed.Swap(true) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case rn.Registered <- node:
|
||||
default:
|
||||
}
|
||||
|
||||
close(rn.Registered)
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
|
@ -235,6 +235,8 @@ type Tuning struct {
|
||||
BatchChangeDelay time.Duration
|
||||
NodeMapSessionBufferedChanSize int
|
||||
BatcherWorkers int
|
||||
RegisterCacheCleanup time.Duration
|
||||
RegisterCacheExpiration time.Duration
|
||||
}
|
||||
|
||||
func validatePKCEMethod(method string) error {
|
||||
@ -1002,6 +1004,8 @@ func LoadServerConfig() (*Config, error) {
|
||||
}
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
||||
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package integration
|
||||
import (
|
||||
"maps"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
||||
}
|
||||
|
||||
func TestOIDCFollowUpUrl(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
// Create no nodes and no users
|
||||
scenario, err := NewScenario(
|
||||
ScenarioSpec{
|
||||
OIDCUsers: []mockoidc.MockUser{
|
||||
oidcMockUser("user1", true),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assertNoErr(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
|
||||
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
// smaller cache expiration time to quickly expire AuthURL
|
||||
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
|
||||
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s",
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||
nil,
|
||||
hsic.WithTestName("oidcauthrelog"),
|
||||
hsic.WithConfigEnv(oidcMap),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Empty(t, listUsers)
|
||||
|
||||
ts, err := scenario.CreateTailscaleNode(
|
||||
"unstable",
|
||||
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
|
||||
// wait for the registration cache to expire
|
||||
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
|
||||
time.Sleep(2 * time.Minute)
|
||||
|
||||
st, err := ts.Status()
|
||||
assertNoErr(t, err)
|
||||
assert.Equal(t, "NeedsLogin", st.BackendState)
|
||||
|
||||
// get new AuthURL from daemon
|
||||
newUrl, err := url.Parse(st.AuthURL)
|
||||
assertNoErr(t, err)
|
||||
|
||||
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), newUrl)
|
||||
assertNoErr(t, err)
|
||||
|
||||
listUsers, err = headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 1)
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
}
|
||||
|
||||
sort.Slice(
|
||||
listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
},
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(
|
||||
wantUsers,
|
||||
listUsers,
|
||||
cmpopts.IgnoreUnexported(v1.User{}),
|
||||
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
|
||||
); diff != "" {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
}
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listNodes, 1)
|
||||
}
|
||||
|
||||
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
||||
// are in the logged-out state (NeedsLogin).
|
||||
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
||||
|
Loading…
Reference in New Issue
Block a user