mirror of
https://github.com/juanfont/headscale.git
synced 2025-10-05 11:19:03 +02:00
Merge 360d1afe19
into 9313e5b058
This commit is contained in:
commit
c0722c6861
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
@ -25,6 +25,7 @@ jobs:
|
||||
- TestOIDCAuthenticationPingAll
|
||||
- TestOIDCExpireNodesBasedOnTokenExpiry
|
||||
- TestOIDC024UserCreation
|
||||
- TestOIDCAuthenticationWithPKCE
|
||||
- TestAuthWebFlowAuthenticationPingAll
|
||||
- TestAuthWebFlowLogoutAndRelogin
|
||||
- TestUserCommand
|
||||
|
@ -364,6 +364,18 @@ unix_socket_permission: "0770"
|
||||
# allowed_users:
|
||||
# - alice@example.com
|
||||
#
|
||||
# # Optional: PKCE (Proof Key for Code Exchange) configuration
|
||||
# # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow
|
||||
# # by preventing authorization code interception attacks
|
||||
# # See https://datatracker.ietf.org/doc/html/rfc7636
|
||||
# pkce:
|
||||
# # Enable or disable PKCE support (default: false)
|
||||
# enabled: false
|
||||
# # PKCE method to use:
|
||||
# # - plain: Use plain code verifier
|
||||
# # - S256: Use SHA256 hashed code verifier (default, recommended)
|
||||
# method: S256
|
||||
#
|
||||
# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users
|
||||
# # by taking the username from the legacy user and matching it with the username
|
||||
# # provided by the OIDC. This is useful when migrating from legacy users to OIDC
|
||||
|
@ -45,6 +45,18 @@ oidc:
|
||||
allowed_users:
|
||||
- alice@example.com
|
||||
|
||||
# Optional: PKCE (Proof Key for Code Exchange) configuration
|
||||
# PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow
|
||||
# by preventing authorization code interception attacks
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7636
|
||||
pkce:
|
||||
# Enable or disable PKCE support (default: false)
|
||||
enabled: false
|
||||
# PKCE method to use:
|
||||
# - plain: Use plain code verifier
|
||||
# - S256: Use SHA256 hashed code verifier (default, recommended)
|
||||
method: S256
|
||||
|
||||
# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
|
||||
# This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
|
||||
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
|
||||
|
@ -28,12 +28,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
randomByteSize = 16
|
||||
randomByteSize = 16
|
||||
defaultOAuthOptionsCount = 3
|
||||
)
|
||||
|
||||
var (
|
||||
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
||||
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
|
||||
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
|
||||
errOIDCAllowedDomains = errors.New(
|
||||
"authenticated principal does not match any allowed domain",
|
||||
)
|
||||
@ -47,11 +49,17 @@ var (
|
||||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||
)
|
||||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
MachineKey key.MachinePublic
|
||||
Verifier *string
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
db *db.HSDatabase
|
||||
registrationCache *zcache.Cache[string, key.MachinePublic]
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
notifier *notifier.Notifier
|
||||
ipAlloc *db.IPAllocator
|
||||
polMan policy.PolicyManager
|
||||
@ -87,7 +95,7 @@ func NewAuthProviderOIDC(
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[string, key.MachinePublic](
|
||||
registrationCache := zcache.New[string, RegistrationInfo](
|
||||
registerCacheExpiration,
|
||||
registerCacheCleanup,
|
||||
)
|
||||
@ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
a.registrationCache.Set(
|
||||
stateStr,
|
||||
machineKey,
|
||||
)
|
||||
// Initialize registration info with machine key
|
||||
registrationInfo := RegistrationInfo{
|
||||
MachineKey: machineKey,
|
||||
}
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
||||
// Add PKCE verification if enabled
|
||||
if a.cfg.PKCE.Enabled {
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
registrationInfo.Verifier = &verifier
|
||||
|
||||
extras = append(extras, oauth2.AccessTypeOffline)
|
||||
|
||||
switch a.cfg.PKCE.Method {
|
||||
case types.PKCEMethodS256:
|
||||
extras = append(extras, oauth2.S256ChallengeOption(verifier))
|
||||
case types.PKCEMethodPlain:
|
||||
// oauth2 does not have a plain challenge option, so we add it manually
|
||||
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
|
||||
}
|
||||
}
|
||||
|
||||
// Add any extra parameters from configuration
|
||||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
// Cache the registration info
|
||||
a.registrationCache.Set(stateStr, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
@ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := a.extractIDToken(req.Context(), code)
|
||||
idToken, err := a.extractIDToken(req.Context(), code, state)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
@ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest(
|
||||
func (a *AuthProviderOIDC) extractIDToken(
|
||||
ctx context.Context,
|
||||
code string,
|
||||
state string,
|
||||
) (*oidc.IDToken, error) {
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
|
||||
var exchangeOpts []oauth2.AuthCodeOption
|
||||
|
||||
if a.cfg.PKCE.Enabled {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
if !ok {
|
||||
return nil, errNoOIDCRegistrationInfo
|
||||
}
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
}
|
||||
}
|
||||
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not exchange code for token: %w", err)
|
||||
}
|
||||
@ -394,7 +432,7 @@ func validateOIDCAllowedUsers(
|
||||
// cache. If the machine key is found, it will try retrieve the
|
||||
// node information from the database.
|
||||
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
|
||||
machineKey, ok := a.registrationCache.Get(state)
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
@ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := a.db.GetNodeByMachineKey(machineKey)
|
||||
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
|
||||
|
||||
return node, &machineKey
|
||||
return node, ®Info.MachineKey
|
||||
}
|
||||
|
||||
// reauthenticateNode updates the node expiry in the database
|
||||
|
@ -26,11 +26,14 @@ import (
|
||||
const (
|
||||
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
|
||||
maxDuration time.Duration = 1<<63 - 1
|
||||
PKCEMethodPlain string = "plain"
|
||||
PKCEMethodS256 string = "S256"
|
||||
)
|
||||
|
||||
var (
|
||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||
)
|
||||
|
||||
type IPAllocationStrategy string
|
||||
@ -162,6 +165,11 @@ type LetsEncryptConfig struct {
|
||||
ChallengeType string
|
||||
}
|
||||
|
||||
type PKCEConfig struct {
|
||||
Enabled bool
|
||||
Method string
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
OnlyStartIfOIDCIsAvailable bool
|
||||
Issuer string
|
||||
@ -176,6 +184,7 @@ type OIDCConfig struct {
|
||||
Expiry time.Duration
|
||||
UseExpiryFromToken bool
|
||||
MapLegacyUsers bool
|
||||
PKCE PKCEConfig
|
||||
}
|
||||
|
||||
type DERPConfig struct {
|
||||
@ -226,6 +235,13 @@ type Tuning struct {
|
||||
NodeMapSessionBufferedChanSize int
|
||||
}
|
||||
|
||||
func validatePKCEMethod(method string) error {
|
||||
if method != PKCEMethodPlain && method != PKCEMethodS256 {
|
||||
return errInvalidPKCEMethod
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfig prepares and loads the Headscale configuration into Viper.
|
||||
// This means it sets the default values, reads the configuration file and
|
||||
// environment variables, and handles deprecated configuration options.
|
||||
@ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error {
|
||||
viper.SetDefault("oidc.expiry", "180d")
|
||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||
viper.SetDefault("oidc.map_legacy_users", true)
|
||||
viper.SetDefault("oidc.pkce.enabled", false)
|
||||
viper.SetDefault("oidc.pkce.method", "S256")
|
||||
|
||||
viper.SetDefault("logtail.enabled", false)
|
||||
viper.SetDefault("randomize_client_port", false)
|
||||
@ -340,6 +358,12 @@ func validateServerConfig() error {
|
||||
// after #2170 is cleaned up
|
||||
// depr.fatal("oidc.strip_email_domain")
|
||||
|
||||
if viper.GetBool("oidc.enabled") {
|
||||
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
depr.Log()
|
||||
|
||||
for _, removed := range []string{
|
||||
@ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) {
|
||||
// after #2170 is cleaned up
|
||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
|
||||
PKCE: PKCEConfig{
|
||||
Enabled: viper.GetBool("oidc.pkce.enabled"),
|
||||
Method: viper.GetString("oidc.pkce.method"),
|
||||
},
|
||||
},
|
||||
|
||||
LogTail: logTailConfig,
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -34,9 +35,13 @@ const (
|
||||
dockerContextPath = "../."
|
||||
hsicOIDCMockHashLength = 6
|
||||
defaultAccessTTL = 10 * time.Minute
|
||||
nodeStateRunning = "Running"
|
||||
)
|
||||
|
||||
var errStatusCodeNotOK = errors.New("status code not OK")
|
||||
var (
|
||||
errStatusCodeNotOK = errors.New("status code not OK")
|
||||
ErrOIDCClientCount = errors.New("client count must be 1 for OIDC scenario")
|
||||
)
|
||||
|
||||
type AuthOIDCScenario struct {
|
||||
*Scenario
|
||||
@ -534,6 +539,211 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
|
||||
scenario := AuthOIDCScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// Single user with one node for testing PKCE flow
|
||||
spec := map[string]int{
|
||||
"user1": 1,
|
||||
}
|
||||
|
||||
mockusers := []mockoidc.MockUser{
|
||||
oidcMockUser("user1", true),
|
||||
}
|
||||
|
||||
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||
defer scenario.mockOIDC.Close()
|
||||
|
||||
oidcMap := map[string]string{
|
||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
|
||||
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
spec,
|
||||
hsic.WithTestName("oidcauthpkce"),
|
||||
hsic.WithConfigEnv(oidcMap),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithHostnameAsServerURL(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
// Get all clients and verify they can connect
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
// Verify PKCE was used in authentication
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
var listUsers []v1.User
|
||||
err = executeAndUnmarshal(headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"users",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&listUsers,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
// Verify all clients are connected and authenticated
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
if status.BackendState != nodeStateRunning {
|
||||
t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type tamperVerifierTransport struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *tamperVerifierTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
log.Printf("RoundTrip: %s %s", req.Method, req.URL.String())
|
||||
|
||||
// For POST requests, tamper with form data
|
||||
if req.Method == http.MethodPost {
|
||||
log.Printf("Processing POST request")
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
log.Printf("Error parsing form: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if verifier := req.Form.Get("code_challenge"); verifier != "" {
|
||||
log.Printf("Found POST verifier: %s", verifier)
|
||||
// Tamper with the verifier
|
||||
req.Form.Set("code_challenge", verifier+"_tampered")
|
||||
log.Printf("Modified POST verifier to: %s", req.Form.Get("code_challenge"))
|
||||
// Update request body with modified form
|
||||
req.Body = io.NopCloser(strings.NewReader(req.Form.Encode()))
|
||||
req.ContentLength = int64(len(req.Form.Encode()))
|
||||
} else {
|
||||
log.Printf("No code_challenge found in POST form data")
|
||||
}
|
||||
}
|
||||
|
||||
// For GET requests, tamper with URL query parameters
|
||||
if req.Method == http.MethodGet {
|
||||
log.Printf("Processing GET request")
|
||||
q := req.URL.Query()
|
||||
if verifier := q.Get("code_challenge"); verifier != "" {
|
||||
log.Printf("Found GET verifier: %s", verifier)
|
||||
q.Set("code_challenge", verifier+"_tampered")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
log.Printf("Modified URL to: %s", req.URL.String())
|
||||
} else {
|
||||
log.Printf("No code_challenge found in GET query params")
|
||||
}
|
||||
}
|
||||
|
||||
// Forward the request with the tampered verifier
|
||||
resp, err := t.base.RoundTrip(req)
|
||||
if err != nil {
|
||||
log.Printf("RoundTrip error: %v", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("Response status: %s", resp.Status)
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func TestOIDCAuthenticationWithPKCEVerifierTampering(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
|
||||
scenario := AuthOIDCScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// Single user with one node for testing PKCE flow
|
||||
spec := map[string]int{
|
||||
"user1": 1,
|
||||
}
|
||||
|
||||
mockusers := []mockoidc.MockUser{
|
||||
oidcMockUser("user1", true),
|
||||
}
|
||||
|
||||
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||
defer scenario.mockOIDC.Close()
|
||||
|
||||
oidcMap := map[string]string{
|
||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
|
||||
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||
}
|
||||
|
||||
// Create a transport that modifies the PKCE verifier in transit
|
||||
baseTransport := &http.Transport{
|
||||
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
tamperTransport := &tamperVerifierTransport{
|
||||
base: baseTransport,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnvWithHTTPModifier(
|
||||
spec,
|
||||
func(cli *http.Client) {
|
||||
cli.Transport = tamperTransport
|
||||
},
|
||||
hsic.WithTestName("oidcauthpkce"),
|
||||
hsic.WithConfigEnv(oidcMap),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithHostnameAsServerURL(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("expected authentication to fail due to PKCE verifier tampering, but it succeeded")
|
||||
} else {
|
||||
log.Printf("auth got error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||
users map[string]int,
|
||||
opts ...hsic.Option,
|
||||
@ -554,7 +764,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||
// This is because the MockOIDC server can only serve login
|
||||
// requests based on a queue it has been given on startup.
|
||||
// We currently only populates it with one login request per user.
|
||||
return fmt.Errorf("client count must be 1 for OIDC scenario.")
|
||||
return ErrOIDCClientCount
|
||||
}
|
||||
log.Printf("creating user %s with %d clients", userName, clientCount)
|
||||
err = s.CreateUser(userName)
|
||||
@ -576,6 +786,49 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthOIDCScenario) CreateHeadscaleEnvWithHTTPModifier(
|
||||
users map[string]int,
|
||||
httpModifier func(*http.Client),
|
||||
opts ...hsic.Option,
|
||||
) error {
|
||||
headscale, err := s.Headscale(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = headscale.WaitForRunning()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for userName, clientCount := range users {
|
||||
if clientCount != 1 {
|
||||
// OIDC scenario only supports one client per user.
|
||||
// This is because the MockOIDC server can only serve login
|
||||
// requests based on a queue it has been given on startup.
|
||||
// We currently only populates it with one login request per user.
|
||||
return ErrOIDCClientCount
|
||||
}
|
||||
log.Printf("creating user %s with %d clients", userName, clientCount)
|
||||
err = s.CreateUser(userName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.runTailscaleUpWithModifier(userName, headscale.GetEndpoint(), httpModifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
|
||||
port, err := dockertestutil.RandomFreeHostPort()
|
||||
if err != nil {
|
||||
@ -685,7 +938,7 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
|
||||
}
|
||||
|
||||
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
|
||||
loginURL.Host = headscale.GetIP() + ":8080"
|
||||
loginURL.Scheme = "http"
|
||||
|
||||
if len(headscale.GetCert()) > 0 {
|
||||
@ -693,6 +946,7 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
}
|
||||
|
||||
insecureTransport := &http.Transport{
|
||||
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||
}
|
||||
|
||||
@ -759,6 +1013,98 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
||||
}
|
||||
|
||||
func (s *AuthOIDCScenario) runTailscaleUpWithModifier(
|
||||
userStr string,
|
||||
loginServer string,
|
||||
httpClientModifier func(*http.Client),
|
||||
) error {
|
||||
headscale, err := s.Headscale()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("running tailscale up for user %s", userStr)
|
||||
if user, ok := s.users[userStr]; ok {
|
||||
for _, client := range user.Clients {
|
||||
c := client
|
||||
err := func() error {
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
log.Printf("%s failed to get status: %s", c.Hostname(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
if status.BackendState == nodeStateRunning {
|
||||
log.Printf("%s is already running", c.Hostname())
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("%s running tailscale up", c.Hostname())
|
||||
|
||||
loginURL, err := c.LoginWithURL(loginServer)
|
||||
if err != nil {
|
||||
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
loginURL.Host = headscale.GetIP() + ":8080"
|
||||
loginURL.Scheme = "http"
|
||||
|
||||
if len(headscale.GetCert()) > 0 {
|
||||
loginURL.Scheme = "https"
|
||||
}
|
||||
|
||||
insecureTransport := &http.Transport{
|
||||
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||
}
|
||||
|
||||
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
||||
|
||||
log.Printf("%s logging in with url", c.Hostname())
|
||||
httpClient := &http.Client{Transport: insecureTransport}
|
||||
|
||||
// Allow the test to modify the HTTP client
|
||||
if httpClientModifier != nil {
|
||||
httpClientModifier(httpClient)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf(
|
||||
"%s failed to login using url %s: %s",
|
||||
c.Hostname(),
|
||||
loginURL,
|
||||
err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Printf("body: %s", body)
|
||||
|
||||
return errStatusCodeNotOK
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
||||
}
|
||||
|
||||
func (s *AuthOIDCScenario) Shutdown() {
|
||||
err := s.pool.Purge(s.mockOIDC)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user