mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge 360d1afe19 into 9313e5b058
				
					
				
			This commit is contained in:
		
						commit
						c0722c6861
					
				
							
								
								
									
										1
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							| @ -25,6 +25,7 @@ jobs: | ||||
|           - TestOIDCAuthenticationPingAll | ||||
|           - TestOIDCExpireNodesBasedOnTokenExpiry | ||||
|           - TestOIDC024UserCreation | ||||
|           - TestOIDCAuthenticationWithPKCE | ||||
|           - TestAuthWebFlowAuthenticationPingAll | ||||
|           - TestAuthWebFlowLogoutAndRelogin | ||||
|           - TestUserCommand | ||||
|  | ||||
| @ -364,6 +364,18 @@ unix_socket_permission: "0770" | ||||
| #   allowed_users: | ||||
| #     - alice@example.com | ||||
| # | ||||
| #   # Optional: PKCE (Proof Key for Code Exchange) configuration | ||||
| #   # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow | ||||
| #   # by preventing authorization code interception attacks | ||||
| #   # See https://datatracker.ietf.org/doc/html/rfc7636 | ||||
| #   pkce: | ||||
| #     # Enable or disable PKCE support (default: false) | ||||
| #     enabled: false | ||||
| #     # PKCE method to use: | ||||
| #     # - plain: Use plain code verifier | ||||
| #     # - S256: Use SHA256 hashed code verifier (default, recommended) | ||||
| #     method: S256 | ||||
| # | ||||
| #   # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users | ||||
| #   # by taking the username from the legacy user and matching it with the username | ||||
| #   # provided by the OIDC. This is useful when migrating from legacy users to OIDC | ||||
|  | ||||
| @ -45,6 +45,18 @@ oidc: | ||||
|   allowed_users: | ||||
|     - alice@example.com | ||||
| 
 | ||||
|   # Optional: PKCE (Proof Key for Code Exchange) configuration | ||||
|   # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow | ||||
|   # by preventing authorization code interception attacks | ||||
|   # See https://datatracker.ietf.org/doc/html/rfc7636 | ||||
|   pkce: | ||||
|     # Enable or disable PKCE support (default: false) | ||||
|     enabled: false | ||||
|     # PKCE method to use: | ||||
|     # - plain: Use plain code verifier | ||||
|     # - S256: Use SHA256 hashed code verifier (default, recommended) | ||||
|     method: S256 | ||||
| 
 | ||||
|   # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. | ||||
|   # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` | ||||
|   # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following | ||||
|  | ||||
| @ -28,12 +28,14 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	randomByteSize = 16 | ||||
| 	randomByteSize           = 16 | ||||
| 	defaultOAuthOptionsCount = 3 | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") | ||||
| 	errNoOIDCIDToken           = errors.New("could not extract ID Token for OIDC callback") | ||||
| 	errNoOIDCRegistrationInfo  = errors.New("could not get registration info from cache") | ||||
| 	errOIDCAllowedDomains      = errors.New( | ||||
| 		"authenticated principal does not match any allowed domain", | ||||
| 	) | ||||
| @ -47,11 +49,17 @@ var ( | ||||
| 	errOIDCNodeKeyMissing = errors.New("could not get node key from cache") | ||||
| ) | ||||
| 
 | ||||
| // RegistrationInfo contains both machine key and verifier information for OIDC validation.
 | ||||
| type RegistrationInfo struct { | ||||
| 	MachineKey key.MachinePublic | ||||
| 	Verifier   *string | ||||
| } | ||||
| 
 | ||||
