mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	oidc: add test for expiring nodes after token expiration
This commit is contained in:
		
							parent
							
								
									085912cfb4
								
							
						
					
					
						commit
						23a595c26f
					
				@ -16,10 +16,11 @@ const (
 | 
				
			|||||||
	errMockOidcClientIDNotDefined     = Error("MOCKOIDC_CLIENT_ID not defined")
 | 
						errMockOidcClientIDNotDefined     = Error("MOCKOIDC_CLIENT_ID not defined")
 | 
				
			||||||
	errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
 | 
						errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
 | 
				
			||||||
	errMockOidcPortNotDefined         = Error("MOCKOIDC_PORT not defined")
 | 
						errMockOidcPortNotDefined         = Error("MOCKOIDC_PORT not defined")
 | 
				
			||||||
	accessTTL                         = 10 * time.Minute
 | 
					 | 
				
			||||||
	refreshTTL                        = 60 * time.Minute
 | 
						refreshTTL                        = 60 * time.Minute
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var accessTTL = 2 * time.Minute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	rootCmd.AddCommand(mockOidcCmd)
 | 
						rootCmd.AddCommand(mockOidcCmd)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -54,6 +55,16 @@ func mockOIDC() error {
 | 
				
			|||||||
	if portStr == "" {
 | 
						if portStr == "" {
 | 
				
			||||||
		return errMockOidcPortNotDefined
 | 
							return errMockOidcPortNotDefined
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
 | 
				
			||||||
 | 
						if accessTTLOverride != "" {
 | 
				
			||||||
 | 
							newTTL, err := time.ParseDuration(accessTTLOverride)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							accessTTL = newTTL
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Info().Msgf("Access token TTL: %s", accessTTL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	port, err := strconv.Atoi(portStr)
 | 
						port, err := strconv.Atoi(portStr)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -9,8 +9,10 @@ import (
 | 
				
			|||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/netip"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/juanfont/headscale"
 | 
						"github.com/juanfont/headscale"
 | 
				
			||||||
	"github.com/juanfont/headscale/integration/dockertestutil"
 | 
						"github.com/juanfont/headscale/integration/dockertestutil"
 | 
				
			||||||
@ -22,7 +24,7 @@ import (
 | 
				
			|||||||
const (
 | 
					const (
 | 
				
			||||||
	dockerContextPath      = "../."
 | 
						dockerContextPath      = "../."
 | 
				
			||||||
	hsicOIDCMockHashLength = 6
 | 
						hsicOIDCMockHashLength = 6
 | 
				
			||||||
	oidcServerPort         = 10000
 | 
						defaultAccessTTL       = 10 * time.Minute
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var errStatusCodeNotOK = errors.New("status code not OK")
 | 
					var errStatusCodeNotOK = errors.New("status code not OK")
 | 
				
			||||||
@ -50,7 +52,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
 | 
				
			|||||||
		"namespace1": len(TailscaleVersions),
 | 
							"namespace1": len(TailscaleVersions),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oidcConfig, err := scenario.runMockOIDC()
 | 
						oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Errorf("failed to run mock OIDC server: %s", err)
 | 
							t.Errorf("failed to run mock OIDC server: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -87,20 +89,76 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
 | 
							t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	success := 0
 | 
						success := pingAll(t, allClients, allIps)
 | 
				
			||||||
 | 
						t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, client := range allClients {
 | 
						err = scenario.Shutdown()
 | 
				
			||||||
		for _, ip := range allIps {
 | 
						if err != nil {
 | 
				
			||||||
			err := client.Ping(ip.String())
 | 
							t.Errorf("failed to tear down scenario: %s", err)
 | 
				
			||||||
			if err != nil {
 | 
						}
 | 
				
			||||||
				t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err)
 | 
					}
 | 
				
			||||||
			} else {
 | 
					
 | 
				
			||||||
				success++
 | 
					func TestOIDCExpireNodes(t *testing.T) {
 | 
				
			||||||
			}
 | 
						IntegrationSkip(t)
 | 
				
			||||||
		}
 | 
						t.Parallel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						shortAccessTTL := 5 * time.Minute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						baseScenario, err := NewScenario()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to create scenario: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
 | 
						scenario := AuthOIDCScenario{
 | 
				
			||||||
 | 
							Scenario: baseScenario,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						spec := map[string]int{
 | 
				
			||||||
 | 
							"namespace1": len(TailscaleVersions),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oidcConfig, err := scenario.runMockOIDC(shortAccessTTL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("failed to run mock OIDC server: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oidcMap := map[string]string{
 | 
				
			||||||
 | 
							"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer,
 | 
				
			||||||
 | 
							"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID,
 | 
				
			||||||
 | 
							"HEADSCALE_OIDC_CLIENT_SECRET":      oidcConfig.ClientSecret,
 | 
				
			||||||
 | 
							"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = scenario.CreateHeadscaleEnv(
 | 
				
			||||||
 | 
							spec,
 | 
				
			||||||
 | 
							hsic.WithTestName("oidcexpirenodes"),
 | 
				
			||||||
 | 
							hsic.WithConfigEnv(oidcMap),
 | 
				
			||||||
 | 
							hsic.WithHostnameAsServerURL(),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to create headscale environment: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						allClients, err := scenario.ListTailscaleClients()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to get clients: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						allIps, err := scenario.ListTailscaleClientsIPs()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed to get clients: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = scenario.WaitForTailscaleSync()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						success := pingAll(t, allClients, allIps)
 | 
				
			||||||
 | 
						t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// await all nodes being logged out after OIDC token expiry
 | 
				
			||||||
 | 
						scenario.WaitForTailscaleLogout()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = scenario.Shutdown()
 | 
						err = scenario.Shutdown()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -143,7 +201,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
					func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*headscale.OIDCConfig, error) {
 | 
				
			||||||
 | 
						port, err := dockertestutil.RandomFreeHostPort()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Fatalf("could not find an open port: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						portNotation := fmt.Sprintf("%d/tcp", port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hash, _ := headscale.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
 | 
						hash, _ := headscale.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
 | 
						hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
 | 
				
			||||||
@ -151,16 +215,17 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
				
			|||||||
	mockOidcOptions := &dockertest.RunOptions{
 | 
						mockOidcOptions := &dockertest.RunOptions{
 | 
				
			||||||
		Name:         hostname,
 | 
							Name:         hostname,
 | 
				
			||||||
		Cmd:          []string{"headscale", "mockoidc"},
 | 
							Cmd:          []string{"headscale", "mockoidc"},
 | 
				
			||||||
		ExposedPorts: []string{"10000/tcp"},
 | 
							ExposedPorts: []string{portNotation},
 | 
				
			||||||
		PortBindings: map[docker.Port][]docker.PortBinding{
 | 
							PortBindings: map[docker.Port][]docker.PortBinding{
 | 
				
			||||||
			"10000/tcp": {{HostPort: "10000"}},
 | 
								docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Networks: []*dockertest.Network{s.Scenario.network},
 | 
							Networks: []*dockertest.Network{s.Scenario.network},
 | 
				
			||||||
		Env: []string{
 | 
							Env: []string{
 | 
				
			||||||
			fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
 | 
								fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
 | 
				
			||||||
			"MOCKOIDC_PORT=10000",
 | 
								fmt.Sprintf("MOCKOIDC_PORT=%d", port),
 | 
				
			||||||
			"MOCKOIDC_CLIENT_ID=superclient",
 | 
								"MOCKOIDC_CLIENT_ID=superclient",
 | 
				
			||||||
			"MOCKOIDC_CLIENT_SECRET=supersecret",
 | 
								"MOCKOIDC_CLIENT_SECRET=supersecret",
 | 
				
			||||||
 | 
								fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -169,7 +234,7 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
				
			|||||||
		ContextDir: dockerContextPath,
 | 
							ContextDir: dockerContextPath,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := s.pool.RemoveContainerByName(hostname)
 | 
						err = s.pool.RemoveContainerByName(hostname)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -184,11 +249,7 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Println("Waiting for headscale mock oidc to be ready for tests")
 | 
						log.Println("Waiting for headscale mock oidc to be ready for tests")
 | 
				
			||||||
	hostEndpoint := fmt.Sprintf(
 | 
						hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port)
 | 
				
			||||||
		"%s:%s",
 | 
					 | 
				
			||||||
		s.mockOIDC.GetIPInNetwork(s.network),
 | 
					 | 
				
			||||||
		s.mockOIDC.GetPort(fmt.Sprintf("%d/tcp", oidcServerPort)),
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := s.pool.Retry(func() error {
 | 
						if err := s.pool.Retry(func() error {
 | 
				
			||||||
		oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
 | 
							oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
 | 
				
			||||||
@ -215,11 +276,11 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
				
			|||||||
	log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
 | 
						log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &headscale.OIDCConfig{
 | 
						return &headscale.OIDCConfig{
 | 
				
			||||||
		Issuer: fmt.Sprintf("http://%s/oidc",
 | 
							Issuer:                     fmt.Sprintf("http://%s/oidc", net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port))),
 | 
				
			||||||
			net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(oidcServerPort))),
 | 
							ClientID:                   "superclient",
 | 
				
			||||||
		ClientID:         "superclient",
 | 
							ClientSecret:               "supersecret",
 | 
				
			||||||
		ClientSecret:     "supersecret",
 | 
							StripEmaildomain:           true,
 | 
				
			||||||
		StripEmaildomain: true,
 | 
							OnlyStartIfOIDCIsAvailable: true,
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -292,6 +353,24 @@ func (s *AuthOIDCScenario) runTailscaleUp(
 | 
				
			|||||||
	return fmt.Errorf("failed to up tailscale node: %w", errNoNamespaceAvailable)
 | 
						return fmt.Errorf("failed to up tailscale node: %w", errNoNamespaceAvailable)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func pingAll(t *testing.T, clients []TailscaleClient, ips []netip.Addr) int {
 | 
				
			||||||
 | 
						t.Helper()
 | 
				
			||||||
 | 
						success := 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, client := range clients {
 | 
				
			||||||
 | 
							for _, ip := range ips {
 | 
				
			||||||
 | 
								err := client.Ping(ip.String())
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									success++
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return success
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *AuthOIDCScenario) Shutdown() error {
 | 
					func (s *AuthOIDCScenario) Shutdown() error {
 | 
				
			||||||
	err := s.pool.Purge(s.mockOIDC)
 | 
						err := s.pool.Purge(s.mockOIDC)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package dockertestutil
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/ory/dockertest/v3"
 | 
						"github.com/ory/dockertest/v3"
 | 
				
			||||||
	"github.com/ory/dockertest/v3/docker"
 | 
						"github.com/ory/dockertest/v3/docker"
 | 
				
			||||||
@ -60,3 +61,20 @@ func AddContainerToNetwork(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RandomFreeHostPort asks the kernel for a free open port that is ready to use.
 | 
				
			||||||
 | 
					// (from https://github.com/phayes/freeport)
 | 
				
			||||||
 | 
					func RandomFreeHostPort() (int, error) {
 | 
				
			||||||
 | 
						addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						listener, err := net.ListenTCP("tcp", addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer listener.Close()
 | 
				
			||||||
 | 
						//nolint:forcetypeassert
 | 
				
			||||||
 | 
						return listener.Addr().(*net.TCPAddr).Port, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user