mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	users: harden, test, and add cleaner of identifier (#2593)
* users: harden, test, and add cleaner of identifier Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * db: migrate badly joined provider identifiers Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									d7a503a34e
								
							
						
					
					
						commit
						2dc2f3b3f0
					
				@ -695,6 +695,29 @@ AND auth_key_id NOT IN (
 | 
			
		||||
				},
 | 
			
		||||
				Rollback: func(db *gorm.DB) error { return nil },
 | 
			
		||||
			},
 | 
			
		||||
			// Fix the provider identifier for users that have a double slash in the
 | 
			
		||||
			// provider identifier.
 | 
			
		||||
			{
 | 
			
		||||
				ID: "202505141324",
 | 
			
		||||
				Migrate: func(tx *gorm.DB) error {
 | 
			
		||||
					users, err := ListUsers(tx)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return fmt.Errorf("listing users: %w", err)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					for _, user := range users {
 | 
			
		||||
						user.ProviderIdentifier.String = types.CleanIdentifier(user.ProviderIdentifier.String)
 | 
			
		||||
 | 
			
		||||
						err := tx.Save(user).Error
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							return fmt.Errorf("saving user: %w", err)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				},
 | 
			
		||||
				Rollback: func(db *gorm.DB) error { return nil },
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -194,13 +194,110 @@ type OIDCClaims struct {
 | 
			
		||||
	Username          string          `json:"preferred_username,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Identifier returns a unique identifier string combining the Iss and Sub claims.
 | 
			
		||||
// The format depends on whether Iss is a URL or not:
 | 
			
		||||
// - For URLs: Joins the URL and sub path (e.g., "https://example.com/sub")
 | 
			
		||||
// - For non-URLs: Joins with a slash (e.g., "oidc/sub")
 | 
			
		||||
// - For empty Iss: Returns just "sub"
 | 
			
		||||
// - For empty Sub: Returns just the Issuer
 | 
			
		||||
// - For both empty: Returns empty string
 | 
			
		||||
//
 | 
			
		||||
// The result is cleaned using CleanIdentifier() to ensure consistent formatting.
 | 
			
		||||
func (c *OIDCClaims) Identifier() string {
 | 
			
		||||
	if strings.HasPrefix(c.Iss, "http") {
 | 
			
		||||
		if i, err := url.JoinPath(c.Iss, c.Sub); err == nil {
 | 
			
		||||
			return i
 | 
			
		||||
	// Handle empty components special cases
 | 
			
		||||
	if c.Iss == "" && c.Sub == "" {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	if c.Iss == "" {
 | 
			
		||||
		return CleanIdentifier(c.Sub)
 | 
			
		||||
	}
 | 
			
		||||
	if c.Sub == "" {
 | 
			
		||||
		return CleanIdentifier(c.Iss)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We'll use the raw values and let CleanIdentifier handle all the whitespace
 | 
			
		||||
	issuer := c.Iss
 | 
			
		||||
	subject := c.Sub
 | 
			
		||||
 | 
			
		||||
	var result string
 | 
			
		||||
	// Try to parse as URL to handle URL joining correctly
 | 
			
		||||
	if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
 | 
			
		||||
		// For URLs, use proper URL path joining
 | 
			
		||||
		if joined, err := url.JoinPath(issuer, subject); err == nil {
 | 
			
		||||
			result = joined
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return c.Iss + "/" + c.Sub
 | 
			
		||||
 | 
			
		||||
	// If URL joining failed or issuer wasn't a URL, do simple string join
 | 
			
		||||
	if result == "" {
 | 
			
		||||
		// Default case: simple string joining with slash
 | 
			
		||||
		issuer = strings.TrimSuffix(issuer, "/")
 | 
			
		||||
		subject = strings.TrimPrefix(subject, "/")
 | 
			
		||||
		result = issuer + "/" + subject
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clean the result and return it
 | 
			
		||||
	return CleanIdentifier(result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CleanIdentifier cleans a potentially malformed identifier by removing double slashes
 | 
			
		||||
// while preserving protocol specifications like http://. This function will:
 | 
			
		||||
// - Trim all whitespace from the beginning and end of the identifier
 | 
			
		||||
// - Remove whitespace within path segments
 | 
			
		||||
// - Preserve the scheme (http://, https://, etc.) for URLs
 | 
			
		||||
// - Remove any duplicate slashes in the path
 | 
			
		||||
// - Remove empty path segments
 | 
			
		||||
// - For non-URL identifiers, it joins non-empty segments with a single slash
 | 
			
		||||
// - Returns empty string for identifiers with only slashes
 | 
			
		||||
// - Normalize URL schemes to lowercase
 | 
			
		||||
func CleanIdentifier(identifier string) string {
 | 
			
		||||
	if identifier == "" {
 | 
			
		||||
		return identifier
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Trim leading/trailing whitespace
 | 
			
		||||
	identifier = strings.TrimSpace(identifier)
 | 
			
		||||
 | 
			
		||||
	// Handle URLs with schemes
 | 
			
		||||
	u, err := url.Parse(identifier)
 | 
			
		||||
	if err == nil && u.Scheme != "" {
 | 
			
		||||
		// Clean path by removing empty segments and whitespace within segments
 | 
			
		||||
		parts := strings.FieldsFunc(u.Path, func(c rune) bool { return c == '/' })
 | 
			
		||||
		for i, part := range parts {
 | 
			
		||||
			parts[i] = strings.TrimSpace(part)
 | 
			
		||||
		}
 | 
			
		||||
		// Remove empty parts after trimming
 | 
			
		||||
		cleanParts := make([]string, 0, len(parts))
 | 
			
		||||
		for _, part := range parts {
 | 
			
		||||
			if part != "" {
 | 
			
		||||
				cleanParts = append(cleanParts, part)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		
 | 
			
		||||
		if len(cleanParts) == 0 {
 | 
			
		||||
			u.Path = ""
 | 
			
		||||
		} else {
 | 
			
		||||
			u.Path = "/" + strings.Join(cleanParts, "/")
 | 
			
		||||
		}
 | 
			
		||||
		// Ensure scheme is lowercase
 | 
			
		||||
		u.Scheme = strings.ToLower(u.Scheme)
 | 
			
		||||
		return u.String()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Handle non-URL identifiers
 | 
			
		||||
	parts := strings.FieldsFunc(identifier, func(c rune) bool { return c == '/' })
 | 
			
		||||
	// Clean whitespace from each part
 | 
			
		||||
	cleanParts := make([]string, 0, len(parts))
 | 
			
		||||
	for _, part := range parts {
 | 
			
		||||
		trimmed := strings.TrimSpace(part)
 | 
			
		||||
		if trimmed != "" {
 | 
			
		||||
			cleanParts = append(cleanParts, trimmed)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if len(cleanParts) == 0 {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(cleanParts, "/")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OIDCUserInfo struct {
 | 
			
		||||
@ -231,7 +328,13 @@ func (u *User) FromClaim(claims *OIDCClaims) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
 | 
			
		||||
	// Get provider identifier
 | 
			
		||||
	identifier := claims.Identifier()
 | 
			
		||||
	// Ensure provider identifier always has a leading slash for backward compatibility
 | 
			
		||||
	if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
 | 
			
		||||
		identifier = "/" + identifier
 | 
			
		||||
	}
 | 
			
		||||
	u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
 | 
			
		||||
	u.DisplayName = claims.Name
 | 
			
		||||
	u.ProfilePicURL = claims.ProfilePictureURL
 | 
			
		||||
	u.Provider = util.RegisterMethodOIDC
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@ import (
 | 
			
		||||
 | 
			
		||||
	"github.com/google/go-cmp/cmp"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestUnmarshallOIDCClaims(t *testing.T) {
 | 
			
		||||
@ -76,6 +77,218 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOIDCClaimsIdentifier(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		iss      string
 | 
			
		||||
		sub      string
 | 
			
		||||
		expected string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:     "standard URL with trailing slash",
 | 
			
		||||
			iss:      "https://oidc.example.com/",
 | 
			
		||||
			sub:      "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
 | 
			
		||||
			expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "standard URL without trailing slash",
 | 
			
		||||
			iss:      "https://oidc.example.com",
 | 
			
		||||
			sub:      "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
 | 
			
		||||
			expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "standard URL with uppercase protocol",
 | 
			
		||||
			iss:      "HTTPS://oidc.example.com/",
 | 
			
		||||
			sub:      "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
 | 
			
		||||
			expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "standard URL with path and trailing slash",
 | 
			
		||||
			iss:      "https://login.microsoftonline.com/v2.0/",
 | 
			
		||||
			sub:      "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
			expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "standard URL with path without trailing slash",
 | 
			
		||||
			iss:      "https://login.microsoftonline.com/v2.0",
 | 
			
		||||
			sub:      "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
			expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "non-URL identifier with slash",
 | 
			
		||||
			iss:      "oidc",
 | 
			
		||||
			sub:      "sub",
 | 
			
		||||
			expected: "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "non-URL identifier with trailing slash",
 | 
			
		||||
			iss:      "oidc/",
 | 
			
		||||
			sub:      "sub",
 | 
			
		||||
			expected: "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "subject with slash",
 | 
			
		||||
			iss:      "oidc/",
 | 
			
		||||
			sub:      "sub/",
 | 
			
		||||
			expected: "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "whitespace",
 | 
			
		||||
			iss:      "   oidc/   ",
 | 
			
		||||
			sub:      "   sub   ",
 | 
			
		||||
			expected: "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "newline",
 | 
			
		||||
			iss:      "\noidc/\n",
 | 
			
		||||
			sub:      "\nsub\n",
 | 
			
		||||
			expected: "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "tab",
 | 
			
		||||
			iss:      "\toidc/\t",
 | 
			
		||||
			sub:      "\tsub\t",
 | 
			
		||||
			expected: "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "empty issuer",
 | 
			
		||||
			iss:      "",
 | 
			
		||||
			sub:      "sub",
 | 
			
		||||
			expected: "sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "empty subject",
 | 
			
		||||
			iss:      "https://oidc.example.com",
 | 
			
		||||
			sub:      "",
 | 
			
		||||
			expected: "https://oidc.example.com",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "both empty",
 | 
			
		||||
			iss:      "",
 | 
			
		||||
			sub:      "",
 | 
			
		||||
			expected: "",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "URL with double slash",
 | 
			
		||||
			iss:      "https://login.microsoftonline.com//v2.0",
 | 
			
		||||
			sub:      "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
			expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "FTP URL protocol",
 | 
			
		||||
			iss:      "ftp://example.com/directory",
 | 
			
		||||
			sub:      "resource",
 | 
			
		||||
			expected: "ftp://example.com/directory/resource",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			claims := OIDCClaims{
 | 
			
		||||
				Iss: tt.iss,
 | 
			
		||||
				Sub: tt.sub,
 | 
			
		||||
			}
 | 
			
		||||
			result := claims.Identifier()
 | 
			
		||||
			assert.Equal(t, tt.expected, result)
 | 
			
		||||
			if diff := cmp.Diff(tt.expected, result); diff != "" {
 | 
			
		||||
				t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Now clean the identifier and verify it's still the same
 | 
			
		||||
			cleaned := CleanIdentifier(result)
 | 
			
		||||
 | 
			
		||||
			// Double-check with cmp.Diff for better error messages
 | 
			
		||||
			if diff := cmp.Diff(tt.expected, cleaned); diff != "" {
 | 
			
		||||
				t.Errorf("CleanIdentifier(Identifier()) mismatch (-want +got):\n%s", diff)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCleanIdentifier(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name       string
 | 
			
		||||
		identifier string
 | 
			
		||||
		expected   string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:       "empty identifier",
 | 
			
		||||
			identifier: "",
 | 
			
		||||
			expected:   "",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "simple identifier",
 | 
			
		||||
			identifier: "oidc/sub",
 | 
			
		||||
			expected:   "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "double slashes in the middle",
 | 
			
		||||
			identifier: "oidc//sub",
 | 
			
		||||
			expected:   "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "trailing slash",
 | 
			
		||||
			identifier: "oidc/sub/",
 | 
			
		||||
			expected:   "oidc/sub",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "multiple double slashes",
 | 
			
		||||
			identifier: "oidc//sub///id//",
 | 
			
		||||
			expected:   "oidc/sub/id",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "HTTP URL with proper scheme",
 | 
			
		||||
			identifier: "http://example.com/path",
 | 
			
		||||
			expected:   "http://example.com/path",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "HTTP URL with double slashes in path",
 | 
			
		||||
			identifier: "http://example.com//path///resource",
 | 
			
		||||
			expected:   "http://example.com/path/resource",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "HTTPS URL with empty segments",
 | 
			
		||||
			identifier: "https://example.com///path//",
 | 
			
		||||
			expected:   "https://example.com/path",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "URL with double slashes in domain",
 | 
			
		||||
			identifier: "https://login.microsoftonline.com//v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
			expected:   "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "FTP URL with double slashes",
 | 
			
		||||
			identifier: "ftp://example.com//resource//",
 | 
			
		||||
			expected:   "ftp://example.com/resource",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "Just slashes",
 | 
			
		||||
			identifier: "///",
 | 
			
		||||
			expected:   "",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "Leading slash without URL",
 | 
			
		||||
			identifier: "/path//to///resource",
 | 
			
		||||
			expected:   "path/to/resource",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:       "Non-standard protocol",
 | 
			
		||||
			identifier: "ldap://example.org//path//to//resource",
 | 
			
		||||
			expected:   "ldap://example.org/path/to/resource",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			result := CleanIdentifier(tt.identifier)
 | 
			
		||||
			assert.Equal(t, tt.expected, result)
 | 
			
		||||
			if diff := cmp.Diff(tt.expected, result); diff != "" {
 | 
			
		||||
				t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOIDCClaimsJSONToUser(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name    string
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user