| type AuthProviderOIDC struct { | ||||
| 	serverURL         string | ||||
| 	cfg               *types.OIDCConfig | ||||
| 	db                *db.HSDatabase | ||||
| 	registrationCache *zcache.Cache[string, key.MachinePublic] | ||||
| 	registrationCache *zcache.Cache[string, RegistrationInfo] | ||||
| 	notifier          *notifier.Notifier | ||||
| 	ipAlloc           *db.IPAllocator | ||||
| 	polMan            policy.PolicyManager | ||||
| @ -87,7 +95,7 @@ func NewAuthProviderOIDC( | ||||
| 		Scopes: cfg.Scope, | ||||
| 	} | ||||
| 
 | ||||
| 	registrationCache := zcache.New[string, key.MachinePublic]( | ||||
| 	registrationCache := zcache.New[string, RegistrationInfo]( | ||||
| 		registerCacheExpiration, | ||||
| 		registerCacheCleanup, | ||||
| 	) | ||||
| @ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler( | ||||
| 
 | ||||
| 	stateStr := hex.EncodeToString(randomBlob)[:32] | ||||
| 
 | ||||
| 	// place the node key into the state cache, so it can be retrieved later
 | ||||
| 	a.registrationCache.Set( | ||||
| 		stateStr, | ||||
| 		machineKey, | ||||
| 	) | ||||
| 	// Initialize registration info with machine key
 | ||||
| 	registrationInfo := RegistrationInfo{ | ||||
| 		MachineKey: machineKey, | ||||
| 	} | ||||
| 
 | ||||
| 	// Add any extra parameter provided in the configuration to the Authorize Endpoint request
 | ||||
| 	extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)) | ||||
| 	extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) | ||||
| 	// Add PKCE verification if enabled
 | ||||
