1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-11-10 01:20:58 +01:00
This commit is contained in:
Justin Angel 2025-11-05 18:10:30 +00:00 committed by GitHub
commit c89c6f8c3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 33 deletions

View File

@ -361,6 +361,17 @@ unix_socket_permission: "0770"
# # required "openid" scope. # # required "openid" scope.
# scope: ["openid", "profile", "email"] # scope: ["openid", "profile", "email"]
# #
# # Enable this setting to accept the user's email address regardless
# # if "email_verified: true" is sent by identity provider.
# #
# # By default, "email_verified: true" must appear in claims or user info
# # before Headscale will accept the principal's email address as the user
# # account is created after successful authentication.
# #
# # This setting is useful when claims and their mapping can't be controlled,
# # such as when using Cloudflare One-time pin for authentication.
# use_unverified_email: false
#
# # Provide custom key/value pairs which get sent to the identity provider's # # Provide custom key/value pairs which get sent to the identity provider's
# # authorization endpoint. # # authorization endpoint.
# extra_params: # extra_params:

View File

@ -167,7 +167,11 @@ func (a *AuthProviderOIDC) RegisterHandler(
extras = append(extras, oauth2.S256ChallengeOption(verifier)) extras = append(extras, oauth2.S256ChallengeOption(verifier))
case types.PKCEMethodPlain: case types.PKCEMethodPlain:
// oauth2 does not have a plain challenge option, so we add it manually // oauth2 does not have a plain challenge option, so we add it manually
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) extras = append(
extras,
oauth2.SetAuthURLParam("code_challenge_method", "plain"),
oauth2.SetAuthURLParam("code_challenge", verifier),
)
} }
} }
@ -281,12 +285,14 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
util.LogErr(err, "could not get userinfo; only using claims from id token") util.LogErr(err, "could not get userinfo; only using claims from id token")
} }
if bool(claims.EmailVerified) || a.cfg.UseUnverifiedEmail {
// The user claims are now updated from the userinfo endpoint so we can verify the user // The user claims are now updated from the userinfo endpoint so we can verify the user
// against allowed emails, email domains, and groups. // against allowed emails, email domains, and groups.
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
httpError(writer, err) httpError(writer, err)
return return
} }
}
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil { if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
httpError(writer, err) httpError(writer, err)
@ -332,8 +338,14 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") log.Debug().
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) Caller().
Str("registration_id", registrationId.String()).
Msg("registration session expired before authorization completed")
httpError(
writer,
NewHTTPError(http.StatusGone, "login session expired, try again", err),
)
return return
} }
@ -373,7 +385,11 @@ func extractCodeAndStateParamFromRequest(
state := req.URL.Query().Get("state") state := req.URL.Query().Get("state")
if code == "" || state == "" { if code == "" || state == "" {
return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams) return "", "", NewHTTPError(
http.StatusBadRequest,
"missing code or state parameter",
errEmptyOIDCCallbackParams,
)
} }
return code, state, nil return code, state, nil
@ -390,7 +406,11 @@ func (a *AuthProviderOIDC) getOauth2Token(
if a.cfg.PKCE.Enabled { if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state) regInfo, ok := a.registrationCache.Get(state)
if !ok { if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) return nil, NewHTTPError(
http.StatusNotFound,
"registration not found",
errNoOIDCRegistrationInfo,
)
} }
if regInfo.Verifier != nil { if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
@ -399,7 +419,11 @@ func (a *AuthProviderOIDC) getOauth2Token(
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) return nil, NewHTTPError(
http.StatusForbidden,
"invalid code",
fmt.Errorf("could not exchange code for token: %w", err),
)
} }
return oauth2Token, err return oauth2Token, err
@ -418,7 +442,11 @@ func (a *AuthProviderOIDC) extractIDToken(
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID}) verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken) idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err)) return nil, NewHTTPError(
http.StatusForbidden,
"failed to verify id_token",
fmt.Errorf("failed to verify ID token: %w", err),
)
} }
return idToken, nil return idToken, nil
@ -433,7 +461,11 @@ func validateOIDCAllowedDomains(
if len(allowedDomains) > 0 { if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 || if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) { !slices.Contains(allowedDomains, claims.Email[at+1:]) {
return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains) return NewHTTPError(
http.StatusUnauthorized,
"unauthorised domain",
errOIDCAllowedDomains,
)
} }
} }
@ -505,7 +537,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
user = &types.User{} user = &types.User{}
} }
user.FromClaim(claims) user.FromClaim(claims, a.cfg.UseUnverifiedEmail)
if newUser { if newUser {
user, c, err = a.h.state.CreateUser(*user) user, c, err = a.h.state.CreateUser(*user)

View File

@ -31,9 +31,15 @@ const (
) )
var ( var (
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") errOidcMutuallyExclusive = errors.New(
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") "oidc_client_secret and oidc_client_secret_path are mutually exclusive",
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") )
errServerURLSuffix = errors.New(
"server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable",
)
errServerURLSame = errors.New(
"server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable",
)
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
) )
@ -179,6 +185,7 @@ type OIDCConfig struct {
ClientSecret string ClientSecret string
Scope []string Scope []string
ExtraParams map[string]string ExtraParams map[string]string
UseUnverifiedEmail bool
AllowedDomains []string AllowedDomains []string
AllowedUsers []string AllowedUsers []string
AllowedGroups []string AllowedGroups []string
@ -327,6 +334,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.pkce.enabled", false) viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256") viper.SetDefault("oidc.pkce.method", "S256")
viper.SetDefault("oidc.use_unverified_email", false)
viper.SetDefault("logtail.enabled", false) viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false) viper.SetDefault("randomize_client_port", false)
@ -385,10 +393,18 @@ func validateServerConfig() error {
} }
} }
if viper.IsSet("oidc.use_unverified_email") {
log.Warn().
Msg("unverified emails will be accepted during oidc authentication (oidc.use_unverified_email=true)")
} else {
log.Warn().Msg("only verified emails will be accepted during oidc authentication (oidc.use_unverified_email=false)")
}
depr.Log() depr.Log()
if viper.IsSet("dns.extra_records") && viper.IsSet("dns.extra_records_path") { if viper.IsSet("dns.extra_records") && viper.IsSet("dns.extra_records_path") {
log.Fatal().Msg("Fatal config error: dns.extra_records and dns.extra_records_path are mutually exclusive. Please remove one of them from your config file") log.Fatal().
Msg("Fatal config error: dns.extra_records and dns.extra_records_path are mutually exclusive. Please remove one of them from your config file")
} }
// Collect any validation errors and return them all at once // Collect any validation errors and return them all at once
@ -952,6 +968,7 @@ func LoadServerConfig() (*Config, error) {
ClientSecret: oidcClientSecret, ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"), Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"), ExtraParams: viper.GetStringMapString("oidc.extra_params"),
UseUnverifiedEmail: viper.GetBool("oidc.use_unverified_email"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),

