mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
Merge 80fe04971d
into bd35fcf338
This commit is contained in:
commit
f6d9361141
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
@ -31,6 +31,7 @@ jobs:
|
|||||||
- TestOIDC024UserCreation
|
- TestOIDC024UserCreation
|
||||||
- TestOIDCAuthenticationWithPKCE
|
- TestOIDCAuthenticationWithPKCE
|
||||||
- TestOIDCReloginSameNodeNewUser
|
- TestOIDCReloginSameNodeNewUser
|
||||||
|
- TestOIDCFollowUpUrl
|
||||||
- TestAuthWebFlowAuthenticationPingAll
|
- TestAuthWebFlowAuthenticationPingAll
|
||||||
- TestAuthWebFlowLogoutAndRelogin
|
- TestAuthWebFlowLogoutAndRelogin
|
||||||
- TestUserCommand
|
- TestUserCommand
|
||||||
|
@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if regReq.Followup != "" {
|
if regReq.Followup != "" {
|
||||||
return h.waitForFollowup(ctx, regReq)
|
return h.waitForFollowup(ctx, regReq, machineKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||||
@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
|||||||
func (h *Headscale) waitForFollowup(
|
func (h *Headscale) waitForFollowup(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
fu, err := url.Parse(regReq.Followup)
|
fu, err := url.Parse(regReq.Followup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -159,13 +160,46 @@ func (h *Headscale) waitForFollowup(
|
|||||||
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
||||||
case node := <-reg.Registered:
|
case node := <-reg.Registered:
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
|
// registration is expired in the cache, instruct the client to try a new registration
|
||||||
|
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||||
}
|
}
|
||||||
return nodeToRegisterResponse(node), nil
|
return nodeToRegisterResponse(node), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
|
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
|
||||||
|
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) reqToNewRegisterResponse(
|
||||||
|
regReq tailcfg.RegisterRequest,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
|
newRegID, err := types.NewRegistrationID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeToRegister := types.NewRegisterNode(
|
||||||
|
types.Node{
|
||||||
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
|
MachineKey: machineKey,
|
||||||
|
NodeKey: regReq.NodeKey,
|
||||||
|
Hostinfo: regReq.Hostinfo,
|
||||||
|
LastSeen: ptr.To(time.Now()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if !regReq.Expiry.IsZero() {
|
||||||
|
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("New followup node registration using key: %s", newRegID)
|
||||||
|
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||||
|
|
||||||
|
return &tailcfg.RegisterResponse{
|
||||||
|
AuthURL: h.authProvider.AuthURL(newRegID),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleRegisterWithAuthKey(
|
func (h *Headscale) handleRegisterWithAuthKey(
|
||||||
@ -244,16 +278,15 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeToRegister := types.RegisterNode{
|
nodeToRegister := types.NewRegisterNode(
|
||||||
Node: types.Node{
|
types.Node{
|
||||||
Hostname: regReq.Hostinfo.Hostname,
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
MachineKey: machineKey,
|
MachineKey: machineKey,
|
||||||
NodeKey: regReq.NodeKey,
|
NodeKey: regReq.NodeKey,
|
||||||
Hostinfo: regReq.Hostinfo,
|
Hostinfo: regReq.Hostinfo,
|
||||||
LastSeen: ptr.To(time.Now()),
|
LastSeen: ptr.To(time.Now()),
|
||||||
},
|
},
|
||||||
Registered: make(chan *types.Node),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
if !regReq.Expiry.IsZero() {
|
if !regReq.Expiry.IsZero() {
|
||||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||||
|
@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newNode := types.RegisterNode{
|
newNode := types.NewRegisterNode(
|
||||||
Node: types.Node{
|
types.Node{
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
Hostname: request.GetName(),
|
Hostname: request.GetName(),
|
||||||
@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||||||
|
|
||||||
Hostinfo: &hostinfo,
|
Hostinfo: &hostinfo,
|
||||||
},
|
},
|
||||||
Registered: make(chan *types.Node),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
verb := "Reauthenticated"
|
verb := "Reauthenticated"
|
||||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||||
|
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
||||||
|
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -74,9 +74,25 @@ type State struct {
|
|||||||
// NewState creates and initializes a new State instance, setting up the database,
|
// NewState creates and initializes a new State instance, setting up the database,
|
||||||
// IP allocator, DERP map, policy manager, and loading existing users and nodes.
|
// IP allocator, DERP map, policy manager, and loading existing users and nodes.
|
||||||
func NewState(cfg *types.Config) (*State, error) {
|
func NewState(cfg *types.Config) (*State, error) {
|
||||||
|
cacheExpiration := registerCacheExpiration
|
||||||
|
if cfg.Tuning.RegisterCacheExpiration != 0 {
|
||||||
|
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheCleanup := registerCacheCleanup
|
||||||
|
if cfg.Tuning.RegisterCacheCleanup != 0 {
|
||||||
|
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||||
|
}
|
||||||
|
|
||||||
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
|
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
|
||||||
registerCacheExpiration,
|
cacheExpiration,
|
||||||
registerCacheCleanup,
|
cacheCleanup,
|
||||||
|
)
|
||||||
|
|
||||||
|
registrationCache.OnEvicted(
|
||||||
|
func(id types.RegistrationID, rn types.RegisterNode) {
|
||||||
|
rn.SendAndClose(nil)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
db, err := hsdb.NewHeadscaleDatabase(
|
db, err := hsdb.NewHeadscaleDatabase(
|
||||||
@ -1238,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
s.nodeStore.PutNode(*savedNode)
|
s.nodeStore.PutNode(*savedNode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Signal to waiting clients
|
||||||
|
regEntry.SendAndClose(savedNode)
|
||||||
|
|
||||||
// Delete from registration cache
|
// Delete from registration cache
|
||||||
s.registrationCache.Delete(registrationID)
|
s.registrationCache.Delete(registrationID)
|
||||||
|
|
||||||
// Signal to waiting clients
|
|
||||||
select {
|
|
||||||
case regEntry.Registered <- savedNode:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
close(regEntry.Registered)
|
|
||||||
|
|
||||||
// Update policy manager
|
// Update policy manager
|
||||||
nodesChange, err := s.updatePolicyManagerNodes()
|
nodesChange, err := s.updatePolicyManagerNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
|
|||||||
type RegisterNode struct {
|
type RegisterNode struct {
|
||||||
Node Node
|
Node Node
|
||||||
Registered chan *Node
|
Registered chan *Node
|
||||||
|
closed *atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRegisterNode(node Node) RegisterNode {
|
||||||
|
return RegisterNode{
|
||||||
|
Node: node,
|
||||||
|
Registered: make(chan *Node),
|
||||||
|
closed: &atomic.Bool{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rn *RegisterNode) SendAndClose(node *Node) {
|
||||||
|
if rn.closed.Swap(true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case rn.Registered <- node:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
close(rn.Registered)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||||
|
@ -235,6 +235,8 @@ type Tuning struct {
|
|||||||
BatchChangeDelay time.Duration
|
BatchChangeDelay time.Duration
|
||||||
NodeMapSessionBufferedChanSize int
|
NodeMapSessionBufferedChanSize int
|
||||||
BatcherWorkers int
|
BatcherWorkers int
|
||||||
|
RegisterCacheCleanup time.Duration
|
||||||
|
RegisterCacheExpiration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func validatePKCEMethod(method string) error {
|
func validatePKCEMethod(method string) error {
|
||||||
@ -1002,6 +1004,8 @@ func LoadServerConfig() (*Config, error) {
|
|||||||
}
|
}
|
||||||
return DefaultBatcherWorkers()
|
return DefaultBatcherWorkers()
|
||||||
}(),
|
}(),
|
||||||
|
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
||||||
|
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package integration
|
|||||||
import (
|
import (
|
||||||
"maps"
|
"maps"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -688,6 +689,108 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||||||
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCFollowUpUrl(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
// Create no nodes and no users
|
||||||
|
scenario, err := NewScenario(
|
||||||
|
ScenarioSpec{
|
||||||
|
OIDCUsers: []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
// smaller cache expiration time to quickly expire AuthURL
|
||||||
|
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s",
|
||||||
|
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||||
|
nil,
|
||||||
|
hsic.WithTestName("oidcauthrelog"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||||
|
hsic.WithEmbeddedDERPServerOnly(),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err := headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Empty(t, listUsers)
|
||||||
|
|
||||||
|
ts, err := scenario.CreateTailscaleNode(
|
||||||
|
"unstable",
|
||||||
|
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
// wait for the registration cache to expire
|
||||||
|
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
|
||||||
|
time.Sleep(2 * time.Minute)
|
||||||
|
|
||||||
|
st, err := ts.Status()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Equal(t, "NeedsLogin", st.BackendState)
|
||||||
|
|
||||||
|
// get new AuthURL from daemon
|
||||||
|
newUrl, err := url.Parse(st.AuthURL)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
|
||||||
|
|
||||||
|
_, err = doLoginURL(ts.Hostname(), newUrl)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err = headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listUsers, 1)
|
||||||
|
|
||||||
|
wantUsers := []*v1.User{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(
|
||||||
|
listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(
|
||||||
|
wantUsers,
|
||||||
|
listUsers,
|
||||||
|
cmpopts.IgnoreUnexported(v1.User{}),
|
||||||
|
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
|
||||||
|
); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listNodes, 1)
|
||||||
|
}
|
||||||
|
|
||||||
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
||||||
// are in the logged-out state (NeedsLogin).
|
// are in the logged-out state (NeedsLogin).
|
||||||
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
||||||
|
Loading…
Reference in New Issue
Block a user