1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-08 00:11:42 +01:00

Clean up logging and error handling in oidc

We should never expose errors via web, it gives attackers a lot of info
(Insert OWASP guide).

Also handle error that didnt separate not found gorm issue and other
errors.
This commit is contained in:
Kristoffer Dalby 2021-11-21 21:51:39 +00:00
parent fac33e46e1
commit fcd4d94927
No known key found for this signature in database
GPG Key ID: 09F62DC067465735

62
oidc.go
View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp" "regexp"
@ -15,6 +16,7 @@ import (
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm"
) )
const ( const (
@ -37,7 +39,10 @@ func (h *Headscale) initOIDC() error {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil { if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) log.Error().
Err(err).
Caller().
Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err return err
} }
@ -69,8 +74,8 @@ func (h *Headscale) initOIDC() error {
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
mKeyStr := ctx.Param("mkey") machineKeyStr := ctx.Param("mkey")
if mKeyStr == "" { if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
@ -78,7 +83,9 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil { if _, err := rand.Read(randomBlob); err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().
Caller().
Msg("could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return return
@ -87,7 +94,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration) h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -130,7 +137,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil { if err != nil {
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) log.Error().
Err(err).
Caller().
Msg("failed to verify id token")
ctx.String(http.StatusBadRequest, "Failed to verify id token")
return return
} }
@ -145,27 +156,31 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// Extract custom claims // Extract custom claims
var claims IDTokenClaims var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil { if err = idToken.Claims(&claims); err != nil {
log.Error().
Err(err).
Caller().
Msg("Failed to decode id token claims")
ctx.String( ctx.String(
http.StatusBadRequest, http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err), fmt.Sprintf("Failed to decode id token claims"),
) )
return return
} }
// retrieve machinekey from state cache // retrieve machinekey from state cache
mKeyIf, mKeyFound := h.oidcStateCache.Get(state) machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound { if !machineKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") ctx.String(http.StatusBadRequest, "state has expired")
return return
} }
mKeyStr, mKeyOK := mKeyIf.(string) machineKey, machineKeyOK := machineKeyIf.(string)
if !mKeyOK { if !machineKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get machine key from cache")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
@ -176,7 +191,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
// retrieve machine information // retrieve machine information
machine, err := h.GetMachineByMachineKey(mKeyStr) machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil { if err != nil {
log.Error().Msg("machine key not found in database") log.Error().Msg("machine key not found in database")
ctx.String( ctx.String(
@ -195,12 +210,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) {
namespace, err = h.CreateNamespace(namespaceName) namespace, err = h.CreateNamespace(namespaceName)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("could not create new namespace '%s'", claims.Email) Err(err).
Caller().
Msgf("could not create new namespace '%s'", namespaceName)
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not create new namespace", "could not create new namespace",
@ -208,10 +225,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
} else if err != nil {
log.Error().
Caller().
Err(err).
Str("namespace", namespaceName).
Msg("could not find or create namespace")
ctx.String(
http.StatusInternalServerError,
"could not find or create namespace",
)
return
} }
ip, err := h.getAvailableIP() ip, err := h.getAvailableIP()
if err != nil { if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not get an IP from the pool")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get an IP from the pool", "could not get an IP from the pool",
@ -242,6 +275,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
log.Error(). log.Error().
Caller().
Str("email", claims.Email). Str("email", claims.Email).
Str("username", claims.Username). Str("username", claims.Username).
Str("machine", machine.Name). Str("machine", machine.Name).