mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-23 13:50:36 +01:00
Go style recommends that log messages and error strings should not be capitalized (unless beginning with proper nouns or acronyms) and should not end with punctuation. This change normalizes all zerolog .Msg() and .Msgf() calls to start with lowercase letters, following Go conventions and making logs more consistent across the codebase.
150 lines
3.5 KiB
Go
150 lines
3.5 KiB
Go
package cli
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
|
"github.com/oauth2-proxy/mockoidc"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
const (
|
|
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
|
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
|
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
|
|
refreshTTL = 60 * time.Minute
|
|
)
|
|
|
|
var accessTTL = 2 * time.Minute
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(mockOidcCmd)
|
|
}
|
|
|
|
var mockOidcCmd = &cobra.Command{
|
|
Use: "mockoidc",
|
|
Short: "Runs a mock OIDC server for testing",
|
|
Long: "This internal command runs a OpenID Connect for testing purposes",
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
err := mockOIDC()
|
|
if err != nil {
|
|
log.Error().Err(err).Msgf("error running mock OIDC server")
|
|
os.Exit(1)
|
|
}
|
|
},
|
|
}
|
|
|
|
func mockOIDC() error {
|
|
clientID := os.Getenv("MOCKOIDC_CLIENT_ID")
|
|
if clientID == "" {
|
|
return errMockOidcClientIDNotDefined
|
|
}
|
|
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
|
|
if clientSecret == "" {
|
|
return errMockOidcClientSecretNotDefined
|
|
}
|
|
addrStr := os.Getenv("MOCKOIDC_ADDR")
|
|
if addrStr == "" {
|
|
return errMockOidcPortNotDefined
|
|
}
|
|
portStr := os.Getenv("MOCKOIDC_PORT")
|
|
if portStr == "" {
|
|
return errMockOidcPortNotDefined
|
|
}
|
|
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
|
|
if accessTTLOverride != "" {
|
|
newTTL, err := time.ParseDuration(accessTTLOverride)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
accessTTL = newTTL
|
|
}
|
|
|
|
userStr := os.Getenv("MOCKOIDC_USERS")
|
|
if userStr == "" {
|
|
return errors.New("MOCKOIDC_USERS not defined")
|
|
}
|
|
|
|
var users []mockoidc.MockUser
|
|
err := json.Unmarshal([]byte(userStr), &users)
|
|
if err != nil {
|
|
return fmt.Errorf("unmarshalling users: %w", err)
|
|
}
|
|
|
|
log.Info().Interface(zf.Users, users).Msg("loading users from JSON")
|
|
|
|
log.Info().Msgf("access token TTL: %s", accessTTL)
|
|
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
mock, err := getMockOIDC(clientID, clientSecret, users)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = mock.Start(listener, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Info().Msgf("mock OIDC server listening on %s", listener.Addr().String())
|
|
log.Info().Msgf("issuer: %s", mock.Issuer())
|
|
c := make(chan struct{})
|
|
<-c
|
|
|
|
return nil
|
|
}
|
|
|
|
func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) {
|
|
keypair, err := mockoidc.NewKeypair(nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userQueue := mockoidc.UserQueue{}
|
|
|
|
for _, user := range users {
|
|
userQueue.Push(&user)
|
|
}
|
|
|
|
mock := mockoidc.MockOIDC{
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
AccessTTL: accessTTL,
|
|
RefreshTTL: refreshTTL,
|
|
CodeChallengeMethodsSupported: []string{"plain", "S256"},
|
|
Keypair: keypair,
|
|
SessionStore: mockoidc.NewSessionStore(),
|
|
UserQueue: &userQueue,
|
|
ErrorQueue: &mockoidc.ErrorQueue{},
|
|
}
|
|
|
|
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
log.Info().Msgf("request: %+v", r)
|
|
h.ServeHTTP(w, r)
|
|
if r.Response != nil {
|
|
log.Info().Msgf("response: %+v", r.Response)
|
|
}
|
|
})
|
|
})
|
|
|
|
return &mock, nil
|
|
}
|