1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-30 00:09:42 +01:00

feat: add strip_email_domain to normalization of namespace

This commit is contained in:
Adrien Raffin-Caboisse 2022-02-23 14:03:07 +01:00
parent 7e4709c13f
commit 4f1f235a2e
6 changed files with 61 additions and 21 deletions

7
app.go
View File

@ -107,9 +107,10 @@ type Config struct {
} }
type OIDCConfig struct { type OIDCConfig struct {
Issuer string Issuer string
ClientID string ClientID string
ClientSecret string ClientSecret string
StripEmaildomain bool
} }
type DERPConfig struct { type DERPConfig struct {

View File

@ -63,6 +63,8 @@ func LoadConfig(path string) error {
viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.timeout", "5s")
viper.SetDefault("cli.insecure", false) viper.SetDefault("cli.insecure", false)
viper.SetDefault("oidc.strip_email_domain", true)
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err) return fmt.Errorf("fatal error reading config file: %w", err)
} }
@ -323,9 +325,10 @@ func getHeadscaleConfig() headscale.Config {
UnixSocketPermission: GetFileMode("unix_socket_permission"), UnixSocketPermission: GetFileMode("unix_socket_permission"),
OIDC: headscale.OIDCConfig{ OIDC: headscale.OIDCConfig{
Issuer: viper.GetString("oidc.issuer"), Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"), ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"), ClientSecret: viper.GetString("oidc.client_secret"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
}, },
CLI: headscale.CLIConfig{ CLI: headscale.CLIConfig{

View File

@ -180,3 +180,9 @@ unix_socket_permission: "0770"
# client_id: "your-oidc-client-id" # client_id: "your-oidc-client-id"
# client_secret: "your-oidc-client-secret" # client_secret: "your-oidc-client-secret"
# #
# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# This will transform `first-name.last-name@example.com` to the namespace `first-name.last-name`
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
# namespace: `first-name.last-name.example.com`
#
# strip_email_domain: true

View File

@ -268,10 +268,15 @@ func (n *Namespace) toProto() *v1.Namespace {
// NormalizeNamespaceName will replace forbidden chars in namespace // NormalizeNamespaceName will replace forbidden chars in namespace
// it can also return an error if the namespace doesn't respect RFC 952 and 1123. // it can also return an error if the namespace doesn't respect RFC 952 and 1123.
func NormalizeNamespaceName(name string) (string, error) { func NormalizeNamespaceName(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name) name = strings.ToLower(name)
name = strings.ReplaceAll(name, "@", ".")
name = strings.ReplaceAll(name, "'", "") name = strings.ReplaceAll(name, "'", "")
if stripEmailDomain {
idx := strings.Index(name, "@")
name = name[:idx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInNamespaceRegex.ReplaceAllString(name, "-") name = invalidCharsInNamespaceRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") { for _, elt := range strings.Split(name, ".") {

View File

@ -244,7 +244,8 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
func TestNormalizeNamespaceName(t *testing.T) { func TestNormalizeNamespaceName(t *testing.T) {
type args struct { type args struct {
name string name string
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -253,39 +254,63 @@ func TestNormalizeNamespaceName(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "normalize simple name", name: "normalize simple name",
args: args{name: "normalize-simple.name"}, args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name", want: "normalize-simple.name",
wantErr: false, wantErr: false,
}, },
{ {
name: "normalize an email", name: "normalize an email",
args: args{name: "foo.bar@example.com"}, args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com", want: "foo.bar.example.com",
wantErr: false, wantErr: false,
}, },
{ {
name: "normalize complex email", name: "normalize an email domain should be removed",
args: args{name: "foo.bar+complex-email@example.com"}, args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com", want: "foo.bar-complex-email.example.com",
wantErr: false, wantErr: false,
}, },
{ {
name: "namespace name with space", name: "namespace name with space",
args: args{name: "name space"}, args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space", want: "name-space",
wantErr: false, wantErr: false,
}, },
{ {
name: "namespace with quote", name: "namespace with quote",
args: args{name: "Jamie's iPhone 5"}, args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5", want: "jamies-iphone-5",
wantErr: false, wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeNamespaceName(tt.args.name) got, err := NormalizeNamespaceName(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf( t.Errorf(
"NormalizeNamespaceName() error = %v, wantErr %v", "NormalizeNamespaceName() error = %v, wantErr %v",

View File

@ -281,7 +281,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
now := time.Now().UTC() now := time.Now().UTC()
namespaceName, err := NormalizeNamespaceName(claims.Email) namespaceName, err := NormalizeNamespaceName(claims.Email, h.cfg.OIDC.StripEmaildomain)
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email") log.Error().Err(err).Caller().Msgf("couldn't normalize email")
ctx.String( ctx.String(