View File

@ -324,7 +324,7 @@ type OIDCUserInfo struct {
// FromClaim overrides a User from OIDC claims. // FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID. // All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) { func (u *User) FromClaim(claims *OIDCClaims, useUnverifiedEmail bool) {
err := util.ValidateUsername(claims.Username) err := util.ValidateUsername(claims.Username)
if err == nil { if err == nil {
u.Name = claims.Username u.Name = claims.Username
@ -332,7 +332,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username) log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
} }
if claims.EmailVerified { if claims.EmailVerified || FlexibleBoolean(useUnverifiedEmail) {
_, err = mail.ParseAddress(claims.Email) _, err = mail.ParseAddress(claims.Email)
if err == nil { if err == nil {
u.Email = claims.Email u.Email = claims.Email

View File

@ -293,6 +293,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
jsonstr string jsonstr string
useUnverifiedEmail bool
want User want User
}{ }{
{ {
@ -348,6 +349,25 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
}, },
}, },
}, },
{
name: "use-unverified-email",
jsonstr: `
{
"sub": "test-unverified-email",
"email": "test-unverified-email@test.no",
"email_verified": "false"
}
`,
useUnverifiedEmail: true,
want: User{
Provider: util.RegisterMethodOIDC,
Email: "test-unverified-email@test.no",
ProviderIdentifier: sql.NullString{
String: "/test-unverified-email",
Valid: true,
},
},
},
{ {
// From https://github.com/juanfont/headscale/issues/2333 // From https://github.com/juanfont/headscale/issues/2333
name: "okta-oidc-claim-20250121", name: "okta-oidc-claim-20250121",
@ -458,7 +478,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
var user User var user User
user.FromClaim(&got) user.FromClaim(&got, tt.useUnverifiedEmail)
if diff := cmp.Diff(user, tt.want); diff != "" { if diff := cmp.Diff(user, tt.want); diff != "" {
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff) t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
} }