mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	oidc: add test for expiring nodes after token expiration
This commit is contained in:
		
							parent
							
								
									085912cfb4
								
							
						
					
					
						commit
						23a595c26f
					
				@ -16,10 +16,11 @@ const (
 | 
			
		||||
	errMockOidcClientIDNotDefined     = Error("MOCKOIDC_CLIENT_ID not defined")
 | 
			
		||||
	errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
 | 
			
		||||
	errMockOidcPortNotDefined         = Error("MOCKOIDC_PORT not defined")
 | 
			
		||||
	accessTTL                         = 10 * time.Minute
 | 
			
		||||
	refreshTTL                        = 60 * time.Minute
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var accessTTL = 2 * time.Minute
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	rootCmd.AddCommand(mockOidcCmd)
 | 
			
		||||
}
 | 
			
		||||
@ -54,6 +55,16 @@ func mockOIDC() error {
 | 
			
		||||
	if portStr == "" {
 | 
			
		||||
		return errMockOidcPortNotDefined
 | 
			
		||||
	}
 | 
			
		||||
	accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
 | 
			
		||||
	if accessTTLOverride != "" {
 | 
			
		||||
		newTTL, err := time.ParseDuration(accessTTLOverride)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		accessTTL = newTTL
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Info().Msgf("Access token TTL: %s", accessTTL)
 | 
			
		||||
 | 
			
		||||
	port, err := strconv.Atoi(portStr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,10 @@ import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/juanfont/headscale"
 | 
			
		||||
	"github.com/juanfont/headscale/integration/dockertestutil"
 | 
			
		||||
@ -22,7 +24,7 @@ import (
 | 
			
		||||
const (
 | 
			
		||||
	dockerContextPath      = "../."
 | 
			
		||||
	hsicOIDCMockHashLength = 6
 | 
			
		||||
	oidcServerPort         = 10000
 | 
			
		||||
	defaultAccessTTL       = 10 * time.Minute
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var errStatusCodeNotOK = errors.New("status code not OK")
 | 
			
		||||
@ -50,7 +52,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
 | 
			
		||||
		"namespace1": len(TailscaleVersions),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oidcConfig, err := scenario.runMockOIDC()
 | 
			
		||||
	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to run mock OIDC server: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
@ -87,19 +89,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
 | 
			
		||||
		t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	success := 0
 | 
			
		||||
 | 
			
		||||
	for _, client := range allClients {
 | 
			
		||||
		for _, ip := range allIps {
 | 
			
		||||
			err := client.Ping(ip.String())
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err)
 | 
			
		||||
			} else {
 | 
			
		||||
				success++
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	success := pingAll(t, allClients, allIps)
 | 
			
		||||
	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
 | 
			
		||||
 | 
			
		||||
	err = scenario.Shutdown()
 | 
			
		||||
@ -108,6 +98,74 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOIDCExpireNodes(t *testing.T) {
 | 
			
		||||
	IntegrationSkip(t)
 | 
			
		||||
	t.Parallel()
 | 
			
		||||
 | 
			
		||||
	shortAccessTTL := 5 * time.Minute
 | 
			
		||||
 | 
			
		||||
	baseScenario, err := NewScenario()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to create scenario: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scenario := AuthOIDCScenario{
 | 
			
		||||
		Scenario: baseScenario,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	spec := map[string]int{
 | 
			
		||||
		"namespace1": len(TailscaleVersions),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oidcConfig, err := scenario.runMockOIDC(shortAccessTTL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("failed to run mock OIDC server: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oidcMap := map[string]string{
 | 
			
		||||
		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer,
 | 
			
		||||
		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID,
 | 
			
		||||
		"HEADSCALE_OIDC_CLIENT_SECRET":      oidcConfig.ClientSecret,
 | 
			
		||||
		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = scenario.CreateHeadscaleEnv(
 | 
			
		||||
		spec,
 | 
			
		||||
		hsic.WithTestName("oidcexpirenodes"),
 | 
			
		||||
		hsic.WithConfigEnv(oidcMap),
 | 
			
		||||
		hsic.WithHostnameAsServerURL(),
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to create headscale environment: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	allClients, err := scenario.ListTailscaleClients()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to get clients: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	allIps, err := scenario.ListTailscaleClientsIPs()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to get clients: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = scenario.WaitForTailscaleSync()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	success := pingAll(t, allClients, allIps)
 | 
			
		||||
	t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps))
 | 
			
		||||
 | 
			
		||||
	// await all nodes being logged out after OIDC token expiry
 | 
			
		||||
	scenario.WaitForTailscaleLogout()
 | 
			
		||||
 | 
			
		||||
	err = scenario.Shutdown()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to tear down scenario: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
 | 
			
		||||
	namespaces map[string]int,
 | 
			
		||||
	opts ...hsic.Option,
 | 
			
		||||
@ -143,7 +201,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
			
		||||
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*headscale.OIDCConfig, error) {
 | 
			
		||||
	port, err := dockertestutil.RandomFreeHostPort()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("could not find an open port: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	portNotation := fmt.Sprintf("%d/tcp", port)
 | 
			
		||||
 | 
			
		||||
	hash, _ := headscale.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
 | 
			
		||||
 | 
			
		||||
	hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
 | 
			
		||||
@ -151,16 +215,17 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
			
		||||
	mockOidcOptions := &dockertest.RunOptions{
 | 
			
		||||
		Name:         hostname,
 | 
			
		||||
		Cmd:          []string{"headscale", "mockoidc"},
 | 
			
		||||
		ExposedPorts: []string{"10000/tcp"},
 | 
			
		||||
		ExposedPorts: []string{portNotation},
 | 
			
		||||
		PortBindings: map[docker.Port][]docker.PortBinding{
 | 
			
		||||
			"10000/tcp": {{HostPort: "10000"}},
 | 
			
		||||
			docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}},
 | 
			
		||||
		},
 | 
			
		||||
		Networks: []*dockertest.Network{s.Scenario.network},
 | 
			
		||||
		Env: []string{
 | 
			
		||||
			fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
 | 
			
		||||
			"MOCKOIDC_PORT=10000",
 | 
			
		||||
			fmt.Sprintf("MOCKOIDC_PORT=%d", port),
 | 
			
		||||
			"MOCKOIDC_CLIENT_ID=superclient",
 | 
			
		||||
			"MOCKOIDC_CLIENT_SECRET=supersecret",
 | 
			
		||||
			fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -169,7 +234,7 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
			
		||||
		ContextDir: dockerContextPath,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := s.pool.RemoveContainerByName(hostname)
 | 
			
		||||
	err = s.pool.RemoveContainerByName(hostname)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@ -184,11 +249,7 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Println("Waiting for headscale mock oidc to be ready for tests")
 | 
			
		||||
	hostEndpoint := fmt.Sprintf(
 | 
			
		||||
		"%s:%s",
 | 
			
		||||
		s.mockOIDC.GetIPInNetwork(s.network),
 | 
			
		||||
		s.mockOIDC.GetPort(fmt.Sprintf("%d/tcp", oidcServerPort)),
 | 
			
		||||
	)
 | 
			
		||||
	hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port)
 | 
			
		||||
 | 
			
		||||
	if err := s.pool.Retry(func() error {
 | 
			
		||||
		oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint)
 | 
			
		||||
@ -215,11 +276,11 @@ func (s *AuthOIDCScenario) runMockOIDC() (*headscale.OIDCConfig, error) {
 | 
			
		||||
	log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
 | 
			
		||||
 | 
			
		||||
	return &headscale.OIDCConfig{
 | 
			
		||||
		Issuer: fmt.Sprintf("http://%s/oidc",
 | 
			
		||||
			net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(oidcServerPort))),
 | 
			
		||||
		Issuer:                     fmt.Sprintf("http://%s/oidc", net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port))),
 | 
			
		||||
		ClientID:                   "superclient",
 | 
			
		||||
		ClientSecret:               "supersecret",
 | 
			
		||||
		StripEmaildomain:           true,
 | 
			
		||||
		OnlyStartIfOIDCIsAvailable: true,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -292,6 +353,24 @@ func (s *AuthOIDCScenario) runTailscaleUp(
 | 
			
		||||
	return fmt.Errorf("failed to up tailscale node: %w", errNoNamespaceAvailable)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func pingAll(t *testing.T, clients []TailscaleClient, ips []netip.Addr) int {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	success := 0
 | 
			
		||||
 | 
			
		||||
	for _, client := range clients {
 | 
			
		||||
		for _, ip := range ips {
 | 
			
		||||
			err := client.Ping(ip.String())
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err)
 | 
			
		||||
			} else {
 | 
			
		||||
				success++
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return success
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AuthOIDCScenario) Shutdown() error {
 | 
			
		||||
	err := s.pool.Purge(s.mockOIDC)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package dockertestutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
 | 
			
		||||
	"github.com/ory/dockertest/v3"
 | 
			
		||||
	"github.com/ory/dockertest/v3/docker"
 | 
			
		||||
@ -60,3 +61,20 @@ func AddContainerToNetwork(
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RandomFreeHostPort asks the kernel for a free open port that is ready to use.
 | 
			
		||||
// (from https://github.com/phayes/freeport)
 | 
			
		||||
func RandomFreeHostPort() (int, error) {
 | 
			
		||||
	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	listener, err := net.ListenTCP("tcp", addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	defer listener.Close()
 | 
			
		||||
	//nolint:forcetypeassert
 | 
			
		||||
	return listener.Addr().(*net.TCPAddr).Port, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user