mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Clean up logging and error handling in oidc
We should never expose errors via web, it gives attackers a lot of info (Insert OWASP guide). Also handle error that didnt separate not found gorm issue and other errors.
This commit is contained in:
		
							parent
							
								
									fac33e46e1
								
							
						
					
					
						commit
						fcd4d94927
					
				
							
								
								
									
										62
									
								
								oidc.go
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								oidc.go
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"encoding/hex" | 	"encoding/hex" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| @ -15,6 +16,7 @@ import ( | |||||||
| 	"github.com/patrickmn/go-cache" | 	"github.com/patrickmn/go-cache" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
|  | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -37,7 +39,10 @@ func (h *Headscale) initOIDC() error { | |||||||
| 		h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) | 		h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) | ||||||
| 
 | 
 | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) | 			log.Error(). | ||||||
|  | 				Err(err). | ||||||
|  | 				Caller(). | ||||||
|  | 				Msgf("Could not retrieve OIDC Config: %s", err.Error()) | ||||||
| 
 | 
 | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @ -69,8 +74,8 @@ func (h *Headscale) initOIDC() error { | |||||||
| // Puts machine key in cache so the callback can retrieve it using the oidc state param
 | // Puts machine key in cache so the callback can retrieve it using the oidc state param
 | ||||||
| // Listens in /oidc/register/:mKey.
 | // Listens in /oidc/register/:mKey.
 | ||||||
| func (h *Headscale) RegisterOIDC(ctx *gin.Context) { | func (h *Headscale) RegisterOIDC(ctx *gin.Context) { | ||||||
| 	mKeyStr := ctx.Param("mkey") | 	machineKeyStr := ctx.Param("mkey") | ||||||
| 	if mKeyStr == "" { | 	if machineKeyStr == "" { | ||||||
| 		ctx.String(http.StatusBadRequest, "Wrong params") | 		ctx.String(http.StatusBadRequest, "Wrong params") | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| @ -78,7 +83,9 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { | |||||||
| 
 | 
 | ||||||
| 	randomBlob := make([]byte, randomByteSize) | 	randomBlob := make([]byte, randomByteSize) | ||||||
| 	if _, err := rand.Read(randomBlob); err != nil { | 	if _, err := rand.Read(randomBlob); err != nil { | ||||||
| 		log.Error().Msg("could not read 16 bytes from rand") | 		log.Error(). | ||||||
|  | 			Caller(). | ||||||
|  | 			Msg("could not read 16 bytes from rand") | ||||||
| 		ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") | 		ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| @ -87,7 +94,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { | |||||||
| 	stateStr := hex.EncodeToString(randomBlob)[:32] | 	stateStr := hex.EncodeToString(randomBlob)[:32] | ||||||
| 
 | 
 | ||||||
| 	// place the machine key into the state cache, so it can be retrieved later
 | 	// place the machine key into the state cache, so it can be retrieved later
 | ||||||
| 	h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration) | 	h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) | ||||||
| 
 | 
 | ||||||
| 	authURL := h.oauth2Config.AuthCodeURL(stateStr) | 	authURL := h.oauth2Config.AuthCodeURL(stateStr) | ||||||
| 	log.Debug().Msgf("Redirecting to %s for authentication", authURL) | 	log.Debug().Msgf("Redirecting to %s for authentication", authURL) | ||||||
| @ -130,7 +137,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { | |||||||
| 
 | 
 | ||||||
| 	idToken, err := verifier.Verify(context.Background(), rawIDToken) | 	idToken, err := verifier.Verify(context.Background(), rawIDToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) | 		log.Error(). | ||||||
|  | 			Err(err). | ||||||
|  | 			Caller(). | ||||||
|  | 			Msg("failed to verify id token") | ||||||
|  | 		ctx.String(http.StatusBadRequest, "Failed to verify id token") | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @ -145,27 +156,31 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { | |||||||
| 	// Extract custom claims
 | 	// Extract custom claims
 | ||||||
| 	var claims IDTokenClaims | 	var claims IDTokenClaims | ||||||
| 	if err = idToken.Claims(&claims); err != nil { | 	if err = idToken.Claims(&claims); err != nil { | ||||||
|  | 		log.Error(). | ||||||
|  | 			Err(err). | ||||||
|  | 			Caller(). | ||||||
|  | 			Msg("Failed to decode id token claims") | ||||||
| 		ctx.String( | 		ctx.String( | ||||||
| 			http.StatusBadRequest, | 			http.StatusBadRequest, | ||||||
| 			fmt.Sprintf("Failed to decode id token claims: %s", err), | 			fmt.Sprintf("Failed to decode id token claims"), | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// retrieve machinekey from state cache
 | 	// retrieve machinekey from state cache
 | ||||||
| 	mKeyIf, mKeyFound := h.oidcStateCache.Get(state) | 	machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) | ||||||
| 
 | 
 | ||||||
| 	if !mKeyFound { | 	if !machineKeyFound { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Msg("requested machine state key expired before authorisation completed") | 			Msg("requested machine state key expired before authorisation completed") | ||||||
| 		ctx.String(http.StatusBadRequest, "state has expired") | 		ctx.String(http.StatusBadRequest, "state has expired") | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	mKeyStr, mKeyOK := mKeyIf.(string) | 	machineKey, machineKeyOK := machineKeyIf.(string) | ||||||
| 
 | 
 | ||||||
| 	if !mKeyOK { | 	if !machineKeyOK { | ||||||
| 		log.Error().Msg("could not get machine key from cache") | 		log.Error().Msg("could not get machine key from cache") | ||||||
| 		ctx.String( | 		ctx.String( | ||||||
| 			http.StatusInternalServerError, | 			http.StatusInternalServerError, | ||||||
| @ -176,7 +191,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// retrieve machine information
 | 	// retrieve machine information
 | ||||||
| 	machine, err := h.GetMachineByMachineKey(mKeyStr) | 	machine, err := h.GetMachineByMachineKey(machineKey) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Msg("machine key not found in database") | 		log.Error().Msg("machine key not found in database") | ||||||
| 		ctx.String( | 		ctx.String( | ||||||
| @ -195,12 +210,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { | |||||||
| 			log.Debug().Msg("Registering new machine after successful callback") | 			log.Debug().Msg("Registering new machine after successful callback") | ||||||
| 
 | 
 | ||||||
| 			namespace, err := h.GetNamespace(namespaceName) | 			namespace, err := h.GetNamespace(namespaceName) | ||||||
| 			if err != nil { | 			if errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
| 				namespace, err = h.CreateNamespace(namespaceName) | 				namespace, err = h.CreateNamespace(namespaceName) | ||||||
| 
 | 
 | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Error(). | 					log.Error(). | ||||||
| 						Msgf("could not create new namespace '%s'", claims.Email) | 						Err(err). | ||||||
|  | 						Caller(). | ||||||
|  | 						Msgf("could not create new namespace '%s'", namespaceName) | ||||||
| 					ctx.String( | 					ctx.String( | ||||||
| 						http.StatusInternalServerError, | 						http.StatusInternalServerError, | ||||||
| 						"could not create new namespace", | 						"could not create new namespace", | ||||||
| @ -208,10 +225,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { | |||||||
| 
 | 
 | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
|  | 			} else if err != nil { | ||||||
|  | 				log.Error(). | ||||||
|  | 					Caller(). | ||||||
|  | 					Err(err). | ||||||
|  | 					Str("namespace", namespaceName). | ||||||
|  | 					Msg("could not find or create namespace") | ||||||
|  | 				ctx.String( | ||||||
|  | 					http.StatusInternalServerError, | ||||||
|  | 					"could not find or create namespace", | ||||||
|  | 				) | ||||||
|  | 
 | ||||||
|  | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			ip, err := h.getAvailableIP() | 			ip, err := h.getAvailableIP() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
|  | 				log.Error(). | ||||||
|  | 					Caller(). | ||||||
|  | 					Err(err). | ||||||
|  | 					Msg("could not get an IP from the pool") | ||||||
| 				ctx.String( | 				ctx.String( | ||||||
| 					http.StatusInternalServerError, | 					http.StatusInternalServerError, | ||||||
| 					"could not get an IP from the pool", | 					"could not get an IP from the pool", | ||||||
| @ -242,6 +275,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Error(). | 	log.Error(). | ||||||
|  | 		Caller(). | ||||||
| 		Str("email", claims.Email). | 		Str("email", claims.Email). | ||||||
| 		Str("username", claims.Username). | 		Str("username", claims.Username). | ||||||
| 		Str("machine", machine.Name). | 		Str("machine", machine.Name). | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user