mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	fix constraints
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									5e7c3153b9
								
							
						
					
					
						commit
						281025bb16
					
				| @ -1,6 +1,7 @@ | |||||||
| package db | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| @ -257,3 +258,110 @@ func testCopyOfDatabase(src string) (string, error) { | |||||||
| func emptyCache() *zcache.Cache[string, types.Node] { | func emptyCache() *zcache.Cache[string, types.Node] { | ||||||
| 	return zcache.New[string, types.Node](time.Minute, time.Hour) | 	return zcache.New[string, types.Node](time.Minute, time.Hour) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestConstraints(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name string | ||||||
|  | 		run  func(*testing.T, *gorm.DB) | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "no-duplicate-username-if-no-oidc", | ||||||
|  | 			run: func(t *testing.T, db *gorm.DB) { | ||||||
|  | 				_, err := CreateUser(db, "user1") | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 				_, err = CreateUser(db, "user1") | ||||||
|  | 				require.Error(t, err) | ||||||
|  | 				// assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username")
 | ||||||
|  | 				require.Contains(t, err.Error(), "user already exists") | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "no-oidc-duplicate-username-and-id", | ||||||
|  | 			run: func(t *testing.T, db *gorm.DB) { | ||||||
|  | 				user := types.User{ | ||||||
|  | 					Model: gorm.Model{ID: 1}, | ||||||
|  | 					Name:  "user1", | ||||||
|  | 				} | ||||||
|  | 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} | ||||||
|  | 
 | ||||||
|  | 				err := db.Save(&user).Error | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 				user = types.User{ | ||||||
|  | 					Model: gorm.Model{ID: 2}, | ||||||
|  | 					Name:  "user1", | ||||||
|  | 				} | ||||||
|  | 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} | ||||||
|  | 
 | ||||||
|  | 				err = db.Save(&user).Error | ||||||
|  | 				require.Error(t, err) | ||||||
|  | 				require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "no-oidc-duplicate-id", | ||||||
|  | 			run: func(t *testing.T, db *gorm.DB) { | ||||||
|  | 				user := types.User{ | ||||||
|  | 					Model: gorm.Model{ID: 1}, | ||||||
|  | 					Name:  "user1", | ||||||
|  | 				} | ||||||
|  | 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} | ||||||
|  | 
 | ||||||
|  | 				err := db.Save(&user).Error | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 				user = types.User{ | ||||||
|  | 					Model: gorm.Model{ID: 2}, | ||||||
|  | 					Name:  "user1.1", | ||||||
|  | 				} | ||||||
|  | 				user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} | ||||||
|  | 
 | ||||||
|  | 				err = db.Save(&user).Error | ||||||
|  | 				require.Error(t, err) | ||||||
|  | 				require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "allow-duplicate-username-cli-then-oidc", | ||||||
|  | 			run: func(t *testing.T, db *gorm.DB) { | ||||||
|  | 				_, err := CreateUser(db, "user1") // Create CLI username
 | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 				user := types.User{ | ||||||
|  | 					Name: "user1", | ||||||
|  | 				} | ||||||
|  | 				user.ProviderIdentifier.String = "http://test.com/user1" | ||||||
|  | 
 | ||||||
|  | 				err = db.Save(&user).Error | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "allow-duplicate-username-oidc-then-cli", | ||||||
|  | 			run: func(t *testing.T, db *gorm.DB) { | ||||||
|  | 				user := types.User{ | ||||||
|  | 					Name: "user1", | ||||||
|  | 				} | ||||||
|  | 				user.ProviderIdentifier.String = "http://test.com/user1" | ||||||
|  | 
 | ||||||
|  | 				err := db.Save(&user).Error | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 				_, err = CreateUser(db, "user1") // Create CLI username
 | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			db, err := newTestDB() | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatalf("creating database: %s", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			tt.run(t, db.DB) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	user := types.User{} | 	user := types.User{ | ||||||
| 	if err := tx.Where("name = ?", name).First(&user).Error; err == nil { | 		Name: name, | ||||||
| 		return nil, ErrUserExists |  | ||||||
| 	} | 	} | ||||||
| 	user.Name = name |  | ||||||
| 	if err := tx.Create(&user).Error; err != nil { | 	if err := tx.Create(&user).Error; err != nil { | ||||||
| 		return nil, fmt.Errorf("creating user: %w", err) | 		return nil, fmt.Errorf("creating user: %w", err) | ||||||
| 	} | 	} | ||||||
| @ -177,6 +175,10 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if len(users) == 0 { | ||||||
|  | 		return nil, ErrUserNotFound | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if len(users) != 1 { | 	if len(users) != 1 { | ||||||
| 		return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) | 		return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -460,7 +460,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( | |||||||
| 			// This is to prevent users that have already been migrated to the new OIDC format
 | 			// This is to prevent users that have already been migrated to the new OIDC format
 | ||||||
| 			// to be updated with the new OIDC identifier inexplicitly which might be the cause of an
 | 			// to be updated with the new OIDC identifier inexplicitly which might be the cause of an
 | ||||||
| 			// account takeover.
 | 			// account takeover.
 | ||||||
| 			if user != nil && user.ProviderIdentifier != "" { | 			if user != nil && user.ProviderIdentifier.Valid { | ||||||
| 				log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") | 				log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") | ||||||
| 				user = &types.User{} | 				user = &types.User{} | ||||||
| 			} | 			} | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ package types | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"cmp" | 	"cmp" | ||||||
|  | 	"database/sql" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| @ -26,7 +27,7 @@ type User struct { | |||||||
| 
 | 
 | ||||||
| 	// Username for the user, is used if email is empty
 | 	// Username for the user, is used if email is empty
 | ||||||
| 	// Should not be used, please use Username().
 | 	// Should not be used, please use Username().
 | ||||||
| 	Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"` | 	Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"` | ||||||
| 
 | 
 | ||||||
| 	// Typically the full name of the user
 | 	// Typically the full name of the user
 | ||||||
| 	DisplayName string | 	DisplayName string | ||||||
| @ -38,7 +39,7 @@ type User struct { | |||||||
| 	// Unique identifier of the user from OIDC,
 | 	// Unique identifier of the user from OIDC,
 | ||||||
| 	// comes from `sub` claim in the OIDC token
 | 	// comes from `sub` claim in the OIDC token
 | ||||||
| 	// and is used to lookup the user.
 | 	// and is used to lookup the user.
 | ||||||
| 	ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` | 	ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"` | ||||||
| 
 | 
 | ||||||
| 	// Provider is the origin of the user account,
 | 	// Provider is the origin of the user account,
 | ||||||
| 	// same as RegistrationMethod, without authkey.
 | 	// same as RegistrationMethod, without authkey.
 | ||||||
| @ -55,7 +56,7 @@ type User struct { | |||||||
| // should be used throughout headscale, in information returned to the
 | // should be used throughout headscale, in information returned to the
 | ||||||
| // user and the Policy engine.
 | // user and the Policy engine.
 | ||||||
| func (u *User) Username() string { | func (u *User) Username() string { | ||||||
| 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) | 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) | ||||||
| 
 | 
 | ||||||
| 	// TODO(kradalby): Wire up all of this for the future
 | 	// TODO(kradalby): Wire up all of this for the future
 | ||||||
| 	// if !strings.Contains(username, "@") {
 | 	// if !strings.Contains(username, "@") {
 | ||||||
| @ -118,7 +119,7 @@ func (u *User) Proto() *v1.User { | |||||||
| 		CreatedAt:     timestamppb.New(u.CreatedAt), | 		CreatedAt:     timestamppb.New(u.CreatedAt), | ||||||
| 		DisplayName:   u.DisplayName, | 		DisplayName:   u.DisplayName, | ||||||
| 		Email:         u.Email, | 		Email:         u.Email, | ||||||
| 		ProviderId:    u.ProviderIdentifier, | 		ProviderId:    u.ProviderIdentifier.String, | ||||||
| 		Provider:      u.Provider, | 		Provider:      u.Provider, | ||||||
| 		ProfilePicUrl: u.ProfilePicURL, | 		ProfilePicUrl: u.ProfilePicURL, | ||||||
| 	} | 	} | ||||||
| @ -145,7 +146,7 @@ func (c *OIDCClaims) Identifier() string { | |||||||
| // FromClaim overrides a User from OIDC claims.
 | // FromClaim overrides a User from OIDC claims.
 | ||||||
| // All fields will be updated, except for the ID.
 | // All fields will be updated, except for the ID.
 | ||||||
| func (u *User) FromClaim(claims *OIDCClaims) { | func (u *User) FromClaim(claims *OIDCClaims) { | ||||||
| 	u.ProviderIdentifier = claims.Identifier() | 	u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} | ||||||
| 	u.DisplayName = claims.Name | 	u.DisplayName = claims.Name | ||||||
| 	if claims.EmailVerified { | 	if claims.EmailVerified { | ||||||
| 		u.Email = claims.Email | 		u.Email = claims.Email | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | |||||||
| 	scenario := AuthOIDCScenario{ | 	scenario := AuthOIDCScenario{ | ||||||
| 		Scenario: baseScenario, | 		Scenario: baseScenario, | ||||||
| 	} | 	} | ||||||
| 	defer scenario.ShutdownAssertNoPanics(t) | 	// defer scenario.ShutdownAssertNoPanics(t)
 | ||||||
| 
 | 
 | ||||||
| 	// Logins to MockOIDC is served by a queue with a strict order,
 | 	// Logins to MockOIDC is served by a queue with a strict order,
 | ||||||
| 	// if we use more than one node per user, the order of the logins
 | 	// if we use more than one node per user, the order of the logins
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user