mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-30 00:09:42 +01:00
Merge pull request #126 from unreality/main
Initial work on OIDC (SSO) integration
This commit is contained in:
commit
fbdfa55629
@ -31,6 +31,7 @@ headscale implements this coordination server.
|
||||
- [x] Taildrop (File Sharing)
|
||||
- [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
|
||||
- [x] DNS (passing DNS servers to nodes)
|
||||
- [x] Single-Sign-On (via Open ID Connect)
|
||||
- [x] Share nodes between namespaces
|
||||
- [x] MagicDNS (see `docs/`)
|
||||
|
||||
@ -49,7 +50,6 @@ headscale implements this coordination server.
|
||||
|
||||
Suggestions/PRs welcomed!
|
||||
|
||||
|
||||
## Running headscale
|
||||
|
||||
Please have a look at the documentation under [`docs/`](docs/).
|
||||
|
128
api.go
128
api.go
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -64,7 +65,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot parse machine key")
|
||||
machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc()
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
c.String(http.StatusInternalServerError, "Sad!")
|
||||
return
|
||||
}
|
||||
@ -75,37 +76,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc()
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
c.String(http.StatusInternalServerError, "Very sad!")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
var m Machine
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||
m = Machine{
|
||||
Expiry: &req.Expiry,
|
||||
MachineKey: mKey.HexString(),
|
||||
Name: req.Hostinfo.Hostname,
|
||||
NodeKey: wgkey.Key(req.NodeKey).HexString(),
|
||||
LastSuccessfulUpdate: &now,
|
||||
newMachine := Machine{
|
||||
Expiry: &time.Time{},
|
||||
MachineKey: mKey.HexString(),
|
||||
Name: req.Hostinfo.Hostname,
|
||||
}
|
||||
if err := h.db.Create(&m).Error; err != nil {
|
||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
machineRegistrations.WithLabelValues("unkown", "web", "error", m.Namespace.Name).Inc()
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
|
||||
return
|
||||
}
|
||||
m = &newMachine
|
||||
}
|
||||
|
||||
if !m.Registered && req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(c, h.db, mKey, req, m)
|
||||
h.handleAuthKey(c, h.db, mKey, req, *m)
|
||||
return
|
||||
}
|
||||
|
||||
@ -113,7 +110,36 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
|
||||
// We have the updated key!
|
||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
if m.Registered {
|
||||
|
||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||
log.Info().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("Client requested logout")
|
||||
|
||||
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||
h.db.Save(&m)
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = false
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
return
|
||||
}
|
||||
|
||||
if m.Registered && m.Expiry.UTC().After(now) {
|
||||
// The machine registration is valid, respond with redirect to /map
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
@ -122,6 +148,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *m.Namespace.toUser()
|
||||
resp.Login = *m.Namespace.toLogin()
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -137,12 +165,30 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// The client has registered before, but has expired
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("Not registered and not NodeKey rotation. Sending a authurl to register")
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
h.cfg.ServerURL, mKey.HexString())
|
||||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
}
|
||||
|
||||
// When a client connects, it may request a specific expiry time in its
|
||||
// RegisterRequest (https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L634)
|
||||
// RequestedExpiry is used to store the clients requested expiry time since the authentication flow is broken
|
||||
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
||||
// retrieved again after the user has authenticated. After the authentication flow
|
||||
// completes, RequestedExpiry is copied into Expiry.
|
||||
m.RequestedExpiry = &req.Expiry
|
||||
|
||||
h.db.Save(&m)
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -158,8 +204,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration
|
||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() {
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
@ -182,35 +228,23 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// We arrive here after a client is restarted without finalizing the authentication flow or
|
||||
// when headscale is stopped in the middle of the auth process.
|
||||
if m.Registered {
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
return
|
||||
}
|
||||
|
||||
// The machine registration is new, redirect the client to the registration URL
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
h.cfg.ServerURL, mKey.HexString())
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
}
|
||||
|
||||
// save the requested expiry time for retrieval later in the authentication flow
|
||||
m.RequestedExpiry = &req.Expiry
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||
h.db.Save(&m)
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
29
app.go
29
app.go
@ -14,6 +14,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
apiV1 "github.com/juanfont/headscale/gen/go/v1"
|
||||
@ -62,6 +66,18 @@ type Config struct {
|
||||
ACMEEmail string
|
||||
|
||||
DNSConfig *tailcfg.DNSConfig
|
||||
|
||||
OIDC OIDCConfig
|
||||
|
||||
MaxMachineRegistrationDuration time.Duration
|
||||
DefaultMachineRegistrationDuration time.Duration
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
MatchMap map[string]string
|
||||
}
|
||||
|
||||
type DERPConfig struct {
|
||||
@ -87,6 +103,10 @@ type Headscale struct {
|
||||
aclRules *[]tailcfg.FilterRule
|
||||
|
||||
lastStateChange sync.Map
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
oidcStateCache *cache.Cache
|
||||
}
|
||||
|
||||
// NewHeadscale returns the Headscale app.
|
||||
@ -127,6 +147,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = h.initOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
|
||||
if err != nil {
|
||||
@ -255,6 +282,8 @@ func (h *Headscale) Serve() error {
|
||||
r.GET("/register", h.RegisterWebAPI)
|
||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
r.POST("/machine/:id", h.RegistrationHandler)
|
||||
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||
r.GET("/oidc/callback", h.OIDCCallback)
|
||||
r.GET("/apple", h.AppleMobileConfig)
|
||||
r.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||
|
||||
|
3
cli.go
3
cli.go
@ -23,6 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
||||
return nil, errors.New("Machine not found")
|
||||
}
|
||||
|
||||
h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered
|
||||
|
||||
if m.isAlreadyRegistered() {
|
||||
return nil, errors.New("Machine already registered")
|
||||
}
|
||||
@ -36,5 +38,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "cli"
|
||||
h.db.Save(&m)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
20
cli_test.go
20
cli_test.go
@ -1,6 +1,8 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
@ -8,14 +10,18 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
m := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
IPAddress: "10.0.0.1",
|
||||
ID: 0,
|
||||
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
IPAddress: "10.0.0.1",
|
||||
Expiry: &now,
|
||||
RequestedExpiry: &now,
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -215,6 +216,26 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// maxMachineRegistrationDuration is the maximum time headscale will allow a client to (optionally) request for
|
||||
// the machine key expiry time. RegisterRequests with Expiry times that are more than
|
||||
// maxMachineRegistrationDuration in the future will be clamped to (now + maxMachineRegistrationDuration)
|
||||
maxMachineRegistrationDuration, _ := time.ParseDuration(
|
||||
"10h",
|
||||
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
|
||||
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
|
||||
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
|
||||
}
|
||||
|
||||
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
|
||||
// specified by the tailscale client. It is the default amount of time a machine registration is valid for
|
||||
// (ie the amount of time before the user has to re-authenticate when requesting a connection)
|
||||
defaultMachineRegistrationDuration, _ := time.ParseDuration(
|
||||
"8h",
|
||||
) // use 8h here because it's the length of a standard business day
|
||||
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
|
||||
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
|
||||
}
|
||||
|
||||
dnsConfig, baseDomain := GetDNSConfig()
|
||||
derpConfig := GetDERPConfig()
|
||||
|
||||
@ -249,8 +270,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||
|
||||
ACMEEmail: viper.GetString("acme_email"),
|
||||
ACMEURL: viper.GetString("acme_url"),
|
||||
|
||||
OIDC: headscale.OIDCConfig{
|
||||
Issuer: viper.GetString("oidc.issuer"),
|
||||
ClientID: viper.GetString("oidc.client_id"),
|
||||
ClientSecret: viper.GetString("oidc.client_secret"),
|
||||
},
|
||||
|
||||
MaxMachineRegistrationDuration: maxMachineRegistrationDuration,
|
||||
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration,
|
||||
}
|
||||
|
||||
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||
|
||||
h, err := headscale.NewHeadscale(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -312,3 +344,15 @@ func HasJsonOutputFlag() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
|
||||
// the match map is valid regex strings.
|
||||
func loadOIDCMatchMap() map[string]string {
|
||||
strMap := viper.GetStringMapString("oidc.domain_map")
|
||||
|
||||
for oidcMatcher := range strMap {
|
||||
_ = regexp.MustCompile(oidcMatcher)
|
||||
}
|
||||
|
||||
return strMap
|
||||
}
|
||||
|
@ -64,3 +64,18 @@ dns_config:
|
||||
|
||||
magic_dns: true
|
||||
base_domain: example.com
|
||||
|
||||
|
||||
# headscale supports experimental OpenID connect support,
|
||||
# it is still being tested and might have some bugs, please
|
||||
# help us test it.
|
||||
# OpenID Connect
|
||||
# oidc:
|
||||
# issuer: "https://your-oidc.issuer.com/path"
|
||||
# client_id: "your-oidc-client-id"
|
||||
# client_secret: "your-oidc-client-secret"
|
||||
#
|
||||
# # Domain map is used to map incomming users (by their email) to
|
||||
# # a namespace. The key can be a string, or regex.
|
||||
# domain_map:
|
||||
# ".*": default-namespace
|
||||
|
3
go.mod
3
go.mod
@ -7,6 +7,7 @@ require (
|
||||
github.com/Microsoft/go-winio v0.5.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.1.1 // indirect
|
||||
github.com/containerd/continuity v0.1.0 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.1.0
|
||||
github.com/docker/cli v20.10.8+incompatible // indirect
|
||||
github.com/docker/docker v20.10.8+incompatible // indirect
|
||||
github.com/efekarakus/termcolor v1.0.1
|
||||
@ -23,6 +24,7 @@ require (
|
||||
github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect
|
||||
github.com/opencontainers/runc v1.0.2 // indirect
|
||||
github.com/ory/dockertest/v3 v3.7.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/pterm/pterm v0.12.30
|
||||
github.com/rs/zerolog v1.25.0
|
||||
@ -36,6 +38,7 @@ require (
|
||||
github.com/zsais/go-gin-prometheus v0.1.0
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
|
||||
golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83
|
||||
|
9
go.sum
9
go.sum
@ -153,6 +153,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE
|
||||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw=
|
||||
github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
@ -767,6 +769,8 @@ github.com/ory/dockertest/v3 v3.7.0 h1:Bijzonc69Ont3OU0a3TWKJ1Rzlh3TsDXP1JrTAkSm
|
||||
github.com/ory/dockertest/v3 v3.7.0/go.mod h1:PvCCgnP7AfBZeVrzwiUTjZx/IUXlGLC1zQlUQrLIlUE=
|
||||
github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM=
|
||||
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pborman/getopt v1.1.0/go.mod h1:FxXoW1Re00sQG/+KIkuSqRL/LwQgSkv7uyac+STFsbk=
|
||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
@ -1137,6 +1141,7 @@ golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
@ -1174,6 +1179,7 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
|
||||
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f h1:Qmd2pbz05z7z6lm0DrgQVVPuBm92jqujBKMHMOlOQEw=
|
||||
golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -1436,6 +1442,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
|
||||
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
@ -1553,6 +1560,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU=
|
||||
gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
||||
gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
|
||||
gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
|
||||
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
|
||||
|
45
machine.go
45
machine.go
@ -36,6 +36,7 @@ type Machine struct {
|
||||
LastSeen *time.Time
|
||||
LastSuccessfulUpdate *time.Time
|
||||
Expiry *time.Time
|
||||
RequestedExpiry *time.Time
|
||||
|
||||
HostInfo datatypes.JSON
|
||||
Endpoints datatypes.JSON
|
||||
@ -56,6 +57,38 @@ func (m Machine) isAlreadyRegistered() bool {
|
||||
return m.Registered
|
||||
}
|
||||
|
||||
// isExpired returns whether the machine registration has expired
|
||||
func (m Machine) isExpired() bool {
|
||||
return time.Now().UTC().After(*m.Expiry)
|
||||
}
|
||||
|
||||
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
||||
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
||||
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
||||
// expiry time.
|
||||
func (h *Headscale) updateMachineExpiry(m *Machine) {
|
||||
if m.isExpired() {
|
||||
now := time.Now().UTC()
|
||||
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
|
||||
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
|
||||
|
||||
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
||||
if maxExpiry.Before(*m.RequestedExpiry) {
|
||||
log.Debug().
|
||||
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
||||
m.Expiry = &maxExpiry
|
||||
} else if m.RequestedExpiry.IsZero() {
|
||||
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
||||
m.Expiry = &defaultExpiry
|
||||
} else {
|
||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
|
||||
m.Expiry = m.RequestedExpiry
|
||||
}
|
||||
|
||||
h.db.Save(&m)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Str("func", "getDirectPeers").
|
||||
@ -326,7 +359,11 @@ func (ms MachinesP) String() string {
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
}
|
||||
|
||||
func (ms Machines) toNodes(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) ([]*tailcfg.Node, error) {
|
||||
func (ms Machines) toNodes(
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
includeRoutes bool,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
nodes := make([]*tailcfg.Node, len(ms))
|
||||
|
||||
for index, machine := range ms {
|
||||
@ -446,8 +483,10 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
||||
}
|
||||
|
||||
n := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
strconv.FormatUint(m.ID, 10),
|
||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||
Name: hostname,
|
||||
User: tailcfg.UserID(m.NamespaceID),
|
||||
Key: tailcfg.NodeKey(nKey),
|
||||
|
@ -246,6 +246,17 @@ func (n *Namespace) toUser() *tailcfg.User {
|
||||
return &u
|
||||
}
|
||||
|
||||
func (n *Namespace) toLogin() *tailcfg.Login {
|
||||
l := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
ProfilePicURL: "",
|
||||
Domain: "headscale.net",
|
||||
}
|
||||
return &l
|
||||
}
|
||||
|
||||
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
|
||||
namespaceMap := make(map[string]Namespace)
|
||||
namespaceMap[m.Namespace.Name] = m.Namespace
|
||||
|
228
oidc.go
Normal file
228
oidc.go
Normal file
@ -0,0 +1,228 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
func (h *Headscale) initOIDC() error {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
if h.oauth2Config == nil {
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.cfg.OIDC.ClientID,
|
||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||
Endpoint: h.oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
}
|
||||
|
||||
// init the state cache if it hasn't been already
|
||||
if h.oidcStateCache == nil {
|
||||
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:mKey
|
||||
func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||
mKeyStr := c.Param("mkey")
|
||||
if mKeyStr == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
return
|
||||
}
|
||||
|
||||
b := make([]byte, 16)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
return
|
||||
}
|
||||
|
||||
stateStr := hex.EncodeToString(b)[:32]
|
||||
|
||||
// place the machine key into the state cache, so it can be retrieved later
|
||||
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
||||
|
||||
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
||||
|
||||
c.Redirect(http.StatusFound, authUrl)
|
||||
}
|
||||
|
||||
// OIDCCallback handles the callback from the OIDC endpoint
|
||||
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
|
||||
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||
// Listens in /oidc/callback
|
||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken)
|
||||
|
||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
||||
if !rawIDTokenOK {
|
||||
c.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||
return
|
||||
}
|
||||
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
||||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
|
||||
//userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
|
||||
//if err != nil {
|
||||
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err))
|
||||
// return
|
||||
//}
|
||||
|
||||
// Extract custom claims
|
||||
var claims IDTokenClaims
|
||||
if err = idToken.Claims(&claims); err != nil {
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
// retrieve machinekey from state cache
|
||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||
|
||||
if !mKeyFound {
|
||||
log.Error().Msg("requested machine state key expired before authorisation completed")
|
||||
c.String(http.StatusBadRequest, "state has expired")
|
||||
return
|
||||
}
|
||||
mKeyStr, mKeyOK := mKeyIf.(string)
|
||||
|
||||
if !mKeyOK {
|
||||
log.Error().Msg("could not get machine key from cache")
|
||||
c.String(http.StatusInternalServerError, "could not get machine key from cache")
|
||||
return
|
||||
}
|
||||
|
||||
// retrieve machine information
|
||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
if err != nil {
|
||||
log.Error().Msg("machine key not found in database")
|
||||
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
if err != nil {
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
||||
return
|
||||
}
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
h.updateMachineExpiry(m)
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
<p>
|
||||
Authenticated as %s, you can now close this window.
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
`, claims.Email)))
|
||||
|
||||
}
|
||||
|
||||
log.Error().
|
||||
Str("email", claims.Email).
|
||||
Str("username", claims.Username).
|
||||
Str("machine", m.Name).
|
||||
Msg("Email could not be mapped to a namespace")
|
||||
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
|
||||
}
|
||||
|
||||
// getNamespaceFromEmail passes the users email through a list of "matchers"
|
||||
// and iterates through them until it matches and returns a namespace.
|
||||
// If no match is found, an empty string will be returned.
|
||||
// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration.
|
||||
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
|
||||
for match, namespace := range h.cfg.OIDC.MatchMap {
|
||||
regex := regexp.MustCompile(match)
|
||||
if regex.MatchString(email) {
|
||||
return namespace, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
174
oidc_test.go
Normal file
174
oidc_test.go
Normal file
@ -0,0 +1,174 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
||||
type fields struct {
|
||||
cfg Config
|
||||
db *gorm.DB
|
||||
dbString string
|
||||
dbType string
|
||||
dbDebug bool
|
||||
publicKey *wgkey.Key
|
||||
privateKey *wgkey.Private
|
||||
aclPolicy *ACLPolicy
|
||||
aclRules *[]tailcfg.FilterRule
|
||||
lastStateChange sync.Map
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
oidcStateCache *cache.Cache
|
||||
}
|
||||
type args struct {
|
||||
email string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "match all",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*": "space",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "test@example.no",
|
||||
},
|
||||
want: "space",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "match user",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
"specific@user\\.no": "user-namespace",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "specific@user.no",
|
||||
},
|
||||
want: "user-namespace",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@example\\.no": "example",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "test@example.no",
|
||||
},
|
||||
want: "example",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "multi match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@example\\.no": "exammple",
|
||||
".*@gmail\\.com": "gmail",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "someuser@gmail.com",
|
||||
},
|
||||
want: "gmail",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "no match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@dontknow.no": "never",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "test@wedontknow.no",
|
||||
},
|
||||
want: "",
|
||||
want1: false,
|
||||
},
|
||||
{
|
||||
name: "multi no match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@dontknow.no": "never",
|
||||
".*@wedontknow.no": "other",
|
||||
".*\\.no": "stuffy",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "tasy@nonofthem.com",
|
||||
},
|
||||
want: "",
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
//nolint
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Headscale{
|
||||
cfg: tt.fields.cfg,
|
||||
db: tt.fields.db,
|
||||
dbString: tt.fields.dbString,
|
||||
dbType: tt.fields.dbType,
|
||||
dbDebug: tt.fields.dbDebug,
|
||||
publicKey: tt.fields.publicKey,
|
||||
privateKey: tt.fields.privateKey,
|
||||
aclPolicy: tt.fields.aclPolicy,
|
||||
aclRules: tt.fields.aclRules,
|
||||
lastStateChange: tt.fields.lastStateChange,
|
||||
oidcProvider: tt.fields.oidcProvider,
|
||||
oauth2Config: tt.fields.oauth2Config,
|
||||
oidcStateCache: tt.fields.oidcStateCache,
|
||||
}
|
||||
got, got1 := h.getNamespaceFromEmail(tt.args.email)
|
||||
if got != tt.want {
|
||||
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user