| 	if a.cfg.PKCE.Enabled { | ||||
| 		verifier := oauth2.GenerateVerifier() | ||||
| 		registrationInfo.Verifier = &verifier | ||||
| 
 | ||||
| 		extras = append(extras, oauth2.AccessTypeOffline) | ||||
| 
 | ||||
| 		switch a.cfg.PKCE.Method { | ||||
| 		case types.PKCEMethodS256: | ||||
| 			extras = append(extras, oauth2.S256ChallengeOption(verifier)) | ||||
| 		case types.PKCEMethodPlain: | ||||
| 			// oauth2 does not have a plain challenge option, so we add it manually
 | ||||
| 			extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Add any extra parameters from configuration
 | ||||
| 	for k, v := range a.cfg.ExtraParams { | ||||
| 		extras = append(extras, oauth2.SetAuthURLParam(k, v)) | ||||
| 	} | ||||
| 
 | ||||
| 	// Cache the registration info
 | ||||
| 	a.registrationCache.Set(stateStr, registrationInfo) | ||||
| 
 | ||||
| 	authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...) | ||||
| 	log.Debug().Msgf("Redirecting to %s for authentication", authURL) | ||||
| 
 | ||||
| @ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	idToken, err := a.extractIDToken(req.Context(), code) | ||||
| 	idToken, err := a.extractIDToken(req.Context(), code, state) | ||||
| 	if err != nil { | ||||
| 		http.Error(writer, err.Error(), http.StatusBadRequest) | ||||
| 		return | ||||
| @ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest( | ||||
| func (a *AuthProviderOIDC) extractIDToken( | ||||
| 	ctx context.Context, | ||||
| 	code string, | ||||
| 	state string, | ||||
| ) (*oidc.IDToken, error) { | ||||
| 	oauth2Token, err := a.oauth2Config.Exchange(ctx, code) | ||||
| 	var exchangeOpts []oauth2.AuthCodeOption | ||||
| 
 | ||||
| 	if a.cfg.PKCE.Enabled { | ||||
| 		regInfo, ok := a.registrationCache.Get(state) | ||||
| 		if !ok { | ||||
| 			return nil, errNoOIDCRegistrationInfo | ||||
| 		} | ||||
| 		if regInfo.Verifier != nil { | ||||
| 			exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not exchange code for token: %w", err) | ||||
| 	} | ||||
| @ -394,7 +432,7 @@ func validateOIDCAllowedUsers( | ||||
| // cache. If the machine key is found, it will try retrieve the
 | ||||
| // node information from the database.
 | ||||
| func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) { | ||||
| 	machineKey, ok := a.registrationCache.Get(state) | ||||
| 	regInfo, ok := a.registrationCache.Get(state) | ||||
| 	if !ok { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| @ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k | ||||
| 	// The error is not important, because if it does not
 | ||||
| 	// exist, then this is a new node and we will move
 | ||||
| 	// on to registration.
 | ||||
| 	node, _ := a.db.GetNodeByMachineKey(machineKey) | ||||
| 	node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey) | ||||
| 
 | ||||
| 	return node, &machineKey | ||||
| 	return node, ®Info.MachineKey | ||||
| } | ||||
| 
 | ||||
| // reauthenticateNode updates the node expiry in the database
 | ||||
|  | ||||
| @ -26,11 +26,14 @@ import ( | ||||
| const ( | ||||
| 	defaultOIDCExpiryTime               = 180 * 24 * time.Hour // 180 Days
 | ||||
| 	maxDuration           time.Duration = 1<<63 - 1 | ||||
| 	PKCEMethodPlain       string        = "plain" | ||||
| 	PKCEMethodS256        string        = "S256" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") | ||||
| 	errServerURLSuffix       = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") | ||||
| 	errInvalidPKCEMethod     = errors.New("pkce.method must be either 'plain' or 'S256'") | ||||
| ) | ||||
| 
 | ||||
| type IPAllocationStrategy string | ||||
| @ -162,6 +165,11 @@ type LetsEncryptConfig struct { | ||||
| 	ChallengeType string | ||||
| } | ||||
| 
 | ||||
| type PKCEConfig struct { | ||||
| 	Enabled bool | ||||
| 	Method  string | ||||
| } | ||||
| 
 | ||||
| type OIDCConfig struct { | ||||
| 	OnlyStartIfOIDCIsAvailable bool | ||||
| 	Issuer                     string | ||||
| @ -176,6 +184,7 @@ type OIDCConfig struct { | ||||
| 	Expiry                     time.Duration | ||||
| 	UseExpiryFromToken         bool | ||||
| 	MapLegacyUsers             bool | ||||
| 	PKCE                       PKCEConfig | ||||
| } | ||||
| 
 | ||||
| type DERPConfig struct { | ||||
| @ -226,6 +235,13 @@ type Tuning struct { | ||||
| 	NodeMapSessionBufferedChanSize int | ||||
| } | ||||
| 
 | ||||
| func validatePKCEMethod(method string) error { | ||||
| 	if method != PKCEMethodPlain && method != PKCEMethodS256 { | ||||
| 		return errInvalidPKCEMethod | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // LoadConfig prepares and loads the Headscale configuration into Viper.
 | ||||
| // This means it sets the default values, reads the configuration file and
 | ||||
| // environment variables, and handles deprecated configuration options.
 | ||||
| @ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error { | ||||
| 	viper.SetDefault("oidc.expiry", "180d") | ||||
| 	viper.SetDefault("oidc.use_expiry_from_token", false) | ||||
| 	viper.SetDefault("oidc.map_legacy_users", true) | ||||
| 	viper.SetDefault("oidc.pkce.enabled", false) | ||||
| 	viper.SetDefault("oidc.pkce.method", "S256") | ||||
| 
 | ||||
| 	viper.SetDefault("logtail.enabled", false) | ||||
| 	viper.SetDefault("randomize_client_port", false) | ||||
| @ -340,6 +358,12 @@ func validateServerConfig() error { | ||||
| 	// after #2170 is cleaned up
 | ||||
| 	// depr.fatal("oidc.strip_email_domain")
 | ||||
| 
 | ||||
| 	if viper.GetBool("oidc.enabled") { | ||||
| 		if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	depr.Log() | ||||
| 
 | ||||
| 	for _, removed := range []string{ | ||||
| @ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) { | ||||
| 			// after #2170 is cleaned up
 | ||||
| 			StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), | ||||
| 			MapLegacyUsers:   viper.GetBool("oidc.map_legacy_users"), | ||||
| 			PKCE: PKCEConfig{ | ||||
| 				Enabled: viper.GetBool("oidc.pkce.enabled"), | ||||
| 				Method:  viper.GetString("oidc.pkce.method"), | ||||
| 			}, | ||||
| 		}, | ||||
| 
 | ||||
| 		LogTail:             logTailConfig, | ||||
|  | ||||
| @ -13,6 +13,7 @@ import ( | ||||
| 	"net/netip" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -34,9 +35,13 @@ const ( | ||||
| 	dockerContextPath      = "../." | ||||
| 	hsicOIDCMockHashLength = 6 | ||||
| 	defaultAccessTTL       = 10 * time.Minute | ||||
| 	nodeStateRunning       = "Running" | ||||
| ) | ||||
| 
 | ||||
| var errStatusCodeNotOK = errors.New("status code not OK") | ||||
| var ( | ||||
| 	errStatusCodeNotOK = errors.New("status code not OK") | ||||
| 	ErrOIDCClientCount = errors.New("client count must be 1 for OIDC scenario") | ||||
| ) | ||||
| 
 | ||||
| type AuthOIDCScenario struct { | ||||
| 	*Scenario | ||||
| @ -534,6 +539,211 @@ func TestOIDC024UserCreation(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOIDCAuthenticationWithPKCE(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
| 
 | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 	// Single user with one node for testing PKCE flow
 | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 	} | ||||
| 
 | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 		oidcMockUser("user1", true), | ||||
| 	} | ||||
| 
 | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 
 | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_PKCE_ENABLED":       "1", // Enable PKCE
 | ||||
| 		"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "0", | ||||
| 		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||
| 	} | ||||
| 
 | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 		hsic.WithTestName("oidcauthpkce"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 		hsic.WithTLS(), | ||||
| 		hsic.WithHostnameAsServerURL(), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 	) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
| 
 | ||||
| 	// Get all clients and verify they can connect
 | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| 	assertNoErrListClients(t, err) | ||||
| 
 | ||||
| 	allIps, err := scenario.ListTailscaleClientsIPs() | ||||
| 	assertNoErrListClientIPs(t, err) | ||||
| 
 | ||||
| 	err = scenario.WaitForTailscaleSync() | ||||
| 	assertNoErrSync(t, err) | ||||
| 
 | ||||
| 	// Verify PKCE was used in authentication
 | ||||
| 	headscale, err := scenario.Headscale() | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	var listUsers []v1.User | ||||
| 	err = executeAndUnmarshal(headscale, | ||||
| 		[]string{ | ||||
| 			"headscale", | ||||
| 			"users", | ||||
| 			"list", | ||||
| 			"--output", | ||||
| 			"json", | ||||
| 		}, | ||||
| 		&listUsers, | ||||
| 	) | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { | ||||
| 		return x.String() | ||||
| 	}) | ||||
| 
 | ||||
| 	success := pingAllHelper(t, allClients, allAddrs) | ||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||
| 
 | ||||
| 	// Verify all clients are connected and authenticated
 | ||||
| 	for _, client := range allClients { | ||||
| 		status, err := client.Status() | ||||
| 		assertNoErr(t, err) | ||||
| 		if status.BackendState != nodeStateRunning { | ||||
| 			t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type tamperVerifierTransport struct { | ||||
| 	base http.RoundTripper | ||||
| } | ||||
| 
 | ||||
| func (t *tamperVerifierTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| 	log.Printf("RoundTrip: %s %s", req.Method, req.URL.String()) | ||||
| 
 | ||||
| 	// For POST requests, tamper with form data
 | ||||
| 	if req.Method == http.MethodPost { | ||||
| 		log.Printf("Processing POST request") | ||||
| 		err := req.ParseForm() | ||||
| 		if err != nil { | ||||
| 			log.Printf("Error parsing form: %v", err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if verifier := req.Form.Get("code_challenge"); verifier != "" { | ||||
| 			log.Printf("Found POST verifier: %s", verifier) | ||||
| 			// Tamper with the verifier
 | ||||
| 			req.Form.Set("code_challenge", verifier+"_tampered") | ||||
| 			log.Printf("Modified POST verifier to: %s", req.Form.Get("code_challenge")) | ||||
| 			// Update request body with modified form
 | ||||
| 			req.Body = io.NopCloser(strings.NewReader(req.Form.Encode())) | ||||
| 			req.ContentLength = int64(len(req.Form.Encode())) | ||||
| 		} else { | ||||
| 			log.Printf("No code_challenge found in POST form data") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// For GET requests, tamper with URL query parameters
 | ||||
| 	if req.Method == http.MethodGet { | ||||
| 		log.Printf("Processing GET request") | ||||
| 		q := req.URL.Query() | ||||
| 		if verifier := q.Get("code_challenge"); verifier != "" { | ||||
| 			log.Printf("Found GET verifier: %s", verifier) | ||||
| 			q.Set("code_challenge", verifier+"_tampered") | ||||
| 			req.URL.RawQuery = q.Encode() | ||||
| 			log.Printf("Modified URL to: %s", req.URL.String()) | ||||
| 		} else { | ||||
| 			log.Printf("No code_challenge found in GET query params") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Forward the request with the tampered verifier
 | ||||
| 	resp, err := t.base.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("RoundTrip error: %v", err) | ||||
| 
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	log.Printf("Response status: %s", resp.Status) | ||||
| 
 | ||||
| 	return resp, err | ||||
| } | ||||
| 
 | ||||
| func TestOIDCAuthenticationWithPKCEVerifierTampering(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
| 
 | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 	// Single user with one node for testing PKCE flow
 | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 	} | ||||
| 
 | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 		oidcMockUser("user1", true), | ||||
| 	} | ||||
| 
 | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 
 | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_PKCE_ENABLED":       "1", // Enable PKCE
 | ||||
| 		"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "0", | ||||
| 		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||
| 	} | ||||
| 
 | ||||
| 	// Create a transport that modifies the PKCE verifier in transit
 | ||||
| 	baseTransport := &http.Transport{ | ||||
| 		// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
 | ||||
| 		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | ||||
| 	} | ||||
| 	tamperTransport := &tamperVerifierTransport{ | ||||
| 		base: baseTransport, | ||||
| 	} | ||||
| 
 | ||||
| 	err = scenario.CreateHeadscaleEnvWithHTTPModifier( | ||||
| 		spec, | ||||
| 		func(cli *http.Client) { | ||||
| 			cli.Transport = tamperTransport | ||||
| 		}, | ||||
| 		hsic.WithTestName("oidcauthpkce"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 		hsic.WithTLS(), | ||||
| 		hsic.WithHostnameAsServerURL(), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 	) | ||||
| 	if err == nil { | ||||
| 		t.Error("expected authentication to fail due to PKCE verifier tampering, but it succeeded") | ||||
| 	} else { | ||||
| 		log.Printf("auth got error: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	users map[string]int, | ||||
| 	opts ...hsic.Option, | ||||
| @ -554,7 +764,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 			// This is because the MockOIDC server can only serve login
 | ||||
| 			// requests based on a queue it has been given on startup.
 | ||||
| 			// We currently only populates it with one login request per user.
 | ||||
| 			return fmt.Errorf("client count must be 1 for OIDC scenario.") | ||||
| 			return ErrOIDCClientCount | ||||
| 		} | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| @ -576,6 +786,49 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnvWithHTTPModifier( | ||||
| 	users map[string]int, | ||||
| 	httpModifier func(*http.Client), | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	headscale, err := s.Headscale(opts...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	err = headscale.WaitForRunning() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	for userName, clientCount := range users { | ||||
| 		if clientCount != 1 { | ||||
| 			// OIDC scenario only supports one client per user.
 | ||||
| 			// This is because the MockOIDC server can only serve login
 | ||||
| 			// requests based on a queue it has been given on startup.
 | ||||
| 			// We currently only populates it with one login request per user.
 | ||||
| 			return ErrOIDCClientCount | ||||
| 		} | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		err = s.runTailscaleUpWithModifier(userName, headscale.GetEndpoint(), httpModifier) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { | ||||
| 	port, err := dockertestutil.RandomFreeHostPort() | ||||
| 	if err != nil { | ||||
| @ -685,7 +938,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 					log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) | ||||
| 				} | ||||
| 
 | ||||
| 				loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) | ||||
| 				loginURL.Host = headscale.GetIP() + ":8080" | ||||
| 				loginURL.Scheme = "http" | ||||
| 
 | ||||
| 				if len(headscale.GetCert()) > 0 { | ||||
| @ -693,6 +946,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 				} | ||||
| 
 | ||||
| 				insecureTransport := &http.Transport{ | ||||
| 					// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
 | ||||
| 					TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
 | ||||
| 				} | ||||
| 
 | ||||
| @ -759,6 +1013,98 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) runTailscaleUpWithModifier( | ||||
| 	userStr string, | ||||
| 	loginServer string, | ||||
| 	httpClientModifier func(*http.Client), | ||||
| ) error { | ||||
| 	headscale, err := s.Headscale() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Printf("running tailscale up for user %s", userStr) | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		for _, client := range user.Clients { | ||||
| 			c := client | ||||
| 			err := func() error { | ||||
| 				status, err := c.Status() | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to get status: %s", c.Hostname(), err) | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				if status.BackendState == nodeStateRunning { | ||||
| 					log.Printf("%s is already running", c.Hostname()) | ||||
| 					return nil | ||||
| 				} | ||||
| 
 | ||||
| 				log.Printf("%s running tailscale up", c.Hostname()) | ||||
| 
 | ||||
| 				loginURL, err := c.LoginWithURL(loginServer) | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				loginURL.Host = headscale.GetIP() + ":8080" | ||||
| 				loginURL.Scheme = "http" | ||||
| 
 | ||||
| 				if len(headscale.GetCert()) > 0 { | ||||
| 					loginURL.Scheme = "https" | ||||
| 				} | ||||
| 
 | ||||
| 				insecureTransport := &http.Transport{ | ||||
| 					// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
 | ||||
| 					TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
 | ||||
| 				} | ||||
| 
 | ||||
| 				log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) | ||||
| 
 | ||||
| 				log.Printf("%s logging in with url", c.Hostname()) | ||||
| 				httpClient := &http.Client{Transport: insecureTransport} | ||||
| 
 | ||||
| 				// Allow the test to modify the HTTP client
 | ||||
| 				if httpClientModifier != nil { | ||||
| 					httpClientModifier(httpClient) | ||||
| 				} | ||||
| 
 | ||||
| 				ctx := context.Background() | ||||
| 				req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) | ||||
| 				resp, err := httpClient.Do(req) | ||||
| 				if err != nil { | ||||
| 					log.Printf( | ||||
| 						"%s failed to login using url %s: %s", | ||||
| 						c.Hostname(), | ||||
| 						loginURL, | ||||
| 						err, | ||||
| 					) | ||||
| 
 | ||||
| 					return err | ||||
| 				} | ||||
| 				defer resp.Body.Close() | ||||
| 
 | ||||
| 				if resp.StatusCode != http.StatusOK { | ||||
| 					log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) | ||||
| 					body, _ := io.ReadAll(resp.Body) | ||||
| 					log.Printf("body: %s", body) | ||||
| 
 | ||||
| 					return errStatusCodeNotOK | ||||
| 				} | ||||
| 
 | ||||
| 				return nil | ||||
| 			}() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) Shutdown() { | ||||
| 	err := s.pool.Purge(s.mockOIDC) | ||||
| 	if err != nil { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user