1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-22 00:11:47 +01:00
juanfont.headscale/cmd/headscale/cli/mockoidc.go

147 lines
3.4 KiB
Go
Raw Normal View History

2022-09-20 21:58:36 +02:00
package cli
import (
"encoding/json"
2022-09-20 21:58:36 +02:00
"fmt"
"net"
"net/http"
2022-09-20 21:58:36 +02:00
"os"
"strconv"
"time"
"github.com/oauth2-proxy/mockoidc"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)
2022-09-20 23:02:44 +02:00
const (
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
refreshTTL = 60 * time.Minute
)
var accessTTL = 2 * time.Minute
2022-09-20 21:58:36 +02:00
func init() {
rootCmd.AddCommand(mockOidcCmd)
}
var mockOidcCmd = &cobra.Command{
Use: "mockoidc",
Short: "Runs a mock OIDC server for testing",
Long: "This internal command runs a OpenID Connect for testing purposes",
Run: func(cmd *cobra.Command, args []string) {
err := mockOIDC()
if err != nil {
2022-09-20 23:06:43 +02:00
log.Error().Err(err).Msgf("Error running mock OIDC server")
2022-09-20 21:58:36 +02:00
os.Exit(1)
}
},
}
func mockOIDC() error {
clientID := os.Getenv("MOCKOIDC_CLIENT_ID")
if clientID == "" {
2022-09-20 23:02:44 +02:00
return errMockOidcClientIDNotDefined
2022-09-20 21:58:36 +02:00
}
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
if clientSecret == "" {
2022-09-20 23:02:44 +02:00
return errMockOidcClientSecretNotDefined
2022-09-20 21:58:36 +02:00
}
addrStr := os.Getenv("MOCKOIDC_ADDR")
if addrStr == "" {
return errMockOidcPortNotDefined
}
2022-09-20 21:58:36 +02:00
portStr := os.Getenv("MOCKOIDC_PORT")
if portStr == "" {
2022-09-20 23:02:44 +02:00
return errMockOidcPortNotDefined
2022-09-20 21:58:36 +02:00
}
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
if accessTTLOverride != "" {
newTTL, err := time.ParseDuration(accessTTLOverride)
if err != nil {
return err
}
accessTTL = newTTL
}
userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" {
return fmt.Errorf("MOCKOIDC_USERS not defined")
}
var users []mockoidc.MockUser
err := json.Unmarshal([]byte(userStr), &users)
if err != nil {
return fmt.Errorf("unmarshalling users: %w", err)
}
log.Info().Interface("users", users).Msg("loading users from JSON")
log.Info().Msgf("Access token TTL: %s", accessTTL)
2022-09-20 21:58:36 +02:00
port, err := strconv.Atoi(portStr)
if err != nil {
return err
}
mock, err := getMockOIDC(clientID, clientSecret, users)
2022-09-20 21:58:36 +02:00
if err != nil {
return err
}
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port))
2022-09-20 21:58:36 +02:00
if err != nil {
return err
}
2022-09-20 23:02:44 +02:00
err = mock.Start(listener, nil)
if err != nil {
return err
}
log.Info().Msgf("Mock OIDC server listening on %s", listener.Addr().String())
2022-09-20 21:58:36 +02:00
log.Info().Msgf("Issuer: %s", mock.Issuer())
c := make(chan struct{})
<-c
return nil
}
func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) {
2022-09-20 21:58:36 +02:00
keypair, err := mockoidc.NewKeypair(nil)
if err != nil {
return nil, err
}
userQueue := mockoidc.UserQueue{}
for _, user := range users {
userQueue.Push(&user)
}
2022-09-20 21:58:36 +02:00
mock := mockoidc.MockOIDC{
ClientID: clientID,
ClientSecret: clientSecret,
2022-09-20 23:02:44 +02:00
AccessTTL: accessTTL,
RefreshTTL: refreshTTL,
2022-09-20 21:58:36 +02:00
CodeChallengeMethodsSupported: []string{"plain", "S256"},
Keypair: keypair,
SessionStore: mockoidc.NewSessionStore(),
UserQueue: &userQueue,
2022-09-20 21:58:36 +02:00
ErrorQueue: &mockoidc.ErrorQueue{},
}
mock.AddMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Info().Msgf("Request: %+v", r)
h.ServeHTTP(w, r)
if r.Response != nil {
log.Info().Msgf("Response: %+v", r.Response)
}
})
})
2022-09-20 21:58:36 +02:00
return &mock, nil
}