mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	feat: add tampered request test for pkce feature
This commit is contained in:
		
							parent
							
								
									f356d08ec9
								
							
						
					
					
						commit
						360d1afe19
					
				| @ -13,6 +13,7 @@ import ( | ||||
| 	"net/netip" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -34,9 +35,13 @@ const ( | ||||
| 	dockerContextPath      = "../." | ||||
| 	hsicOIDCMockHashLength = 6 | ||||
| 	defaultAccessTTL       = 10 * time.Minute | ||||
| 	nodeStateRunning       = "Running" | ||||
| ) | ||||
| 
 | ||||
| var errStatusCodeNotOK = errors.New("status code not OK") | ||||
| var ( | ||||
| 	errStatusCodeNotOK = errors.New("status code not OK") | ||||
| 	ErrOIDCClientCount = errors.New("client count must be 1 for OIDC scenario") | ||||
| ) | ||||
| 
 | ||||
| type AuthOIDCScenario struct { | ||||
| 	*Scenario | ||||
| @ -617,12 +622,128 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { | ||||
| 	for _, client := range allClients { | ||||
| 		status, err := client.Status() | ||||
| 		assertNoErr(t, err) | ||||
| 		if status.BackendState != "Running" { | ||||
| 		if status.BackendState != nodeStateRunning { | ||||
| 			t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type tamperVerifierTransport struct { | ||||
| 	base http.RoundTripper | ||||
| } | ||||
| 
 | ||||
| func (t *tamperVerifierTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| 	log.Printf("RoundTrip: %s %s", req.Method, req.URL.String()) | ||||
| 
 | ||||
| 	// For POST requests, tamper with form data
 | ||||
| 	if req.Method == http.MethodPost { | ||||
| 		log.Printf("Processing POST request") | ||||
| 		err := req.ParseForm() | ||||
| 		if err != nil { | ||||
| 			log.Printf("Error parsing form: %v", err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if verifier := req.Form.Get("code_challenge"); verifier != "" { | ||||
| 			log.Printf("Found POST verifier: %s", verifier) | ||||
| 			// Tamper with the verifier
 | ||||
| 			req.Form.Set("code_challenge", verifier+"_tampered") | ||||
| 			log.Printf("Modified POST verifier to: %s", req.Form.Get("code_challenge")) | ||||
| 			// Update request body with modified form
 | ||||
| 			req.Body = io.NopCloser(strings.NewReader(req.Form.Encode())) | ||||
| 			req.ContentLength = int64(len(req.Form.Encode())) | ||||
| 		} else { | ||||
| 			log.Printf("No code_challenge found in POST form data") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// For GET requests, tamper with URL query parameters
 | ||||
| 	if req.Method == http.MethodGet { | ||||
| 		log.Printf("Processing GET request") | ||||
| 		q := req.URL.Query() | ||||
| 		if verifier := q.Get("code_challenge"); verifier != "" { | ||||
| 			log.Printf("Found GET verifier: %s", verifier) | ||||
| 			q.Set("code_challenge", verifier+"_tampered") | ||||
| 			req.URL.RawQuery = q.Encode() | ||||
| 			log.Printf("Modified URL to: %s", req.URL.String()) | ||||
| 		} else { | ||||
| 			log.Printf("No code_challenge found in GET query params") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Forward the request with the tampered verifier
 | ||||
| 	resp, err := t.base.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("RoundTrip error: %v", err) | ||||
| 
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	log.Printf("Response status: %s", resp.Status) | ||||
| 
 | ||||
| 	return resp, err | ||||
| } | ||||
| 
 | ||||
| func TestOIDCAuthenticationWithPKCEVerifierTampering(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
| 
 | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 	// Single user with one node for testing PKCE flow
 | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 	} | ||||
| 
 | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 		oidcMockUser("user1", true), | ||||
| 	} | ||||
| 
 | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 
 | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_PKCE_ENABLED":       "1", // Enable PKCE
 | ||||
| 		"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "0", | ||||
| 		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||
| 	} | ||||
| 
 | ||||
| 	// Create a transport that modifies the PKCE verifier in transit
 | ||||
| 	baseTransport := &http.Transport{ | ||||
| 		// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
 | ||||
| 		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | ||||
| 	} | ||||
| 	tamperTransport := &tamperVerifierTransport{ | ||||
| 		base: baseTransport, | ||||
| 	} | ||||
| 
 | ||||
| 	err = scenario.CreateHeadscaleEnvWithHTTPModifier( | ||||
| 		spec, | ||||
| 		func(cli *http.Client) { | ||||
| 			cli.Transport = tamperTransport | ||||
| 		}, | ||||
| 		hsic.WithTestName("oidcauthpkce"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 		hsic.WithTLS(), | ||||
| 		hsic.WithHostnameAsServerURL(), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 	) | ||||
| 	if err == nil { | ||||
| 		t.Error("expected authentication to fail due to PKCE verifier tampering, but it succeeded") | ||||
| 	} else { | ||||
| 		log.Printf("auth got error: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	users map[string]int, | ||||
| 	opts ...hsic.Option, | ||||
| @ -643,7 +764,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 			// This is because the MockOIDC server can only serve login
 | ||||
| 			// requests based on a queue it has been given on startup.
 | ||||
| 			// We currently only populates it with one login request per user.
 | ||||
| 			return fmt.Errorf("client count must be 1 for OIDC scenario.") | ||||
| 			return ErrOIDCClientCount | ||||
| 		} | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| @ -665,6 +786,49 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnvWithHTTPModifier( | ||||
| 	users map[string]int, | ||||
| 	httpModifier func(*http.Client), | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	headscale, err := s.Headscale(opts...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	err = headscale.WaitForRunning() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	for userName, clientCount := range users { | ||||
| 		if clientCount != 1 { | ||||
| 			// OIDC scenario only supports one client per user.
 | ||||
| 			// This is because the MockOIDC server can only serve login
 | ||||
| 			// requests based on a queue it has been given on startup.
 | ||||
| 			// We currently only populates it with one login request per user.
 | ||||
| 			return ErrOIDCClientCount | ||||
| 		} | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		err = s.runTailscaleUpWithModifier(userName, headscale.GetEndpoint(), httpModifier) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { | ||||
| 	port, err := dockertestutil.RandomFreeHostPort() | ||||
| 	if err != nil { | ||||
| @ -774,7 +938,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 					log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) | ||||
| 				} | ||||
| 
 | ||||
| 				loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) | ||||
| 				loginURL.Host = headscale.GetIP() + ":8080" | ||||
| 				loginURL.Scheme = "http" | ||||
| 
 | ||||
| 				if len(headscale.GetCert()) > 0 { | ||||
| @ -782,6 +946,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 				} | ||||
| 
 | ||||
| 				insecureTransport := &http.Transport{ | ||||
| 					// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
 | ||||
| 					TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
 | ||||
| 				} | ||||
| 
 | ||||
| @ -848,6 +1013,98 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) runTailscaleUpWithModifier( | ||||
| 	userStr string, | ||||
| 	loginServer string, | ||||
| 	httpClientModifier func(*http.Client), | ||||
| ) error { | ||||
| 	headscale, err := s.Headscale() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Printf("running tailscale up for user %s", userStr) | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		for _, client := range user.Clients { | ||||
| 			c := client | ||||
| 			err := func() error { | ||||
| 				status, err := c.Status() | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to get status: %s", c.Hostname(), err) | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				if status.BackendState == nodeStateRunning { | ||||
| 					log.Printf("%s is already running", c.Hostname()) | ||||
| 					return nil | ||||
| 				} | ||||
| 
 | ||||
| 				log.Printf("%s running tailscale up", c.Hostname()) | ||||
| 
 | ||||
| 				loginURL, err := c.LoginWithURL(loginServer) | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				loginURL.Host = headscale.GetIP() + ":8080" | ||||
| 				loginURL.Scheme = "http" | ||||
| 
 | ||||
| 				if len(headscale.GetCert()) > 0 { | ||||
| 					loginURL.Scheme = "https" | ||||
| 				} | ||||
| 
 | ||||
| 				insecureTransport := &http.Transport{ | ||||
| 					// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
 | ||||
| 					TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
 | ||||
| 				} | ||||
| 
 | ||||
| 				log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) | ||||
| 
 | ||||
| 				log.Printf("%s logging in with url", c.Hostname()) | ||||
| 				httpClient := &http.Client{Transport: insecureTransport} | ||||
| 
 | ||||
| 				// Allow the test to modify the HTTP client
 | ||||
| 				if httpClientModifier != nil { | ||||
| 					httpClientModifier(httpClient) | ||||
| 				} | ||||
| 
 | ||||
| 				ctx := context.Background() | ||||
| 				req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) | ||||
| 				resp, err := httpClient.Do(req) | ||||
| 				if err != nil { | ||||
| 					log.Printf( | ||||
| 						"%s failed to login using url %s: %s", | ||||
| 						c.Hostname(), | ||||
| 						loginURL, | ||||
| 						err, | ||||
| 					) | ||||
| 
 | ||||
| 					return err | ||||
| 				} | ||||
| 				defer resp.Body.Close() | ||||
| 
 | ||||
| 				if resp.StatusCode != http.StatusOK { | ||||
| 					log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) | ||||
| 					body, _ := io.ReadAll(resp.Body) | ||||
| 					log.Printf("body: %s", body) | ||||
| 
 | ||||
| 					return errStatusCodeNotOK | ||||
| 				} | ||||
| 
 | ||||
| 				return nil | ||||
| 			}() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) Shutdown() { | ||||
| 	err := s.pool.Purge(s.mockOIDC) | ||||
| 	if err != nil { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user