mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	fix oidc test, add tests for migration
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									8059d475a4
								
							
						
					
					
						commit
						69b9abaa6c
					
				
							
								
								
									
										1
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							| @ -21,6 +21,7 @@ jobs: | ||||
|           - TestPolicyUpdateWhileRunningWithCLIInDatabase | ||||
|           - TestOIDCAuthenticationPingAll | ||||
|           - TestOIDCExpireNodesBasedOnTokenExpiry | ||||
|           - TestOIDC024UserCreation | ||||
|           - TestAuthWebFlowAuthenticationPingAll | ||||
|           - TestAuthWebFlowLogoutAndRelogin | ||||
|           - TestUserCommand | ||||
|  | ||||
| @ -1,8 +1,10 @@ | ||||
| package cli | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| @ -64,6 +66,19 @@ func mockOIDC() error { | ||||
| 		accessTTL = newTTL | ||||
| 	} | ||||
| 
 | ||||
| 	userStr := os.Getenv("MOCKOIDC_USERS") | ||||
| 	if userStr == "" { | ||||
| 		return fmt.Errorf("MOCKOIDC_USERS not defined") | ||||
| 	} | ||||
| 
 | ||||
| 	var users []mockoidc.MockUser | ||||
| 	err := json.Unmarshal([]byte(userStr), &users) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("unmarshalling users: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	log.Info().Interface("users", users).Msg("loading users from JSON") | ||||
| 
 | ||||
| 	log.Info().Msgf("Access token TTL: %s", accessTTL) | ||||
| 
 | ||||
| 	port, err := strconv.Atoi(portStr) | ||||
| @ -71,7 +86,7 @@ func mockOIDC() error { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	mock, err := getMockOIDC(clientID, clientSecret) | ||||
| 	mock, err := getMockOIDC(clientID, clientSecret, users) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -93,12 +108,18 @@ func mockOIDC() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { | ||||
| func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) { | ||||
| 	keypair, err := mockoidc.NewKeypair(nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	userQueue := mockoidc.UserQueue{} | ||||
| 
 | ||||
| 	for _, user := range users { | ||||
| 		userQueue.Push(&user) | ||||
| 	} | ||||
| 
 | ||||
| 	mock := mockoidc.MockOIDC{ | ||||
| 		ClientID:                      clientID, | ||||
| 		ClientSecret:                  clientSecret, | ||||
| @ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro | ||||
| 		CodeChallengeMethodsSupported: []string{"plain", "S256"}, | ||||
| 		Keypair:                       keypair, | ||||
| 		SessionStore:                  mockoidc.NewSessionStore(), | ||||
| 		UserQueue:                     &mockoidc.UserQueue{}, | ||||
| 		UserQueue:                     &userQueue, | ||||
| 		ErrorQueue:                    &mockoidc.ErrorQueue{}, | ||||
| 	} | ||||
| 
 | ||||
| 	mock.AddMiddleware(func(h http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			log.Info().Msgf("Request: %+v", r) | ||||
| 			h.ServeHTTP(w, r) | ||||
| 			if r.Response != nil { | ||||
| 				log.Info().Msgf("Response: %+v", r.Response) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	return &mock, nil | ||||
| } | ||||
|  | ||||
| @ -436,7 +436,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( | ||||
| ) (*types.User, error) { | ||||
| 	var user *types.User | ||||
| 	var err error | ||||
| 	user, err = a.db.GetUserByOIDCIdentifier(claims.Sub) | ||||
| 	user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) | ||||
| 	if err != nil && !errors.Is(err, db.ErrUserNotFound) { | ||||
| 		return nil, fmt.Errorf("creating or updating user: %w", err) | ||||
| 	} | ||||
| @ -448,10 +448,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( | ||||
| 	// TODO(kradalby): Remove when strip_email_domain and migration is removed
 | ||||
| 	// after #2170 is cleaned up.
 | ||||
| 	if a.cfg.MapLegacyUsers && user == nil { | ||||
| 		log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username") | ||||
| 		if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { | ||||
| 			log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username") | ||||
| 			user, err = a.db.GetUserByName(oldUsername) | ||||
| 			if err != nil && !errors.Is(err, db.ErrUserNotFound) { | ||||
| 				return nil, fmt.Errorf("creating or updating user: %w", err) | ||||
| 				return nil, fmt.Errorf("getting user: %w", err) | ||||
| 			} | ||||
| 
 | ||||
| 			// If the user exists, but it already has a provider identifier (OIDC sub), create a new user.
 | ||||
| @ -525,6 +527,9 @@ func getUserName( | ||||
| 	claims *types.OIDCClaims, | ||||
| 	stripEmaildomain bool, | ||||
| ) (string, error) { | ||||
| 	if !claims.EmailVerified { | ||||
| 		return "", fmt.Errorf("email not verified") | ||||
| 	} | ||||
| 	userName, err := util.NormalizeToFQDNRules( | ||||
| 		claims.Email, | ||||
| 		stripEmaildomain, | ||||
|  | ||||
| @ -905,7 +905,10 @@ func LoadServerConfig() (*Config, error) { | ||||
| 				} | ||||
| 			}(), | ||||
| 			UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), | ||||
| 			MapLegacyUsers:     viper.GetBool("oidc.map_legacy_users"), | ||||
| 			// TODO(kradalby): Remove when strip_email_domain is removed
 | ||||
| 			// after #2170 is cleaned up
 | ||||
| 			StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), | ||||
| 			MapLegacyUsers:   viper.GetBool("oidc.map_legacy_users"), | ||||
| 		}, | ||||
| 
 | ||||
| 		LogTail:             logTailConfig, | ||||
|  | ||||
| @ -3,7 +3,6 @@ package types | ||||
| import ( | ||||
| 	"cmp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| @ -39,7 +38,7 @@ type User struct { | ||||
| 	// Unique identifier of the user from OIDC,
 | ||||
| 	// comes from `sub` claim in the OIDC token
 | ||||
| 	// and is used to lookup the user.
 | ||||
| 	ProviderIdentifier string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` | ||||
| 	ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` | ||||
| 
 | ||||
| 	// Provider is the origin of the user account,
 | ||||
| 	// same as RegistrationMethod, without authkey.
 | ||||
| @ -58,9 +57,10 @@ type User struct { | ||||
| // If the username does not contain an '@' it will be added to the end.
 | ||||
| func (u *User) Username() string { | ||||
| 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) | ||||
| 	if !strings.Contains(username, "@") { | ||||
| 		username = username + "@" | ||||
| 	} | ||||
| 	// TODO(kradalby): Wire up all of this for the future
 | ||||
| 	// if !strings.Contains(username, "@") {
 | ||||
| 	// 	username = username + "@"
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	return username | ||||
| } | ||||
| @ -138,10 +138,14 @@ type OIDCClaims struct { | ||||
| 	Username          string   `json:"preferred_username,omitempty"` | ||||
| } | ||||
| 
 | ||||
| func (c *OIDCClaims) Identifier() string { | ||||
| 	return c.Iss + "/" + c.Sub | ||||
| } | ||||
| 
 | ||||
| // FromClaim overrides a User from OIDC claims.
 | ||||
| // All fields will be updated, except for the ID.
 | ||||
| func (u *User) FromClaim(claims *OIDCClaims) { | ||||
| 	u.ProviderIdentifier = claims.Iss + "/" + claims.Sub | ||||
| 	u.ProviderIdentifier = claims.Identifier() | ||||
| 	u.DisplayName = claims.Name | ||||
| 	if claims.EmailVerified { | ||||
| 		u.Email = claims.Email | ||||
|  | ||||
| @ -3,6 +3,7 @@ package integration | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| @ -10,14 +11,19 @@ import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/google/go-cmp/cmp/cmpopts" | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/oauth2-proxy/mockoidc" | ||||
| 	"github.com/ory/dockertest/v3" | ||||
| 	"github.com/ory/dockertest/v3/docker" | ||||
| 	"github.com/samber/lo" | ||||
| @ -50,18 +56,32 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 	// Logins to MockOIDC is served by a queue with a strict order,
 | ||||
| 	// if we use more than one node per user, the order of the logins
 | ||||
| 	// will not be deterministic and the test will fail.
 | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user1": 1, | ||||
| 		"user2": 1, | ||||
| 	} | ||||
| 
 | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 		oidcMockUser("user1", true), | ||||
| 		oidcMockUser("user2", false), | ||||
| 	} | ||||
| 
 | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 
 | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 		// TODO(kradalby): Remove when strip_email_domain is removed
 | ||||
| 		// after #2170 is cleaned up
 | ||||
| 		"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "0", | ||||
| 		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||
| 	} | ||||
| 
 | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| @ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | ||||
| 
 | ||||
| 	success := pingAllHelper(t, allClients, allAddrs) | ||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||
| 
 | ||||
| 	headscale, err := scenario.Headscale() | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	var listUsers []v1.User | ||||
| 	err = executeAndUnmarshal(headscale, | ||||
| 		[]string{ | ||||
| 			"headscale", | ||||
| 			"users", | ||||
| 			"list", | ||||
| 			"--output", | ||||
| 			"json", | ||||
| 		}, | ||||
| 		&listUsers, | ||||
| 	) | ||||
| 	assertNoErr(t, err) | ||||
| 
 | ||||
| 	want := []v1.User{ | ||||
| 		{ | ||||
| 			Id:   "1", | ||||
| 			Name: "user1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "2", | ||||
| 			Name:       "user1", | ||||
| 			Email:      "user1@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:   "3", | ||||
| 			Name: "user2", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "4", | ||||
| 			Name:       "user2", | ||||
| 			Email:      "", // Unverified
 | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user2", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Slice(listUsers, func(i, j int) bool { | ||||
| 		return listUsers[i].Id < listUsers[j].Id | ||||
| 	}) | ||||
| 
 | ||||
| 	if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||
| 		t.Fatalf("unexpected users: %s", diff) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // This test is really flaky.
 | ||||
| @ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 3, | ||||
| 		"user1": 1, | ||||
| 		"user2": 1, | ||||
| 	} | ||||
| 
 | ||||
| 	oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) | ||||
| 	oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ | ||||
| 		oidcMockUser("user1", true), | ||||
| 		oidcMockUser("user2", false), | ||||
| 	}) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 
 | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":                oidcConfig.Issuer, | ||||
| @ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { | ||||
| 	assertTailscaleNodesLogout(t, allClients) | ||||
| } | ||||
| 
 | ||||
| // TODO(kradalby):
 | ||||
| // - Test that creates a new user when one exists when migration is turned off
 | ||||
| // - Test that takes over a user when one exists when migration is turned on
 | ||||
| //   - But email is not verified
 | ||||
| //   - stripped email domain on/off
 | ||||
| func TestOIDC024UserCreation(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		name          string | ||||
| 		config        map[string]string | ||||
| 		emailVerified bool | ||||
| 		cliUsers      []string | ||||
| 		oidcUsers     []string | ||||
| 		want          func(iss string) []v1.User | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "no-migration-verified-email", | ||||
| 			config: map[string]string{ | ||||
| 				"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", | ||||
| 			}, | ||||
| 			emailVerified: true, | ||||
| 			cliUsers:      []string{"user1", "user2"}, | ||||
| 			oidcUsers:     []string{"user1", "user2"}, | ||||
| 			want: func(iss string) []v1.User { | ||||
| 				return []v1.User{ | ||||
| 					{ | ||||
| 						Id:   "1", | ||||
| 						Name: "user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "2", | ||||
| 						Name:       "user1", | ||||
| 						Email:      "user1@headscale.net", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:   "3", | ||||
| 						Name: "user2", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "4", | ||||
| 						Name:       "user2", | ||||
| 						Email:      "user2@headscale.net", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user2", | ||||
| 					}, | ||||
| 				} | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "no-migration-not-verified-email", | ||||
| 			config: map[string]string{ | ||||
| 				"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", | ||||
| 			}, | ||||
| 			emailVerified: false, | ||||
| 			cliUsers:      []string{"user1", "user2"}, | ||||
| 			oidcUsers:     []string{"user1", "user2"}, | ||||
| 			want: func(iss string) []v1.User { | ||||
| 				return []v1.User{ | ||||
| 					{ | ||||
| 						Id:   "1", | ||||
| 						Name: "user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "2", | ||||
| 						Name:       "user1", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:   "3", | ||||
| 						Name: "user2", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "4", | ||||
| 						Name:       "user2", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user2", | ||||
| 					}, | ||||
| 				} | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "migration-strip-domains-verified-email", | ||||
| 			config: map[string]string{ | ||||
| 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||
| 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", | ||||
| 			}, | ||||
| 			emailVerified: true, | ||||
| 			cliUsers:      []string{"user1", "user2"}, | ||||
| 			oidcUsers:     []string{"user1", "user2"}, | ||||
| 			want: func(iss string) []v1.User { | ||||
| 				return []v1.User{ | ||||
| 					{ | ||||
| 						Id:         "1", | ||||
| 						Name:       "user1", | ||||
| 						Email:      "user1@headscale.net", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "2", | ||||
| 						Name:       "user2", | ||||
| 						Email:      "user2@headscale.net", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user2", | ||||
| 					}, | ||||
| 				} | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "migration-strip-domains-not-verified-email", | ||||
| 			config: map[string]string{ | ||||
| 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||
| 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", | ||||
| 			}, | ||||
| 			emailVerified: false, | ||||
| 			cliUsers:      []string{"user1", "user2"}, | ||||
| 			oidcUsers:     []string{"user1", "user2"}, | ||||
| 			want: func(iss string) []v1.User { | ||||
| 				return []v1.User{ | ||||
| 					{ | ||||
| 						Id:   "1", | ||||
| 						Name: "user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "2", | ||||
| 						Name:       "user1", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:   "3", | ||||
| 						Name: "user2", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "4", | ||||
| 						Name:       "user2", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user2", | ||||
| 					}, | ||||
| 				} | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "migration-no-strip-domains-verified-email", | ||||
| 			config: map[string]string{ | ||||
| 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||
| 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||
| 			}, | ||||
| 			emailVerified: true, | ||||
| 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | ||||
| 			oidcUsers:     []string{"user1", "user2"}, | ||||
| 			want: func(iss string) []v1.User { | ||||
| 				return []v1.User{ | ||||
| 					// Hmm I think we will have to overwrite the initial name here
 | ||||
| 					// createuser with "user1.headscale.net", but oidc with "user1"
 | ||||
| 					{ | ||||
| 						Id:         "1", | ||||
| 						Name:       "user1", | ||||
| 						Email:      "user1@headscale.net", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "2", | ||||
| 						Name:       "user2", | ||||
| 						Email:      "user2@headscale.net", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user2", | ||||
| 					}, | ||||
| 				} | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "migration-no-strip-domains-not-verified-email", | ||||
| 			config: map[string]string{ | ||||
| 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||
| 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||
| 			}, | ||||
| 			emailVerified: false, | ||||
| 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | ||||
| 			oidcUsers:     []string{"user1", "user2"}, | ||||
| 			want: func(iss string) []v1.User { | ||||
| 				return []v1.User{ | ||||
| 					{ | ||||
| 						Id:   "1", | ||||
| 						Name: "user1.headscale.net", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "2", | ||||
| 						Name:       "user1", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user1", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:   "3", | ||||
| 						Name: "user2.headscale.net", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Id:         "4", | ||||
| 						Name:       "user2", | ||||
| 						Provider:   "oidc", | ||||
| 						ProviderId: iss + "/user2", | ||||
| 					}, | ||||
| 				} | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			assertNoErr(t, err) | ||||
| 
 | ||||
| 			scenario := AuthOIDCScenario{ | ||||
| 				Scenario: baseScenario, | ||||
| 			} | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 			spec := map[string]int{} | ||||
| 			for _, user := range tt.cliUsers { | ||||
| 				spec[user] = 1 | ||||
| 			} | ||||
| 
 | ||||
| 			var mockusers []mockoidc.MockUser | ||||
| 			for _, user := range tt.oidcUsers { | ||||
| 				mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) | ||||
| 			} | ||||
| 
 | ||||
| 			oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 			assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 			defer scenario.mockOIDC.Close() | ||||
| 
 | ||||
| 			oidcMap := map[string]string{ | ||||
| 				"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 				"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 				"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 				"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 			} | ||||
| 
 | ||||
| 			for k, v := range tt.config { | ||||
| 				oidcMap[k] = v | ||||
| 			} | ||||
| 
 | ||||
| 			err = scenario.CreateHeadscaleEnv( | ||||
| 				spec, | ||||
| 				hsic.WithTestName("oidcmigration"), | ||||
| 				hsic.WithConfigEnv(oidcMap), | ||||
| 				hsic.WithTLS(), | ||||
| 				hsic.WithHostnameAsServerURL(), | ||||
| 				hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 			) | ||||
| 			assertNoErrHeadscaleEnv(t, err) | ||||
| 
 | ||||
| 			// Ensure that the nodes have logged in, this is what
 | ||||
| 			// triggers user creation via OIDC.
 | ||||
| 			err = scenario.WaitForTailscaleSync() | ||||
| 			assertNoErrSync(t, err) | ||||
| 
 | ||||
| 			headscale, err := scenario.Headscale() | ||||
| 			assertNoErr(t, err) | ||||
| 
 | ||||
| 			want := tt.want(oidcConfig.Issuer) | ||||
| 
 | ||||
| 			var listUsers []v1.User | ||||
| 			err = executeAndUnmarshal(headscale, | ||||
| 				[]string{ | ||||
| 					"headscale", | ||||
| 					"users", | ||||
| 					"list", | ||||
| 					"--output", | ||||
| 					"json", | ||||
| 				}, | ||||
| 				&listUsers, | ||||
| 			) | ||||
| 			assertNoErr(t, err) | ||||
| 
 | ||||
| 			sort.Slice(listUsers, func(i, j int) bool { | ||||
| 				return listUsers[i].Id < listUsers[j].Id | ||||
| 			}) | ||||
| 
 | ||||
| 			if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||
| 				t.Errorf("unexpected users: %s", diff) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	users map[string]int, | ||||
| 	opts ...hsic.Option, | ||||
| @ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	} | ||||
| 
 | ||||
| 	for userName, clientCount := range users { | ||||
| 		if clientCount != 1 { | ||||
| 			// OIDC scenario only supports one client per user.
 | ||||
| 			// This is because the MockOIDC server can only serve login
 | ||||
| 			// requests based on a queue it has been given on startup.
 | ||||
| 			// We currently only populates it with one login request per user.
 | ||||
| 			return fmt.Errorf("client count must be 1 for OIDC scenario.") | ||||
| 		} | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| @ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { | ||||
| func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { | ||||
| 	port, err := dockertestutil.RandomFreeHostPort() | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("could not find an open port: %s", err) | ||||
| @ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf | ||||
| 
 | ||||
| 	hostname := fmt.Sprintf("hs-oidcmock-%s", hash) | ||||
| 
 | ||||
| 	usersJSON, err := json.Marshal(users) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	mockOidcOptions := &dockertest.RunOptions{ | ||||
| 		Name:         hostname, | ||||
| 		Cmd:          []string{"headscale", "mockoidc"}, | ||||
| @ -219,6 +596,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf | ||||
| 			"MOCKOIDC_CLIENT_ID=superclient", | ||||
| 			"MOCKOIDC_CLIENT_SECRET=supersecret", | ||||
| 			fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), | ||||
| 			fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| @ -310,45 +688,40 @@ func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 
 | ||||
| 				log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) | ||||
| 
 | ||||
| 				if err := s.pool.Retry(func() error { | ||||
| 					log.Printf("%s logging in with url", c.Hostname()) | ||||
| 					httpClient := &http.Client{Transport: insecureTransport} | ||||
| 					ctx := context.Background() | ||||
| 					req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) | ||||
| 					resp, err := httpClient.Do(req) | ||||
| 					if err != nil { | ||||
| 						log.Printf( | ||||
| 							"%s failed to login using url %s: %s", | ||||
| 							c.Hostname(), | ||||
| 							loginURL, | ||||
| 							err, | ||||
| 						) | ||||
| 				log.Printf("%s logging in with url", c.Hostname()) | ||||
| 				httpClient := &http.Client{Transport: insecureTransport} | ||||
| 				ctx := context.Background() | ||||
| 				req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) | ||||
| 				resp, err := httpClient.Do(req) | ||||
| 				if err != nil { | ||||
| 					log.Printf( | ||||
| 						"%s failed to login using url %s: %s", | ||||
| 						c.Hostname(), | ||||
| 						loginURL, | ||||
| 						err, | ||||
| 					) | ||||
| 
 | ||||
| 						return err | ||||
| 					} | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 					if resp.StatusCode != http.StatusOK { | ||||
| 						log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) | ||||
| 				if resp.StatusCode != http.StatusOK { | ||||
| 					log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) | ||||
| 					body, _ := io.ReadAll(resp.Body) | ||||
| 					log.Printf("body: %s", body) | ||||
| 
 | ||||
| 						return errStatusCodeNotOK | ||||
| 					} | ||||
| 					return errStatusCodeNotOK | ||||
| 				} | ||||
| 
 | ||||
| 					defer resp.Body.Close() | ||||
| 				defer resp.Body.Close() | ||||
| 
 | ||||
| 					_, err = io.ReadAll(resp.Body) | ||||
| 					if err != nil { | ||||
| 						log.Printf("%s failed to read response body: %s", c.Hostname(), err) | ||||
| 				_, err = io.ReadAll(resp.Body) | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to read response body: %s", c.Hostname(), err) | ||||
| 
 | ||||
| 						return err | ||||
| 					} | ||||
| 
 | ||||
| 					return nil | ||||
| 				}); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				log.Printf("Finished request for %s to join tailnet", c.Hostname()) | ||||
| 
 | ||||
| 				return nil | ||||
| 			}) | ||||
| 
 | ||||
| @ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { | ||||
| 		assert.Equal(t, "NeedsLogin", status.BackendState) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { | ||||
| 	return mockoidc.MockUser{ | ||||
| 		Subject:           username, | ||||
| 		PreferredUsername: username, | ||||
| 		Email:             fmt.Sprintf("%s@headscale.net", username), | ||||
| 		EmailVerified:     emailVerified, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -74,7 +74,7 @@ func ExecuteCommand( | ||||
| 	select { | ||||
| 	case res := <-resultChan: | ||||
| 		if res.err != nil { | ||||
| 			return stdout.String(), stderr.String(), res.err | ||||
| 			return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err) | ||||
| 		} | ||||
| 
 | ||||
| 		if res.exitCode != 0 { | ||||
| @ -83,12 +83,12 @@ func ExecuteCommand( | ||||
| 			// log.Println("stdout: ", stdout.String())
 | ||||
| 			// log.Println("stderr: ", stderr.String())
 | ||||
| 
 | ||||
| 			return stdout.String(), stderr.String(), ErrDockertestCommandFailed | ||||
| 			return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) | ||||
| 		} | ||||
| 
 | ||||
| 		return stdout.String(), stderr.String(), nil | ||||
| 	case <-time.After(execConfig.timeout): | ||||
| 
 | ||||
| 		return stdout.String(), stderr.String(), ErrDockertestCommandTimeout | ||||
| 		return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user