mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-10 01:20:58 +01:00
Merge 32da525b8d into 2024219bd1
This commit is contained in:
commit
c89c6f8c3b
@ -361,6 +361,17 @@ unix_socket_permission: "0770"
|
|||||||
# # required "openid" scope.
|
# # required "openid" scope.
|
||||||
# scope: ["openid", "profile", "email"]
|
# scope: ["openid", "profile", "email"]
|
||||||
#
|
#
|
||||||
|
# # Enable this setting to accept the user's email address regardless
|
||||||
|
# # if "email_verified: true" is sent by identity provider.
|
||||||
|
# #
|
||||||
|
# # By default, "email_verified: true" must appear in claims or user info
|
||||||
|
# # before Headscale will accept the principal's email address as the user
|
||||||
|
# # account is created after successful authentication.
|
||||||
|
# #
|
||||||
|
# # This setting is useful when claims and their mapping can't be controlled,
|
||||||
|
# # such as when using Cloudflare One-time pin for authentication.
|
||||||
|
# use_unverified_email: false
|
||||||
|
#
|
||||||
# # Provide custom key/value pairs which get sent to the identity provider's
|
# # Provide custom key/value pairs which get sent to the identity provider's
|
||||||
# # authorization endpoint.
|
# # authorization endpoint.
|
||||||
# extra_params:
|
# extra_params:
|
||||||
|
|||||||
@ -167,7 +167,11 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||||||
extras = append(extras, oauth2.S256ChallengeOption(verifier))
|
extras = append(extras, oauth2.S256ChallengeOption(verifier))
|
||||||
case types.PKCEMethodPlain:
|
case types.PKCEMethodPlain:
|
||||||
// oauth2 does not have a plain challenge option, so we add it manually
|
// oauth2 does not have a plain challenge option, so we add it manually
|
||||||
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
|
extras = append(
|
||||||
|
extras,
|
||||||
|
oauth2.SetAuthURLParam("code_challenge_method", "plain"),
|
||||||
|
oauth2.SetAuthURLParam("code_challenge", verifier),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,12 +285,14 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bool(claims.EmailVerified) || a.cfg.UseUnverifiedEmail {
|
||||||
// The user claims are now updated from the userinfo endpoint so we can verify the user
|
// The user claims are now updated from the userinfo endpoint so we can verify the user
|
||||||
// against allowed emails, email domains, and groups.
|
// against allowed emails, email domains, and groups.
|
||||||
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
|
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
@ -332,8 +338,14 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||||
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
log.Debug().
|
||||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
Caller().
|
||||||
|
Str("registration_id", registrationId.String()).
|
||||||
|
Msg("registration session expired before authorization completed")
|
||||||
|
httpError(
|
||||||
|
writer,
|
||||||
|
NewHTTPError(http.StatusGone, "login session expired, try again", err),
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -373,7 +385,11 @@ func extractCodeAndStateParamFromRequest(
|
|||||||
state := req.URL.Query().Get("state")
|
state := req.URL.Query().Get("state")
|
||||||
|
|
||||||
if code == "" || state == "" {
|
if code == "" || state == "" {
|
||||||
return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams)
|
return "", "", NewHTTPError(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"missing code or state parameter",
|
||||||
|
errEmptyOIDCCallbackParams,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return code, state, nil
|
return code, state, nil
|
||||||
@ -390,7 +406,11 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||||||
if a.cfg.PKCE.Enabled {
|
if a.cfg.PKCE.Enabled {
|
||||||
regInfo, ok := a.registrationCache.Get(state)
|
regInfo, ok := a.registrationCache.Get(state)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
return nil, NewHTTPError(
|
||||||
|
http.StatusNotFound,
|
||||||
|
"registration not found",
|
||||||
|
errNoOIDCRegistrationInfo,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if regInfo.Verifier != nil {
|
if regInfo.Verifier != nil {
|
||||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||||
@ -399,7 +419,11 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||||||
|
|
||||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
|
return nil, NewHTTPError(
|
||||||
|
http.StatusForbidden,
|
||||||
|
"invalid code",
|
||||||
|
fmt.Errorf("could not exchange code for token: %w", err),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return oauth2Token, err
|
return oauth2Token, err
|
||||||
@ -418,7 +442,11 @@ func (a *AuthProviderOIDC) extractIDToken(
|
|||||||
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
|
return nil, NewHTTPError(
|
||||||
|
http.StatusForbidden,
|
||||||
|
"failed to verify id_token",
|
||||||
|
fmt.Errorf("failed to verify ID token: %w", err),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return idToken, nil
|
return idToken, nil
|
||||||
@ -433,7 +461,11 @@ func validateOIDCAllowedDomains(
|
|||||||
if len(allowedDomains) > 0 {
|
if len(allowedDomains) > 0 {
|
||||||
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
||||||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
|
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
|
||||||
return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains)
|
return NewHTTPError(
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
"unauthorised domain",
|
||||||
|
errOIDCAllowedDomains,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -505,7 +537,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
user = &types.User{}
|
user = &types.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
user.FromClaim(claims)
|
user.FromClaim(claims, a.cfg.UseUnverifiedEmail)
|
||||||
|
|
||||||
if newUser {
|
if newUser {
|
||||||
user, c, err = a.h.state.CreateUser(*user)
|
user, c, err = a.h.state.CreateUser(*user)
|
||||||
|
|||||||
@ -31,9 +31,15 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
errOidcMutuallyExclusive = errors.New(
|
||||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
|
||||||
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
|
)
|
||||||
|
errServerURLSuffix = errors.New(
|
||||||
|
"server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable",
|
||||||
|
)
|
||||||
|
errServerURLSame = errors.New(
|
||||||
|
"server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable",
|
||||||
|
)
|
||||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,6 +185,7 @@ type OIDCConfig struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
Scope []string
|
Scope []string
|
||||||
ExtraParams map[string]string
|
ExtraParams map[string]string
|
||||||
|
UseUnverifiedEmail bool
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
@ -327,6 +334,7 @@ func LoadConfig(path string, isFile bool) error {
|
|||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
viper.SetDefault("oidc.pkce.enabled", false)
|
viper.SetDefault("oidc.pkce.enabled", false)
|
||||||
viper.SetDefault("oidc.pkce.method", "S256")
|
viper.SetDefault("oidc.pkce.method", "S256")
|
||||||
|
viper.SetDefault("oidc.use_unverified_email", false)
|
||||||
|
|
||||||
viper.SetDefault("logtail.enabled", false)
|
viper.SetDefault("logtail.enabled", false)
|
||||||
viper.SetDefault("randomize_client_port", false)
|
viper.SetDefault("randomize_client_port", false)
|
||||||
@ -385,10 +393,18 @@ func validateServerConfig() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if viper.IsSet("oidc.use_unverified_email") {
|
||||||
|
log.Warn().
|
||||||
|
Msg("unverified emails will be accepted during oidc authentication (oidc.use_unverified_email=true)")
|
||||||
|
} else {
|
||||||
|
log.Warn().Msg("only verified emails will be accepted during oidc authentication (oidc.use_unverified_email=false)")
|
||||||
|
}
|
||||||
|
|
||||||
depr.Log()
|
depr.Log()
|
||||||
|
|
||||||
if viper.IsSet("dns.extra_records") && viper.IsSet("dns.extra_records_path") {
|
if viper.IsSet("dns.extra_records") && viper.IsSet("dns.extra_records_path") {
|
||||||
log.Fatal().Msg("Fatal config error: dns.extra_records and dns.extra_records_path are mutually exclusive. Please remove one of them from your config file")
|
log.Fatal().
|
||||||
|
Msg("Fatal config error: dns.extra_records and dns.extra_records_path are mutually exclusive. Please remove one of them from your config file")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect any validation errors and return them all at once
|
// Collect any validation errors and return them all at once
|
||||||
@ -952,6 +968,7 @@ func LoadServerConfig() (*Config, error) {
|
|||||||
ClientSecret: oidcClientSecret,
|
ClientSecret: oidcClientSecret,
|
||||||
Scope: viper.GetStringSlice("oidc.scope"),
|
Scope: viper.GetStringSlice("oidc.scope"),
|
||||||
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
||||||
|
UseUnverifiedEmail: viper.GetBool("oidc.use_unverified_email"),
|
||||||
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
||||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||||
|
|||||||
@ -324,7 +324,7 @@ type OIDCUserInfo struct {
|
|||||||
|
|
||||||
// FromClaim overrides a User from OIDC claims.
|
// FromClaim overrides a User from OIDC claims.
|
||||||
// All fields will be updated, except for the ID.
|
// All fields will be updated, except for the ID.
|
||||||
func (u *User) FromClaim(claims *OIDCClaims) {
|
func (u *User) FromClaim(claims *OIDCClaims, useUnverifiedEmail bool) {
|
||||||
err := util.ValidateUsername(claims.Username)
|
err := util.ValidateUsername(claims.Username)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
u.Name = claims.Username
|
u.Name = claims.Username
|
||||||
@ -332,7 +332,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
|
|||||||
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
|
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.EmailVerified {
|
if claims.EmailVerified || FlexibleBoolean(useUnverifiedEmail) {
|
||||||
_, err = mail.ParseAddress(claims.Email)
|
_, err = mail.ParseAddress(claims.Email)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
u.Email = claims.Email
|
u.Email = claims.Email
|
||||||
|
|||||||
@ -293,6 +293,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
jsonstr string
|
jsonstr string
|
||||||
|
useUnverifiedEmail bool
|
||||||
want User
|
want User
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@ -348,6 +349,25 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "use-unverified-email",
|
||||||
|
jsonstr: `
|
||||||
|
{
|
||||||
|
"sub": "test-unverified-email",
|
||||||
|
"email": "test-unverified-email@test.no",
|
||||||
|
"email_verified": "false"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
useUnverifiedEmail: true,
|
||||||
|
want: User{
|
||||||
|
Provider: util.RegisterMethodOIDC,
|
||||||
|
Email: "test-unverified-email@test.no",
|
||||||
|
ProviderIdentifier: sql.NullString{
|
||||||
|
String: "/test-unverified-email",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
// From https://github.com/juanfont/headscale/issues/2333
|
// From https://github.com/juanfont/headscale/issues/2333
|
||||||
name: "okta-oidc-claim-20250121",
|
name: "okta-oidc-claim-20250121",
|
||||||
@ -458,7 +478,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
|||||||
|
|
||||||
var user User
|
var user User
|
||||||
|
|
||||||
user.FromClaim(&got)
|
user.FromClaim(&got, tt.useUnverifiedEmail)
|
||||||
if diff := cmp.Diff(user, tt.want); diff != "" {
|
if diff := cmp.Diff(user, tt.want); diff != "" {
|
||||||
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
|
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user