1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-14 13:51:01 +02:00

integration: Eventually, debug output, lint and format

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-08-06 08:37:02 +02:00
parent f6c1348835
commit da641c42d8
No known key found for this signature in database
14 changed files with 1675 additions and 582 deletions

View File

@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
@ -13,6 +14,7 @@ import (
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -1271,57 +1273,262 @@ func TestACLAutogroupMember(t *testing.T) {
func TestACLAutogroupTagged(t *testing.T) { func TestACLAutogroupTagged(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
scenario := aclScenario(t, // Create a custom scenario for testing autogroup:tagged
&policyv2.Policy{ spec := ScenarioSpec{
ACLs: []policyv2.ACL{ NodesPerUser: 2, // 2 nodes per user - one tagged, one untagged
{ Users: []string{"user1", "user2"},
Action: "accept", }
Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupTagged)},
Destinations: []policyv2.AliasWithPorts{ scenario, err := NewScenario(spec)
aliasWithPorts(ptr.To(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), require.NoError(t, err)
}, defer scenario.ShutdownAssertNoPanics(t)
policy := &policyv2.Policy{
TagOwners: policyv2.TagOwners{
"tag:test": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupTagged)},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(ptr.To(policyv2.AutoGroupTagged), tailcfg.PortRangeAny),
}, },
}, },
}, },
}
2, // Create only the headscale server (not the full environment with users/nodes)
headscale, err := scenario.Headscale(
hsic.WithACLPolicy(policy),
hsic.WithTestName("acl-autogroup-tagged"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
) )
defer scenario.ShutdownAssertNoPanics(t) require.NoError(t, err)
// Create users and nodes manually with specific tags
for _, userStr := range spec.Users {
user, err := scenario.CreateUser(userStr)
require.NoError(t, err)
// Create a single pre-auth key per user
authKey, err := scenario.CreatePreAuthKey(user.GetId(), true, false)
require.NoError(t, err)
// Create nodes with proper naming
for i := 0; i < spec.NodesPerUser; i++ {
var tags []string
var version string
if i == 0 {
// First node is tagged
tags = []string{"tag:test"}
version = "head"
t.Logf("Creating tagged node for %s", userStr)
} else {
// Second node is untagged
tags = nil
version = "unstable"
t.Logf("Creating untagged node for %s", userStr)
}
// Get the network for this scenario
networks := scenario.Networks()
var network *dockertest.Network
if len(networks) > 0 {
network = networks[0]
}
// Create the tailscale node with appropriate options
opts := []tsic.Option{
tsic.WithCACert(headscale.GetCert()),
tsic.WithHeadscaleName(headscale.GetHostname()),
tsic.WithNetwork(network),
tsic.WithNetfilter("off"),
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
}),
tsic.WithDockerWorkdir("/"),
}
// Add tags if this is a tagged node
if len(tags) > 0 {
opts = append(opts, tsic.WithTags(tags))
}
tsClient, err := tsic.New(
scenario.Pool(),
version,
opts...,
)
require.NoError(t, err)
err = tsClient.WaitForNeedsLogin(integrationutil.PeerSyncTimeout())
require.NoError(t, err)
// Login with the auth key
err = tsClient.Login(headscale.GetEndpoint(), authKey.GetKey())
require.NoError(t, err)
err = tsClient.WaitForRunning(integrationutil.PeerSyncTimeout())
require.NoError(t, err)
// Add client to user
userObj := scenario.GetOrCreateUser(userStr)
userObj.Clients[tsClient.Hostname()] = tsClient
}
}
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
require.NoError(t, err) require.NoError(t, err)
require.Len(t, allClients, 4) // 2 users * 2 nodes each
err = scenario.WaitForTailscaleSync() // Wait for nodes to see only their allowed peers
require.NoError(t, err) // Tagged nodes should see each other (2 tagged nodes total)
// Untagged nodes should see no one
var taggedClients []TailscaleClient
var untaggedClients []TailscaleClient
// Test that tagged nodes can access each other // First, categorize nodes by checking their tags
for _, client := range allClients { for _, client := range allClients {
hostname := client.Hostname()
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err)
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
// This is a tagged node
assert.Len(ct, status.Peers(), 1, "tagged node %s should see exactly 1 peer", hostname)
// Add to tagged list only once we've verified it
found := false
for _, tc := range taggedClients {
if tc.Hostname() == hostname {
found = true
break
}
}
if !found {
taggedClients = append(taggedClients, client)
}
} else {
// This is an untagged node
assert.Len(ct, status.Peers(), 0, "untagged node %s should see 0 peers", hostname)
// Add to untagged list only once we've verified it
found := false
for _, uc := range untaggedClients {
if uc.Hostname() == hostname {
found = true
break
}
}
if !found {
untaggedClients = append(untaggedClients, client)
}
}
}, 30*time.Second, 1*time.Second, fmt.Sprintf("verifying peer visibility for node %s", hostname))
}
// Verify we have the expected number of tagged and untagged nodes
require.Len(t, taggedClients, 2, "should have exactly 2 tagged nodes")
require.Len(t, untaggedClients, 2, "should have exactly 2 untagged nodes")
// Explicitly verify tags on tagged nodes
for _, client := range taggedClients {
status, err := client.Status() status, err := client.Status()
require.NoError(t, err) require.NoError(t, err)
if status.Self.Tags == nil || status.Self.Tags.Len() == 0 { require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
continue require.Greater(t, status.Self.Tags.Len(), 0, "tagged node %s should have at least one tag", client.Hostname())
t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags)
}
// Verify untagged nodes have no tags
for _, client := range untaggedClients {
status, err := client.Status()
require.NoError(t, err)
if status.Self.Tags != nil {
require.Equal(t, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
} }
t.Logf("Untagged node %s has no tags", client.Hostname())
}
for _, peer := range allClients { // Test that tagged nodes can communicate with each other
for _, client := range taggedClients {
for _, peer := range taggedClients {
if client.Hostname() == peer.Hostname() { if client.Hostname() == peer.Hostname() {
continue continue
} }
status, err := peer.Status()
require.NoError(t, err)
if status.Self.Tags == nil || status.Self.Tags.Len() == 0 {
continue
}
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
require.NoError(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("Testing connection from tagged node %s to tagged node %s", client.Hostname(), peer.Hostname())
result, err := client.Curl(url) assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assert.Len(t, result, 13) result, err := client.Curl(url)
assert.NoError(ct, err)
assert.Len(ct, result, 13)
}, 15*time.Second, 500*time.Millisecond, "tagged nodes should be able to communicate")
}
}
// Test that untagged nodes cannot communicate with anyone
for _, client := range untaggedClients {
// Try to reach tagged nodes (should fail)
for _, peer := range taggedClients {
fqdn, err := peer.FQDN()
require.NoError(t, err) require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("Testing connection from untagged node %s to tagged node %s (should fail)", client.Hostname(), peer.Hostname())
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, err := client.CurlFailFast(url)
assert.Empty(ct, result)
assert.Error(ct, err)
}, 5*time.Second, 200*time.Millisecond, "untagged nodes should not be able to reach tagged nodes")
}
// Try to reach other untagged nodes (should also fail)
for _, peer := range untaggedClients {
if client.Hostname() == peer.Hostname() {
continue
}
fqdn, err := peer.FQDN()
require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("Testing connection from untagged node %s to untagged node %s (should fail)", client.Hostname(), peer.Hostname())
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, err := client.CurlFailFast(url)
assert.Empty(ct, result)
assert.Error(ct, err)
}, 5*time.Second, 200*time.Millisecond, "untagged nodes should not be able to reach other untagged nodes")
}
}
// Test that tagged nodes cannot reach untagged nodes
for _, client := range taggedClients {
for _, peer := range untaggedClients {
fqdn, err := peer.FQDN()
require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("Testing connection from tagged node %s to untagged node %s (should fail)", client.Hostname(), peer.Hostname())
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, err := client.CurlFailFast(url)
assert.Empty(ct, result)
assert.Error(ct, err)
}, 5*time.Second, 200*time.Millisecond, "tagged nodes should not be able to reach untagged nodes")
} }
} }
} }

View File

@ -30,7 +30,11 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
opts := []hsic.Option{hsic.WithTestName("pingallbyip")} opts := []hsic.Option{
hsic.WithTestName("pingallbyip"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithDERPAsIP(),
}
if https { if https {
opts = append(opts, []hsic.Option{ opts = append(opts, []hsic.Option{
hsic.WithTLS(), hsic.WithTLS(),
@ -130,6 +134,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
assertLastSeenSet(t, node) assertLastSeenSet(t, node)
} }
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
}) })
@ -193,6 +200,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, err = scenario.CreateHeadscaleEnv([]tsic.Option{},
hsic.WithTestName("keyrelognewuser"), hsic.WithTestName("keyrelognewuser"),
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithDERPAsIP(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -282,7 +290,10 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
opts := []hsic.Option{hsic.WithTestName("pingallbyip")} opts := []hsic.Option{
hsic.WithTestName("pingallbyip"),
hsic.WithDERPAsIP(),
}
if https { if https {
opts = append(opts, []hsic.Option{ opts = append(opts, []hsic.Option{
hsic.WithTLS(), hsic.WithTLS(),

View File

@ -113,7 +113,18 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
} }
} }
// This test is really flaky. // TestOIDCExpireNodesBasedOnTokenExpiry validates that nodes correctly transition to NeedsLogin
// state when their OIDC tokens expire. This test uses a short token TTL to validate the
// expiration behavior without waiting for production-length timeouts.
//
// The test verifies:
// - Nodes can successfully authenticate via OIDC and establish connectivity
// - When OIDC tokens expire, nodes transition to NeedsLogin state
// - The expiration is based on individual token issue times, not a global timer
//
// Known timing considerations:
// - Nodes may expire at different times due to sequential login processing
// - The test must account for login time spread between first and last node
func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
@ -153,8 +164,12 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
allIps, err := scenario.ListTailscaleClientsIPs() allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err) assertNoErrListClientIPs(t, err)
// Record when sync completes to better estimate token expiry timing
syncCompleteTime := time.Now()
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
loginDuration := time.Since(syncCompleteTime)
t.Logf("Login and sync completed in %v", loginDuration)
// assertClientsState(t, allClients) // assertClientsState(t, allClients)
@ -165,19 +180,49 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps)) t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps))
// This is not great, but this sadly is a time dependent test, so the // Wait for OIDC token expiry and verify all nodes transition to NeedsLogin.
// safe thing to do is wait out the whole TTL time (and a bit more out // We add extra time to account for:
// of safety reasons) before checking if the clients have logged out. // - Sequential login processing causing different token issue times
// The Wait function can't do it itself as it has an upper bound of 1 // - Network and processing delays
// min. // - Safety margin for test reliability
loginTimeSpread := 1 * time.Minute // Account for sequential login delays
safetyBuffer := 30 * time.Second // Additional safety margin
totalWaitTime := shortAccessTTL + loginTimeSpread + safetyBuffer
t.Logf("Waiting %v for OIDC tokens to expire (TTL: %v, spread: %v, buffer: %v)",
totalWaitTime, shortAccessTTL, loginTimeSpread, safetyBuffer)
// EventuallyWithT retries the test function until it passes or times out.
// IMPORTANT: Use 'ct' (CollectT) for all assertions inside the function, not 't'.
// Using 't' would cause immediate test failure without retries, defeating the purpose
// of EventuallyWithT which is designed to handle timing-dependent conditions.
assert.EventuallyWithT(t, func(ct *assert.CollectT) { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check each client's status individually to provide better diagnostics
expiredCount := 0
for _, client := range allClients { for _, client := range allClients {
status, err := client.Status() status, err := client.Status()
assert.NoError(ct, err) if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) {
assert.Equal(ct, "NeedsLogin", status.BackendState) if status.BackendState == "NeedsLogin" {
expiredCount++
}
}
} }
assertTailscaleNodesLogout(t, allClients)
}, shortAccessTTL+10*time.Second, 5*time.Second) // Log progress for debugging
if expiredCount < len(allClients) {
t.Logf("Token expiry progress: %d/%d clients in NeedsLogin state", expiredCount, len(allClients))
}
// All clients must be in NeedsLogin state
assert.Equal(ct, len(allClients), expiredCount,
"expected all %d clients to be in NeedsLogin state, but only %d are",
len(allClients), expiredCount)
// Only check detailed logout state if all clients are expired
if expiredCount == len(allClients) {
assertTailscaleNodesLogout(ct, allClients)
}
}, totalWaitTime, 5*time.Second)
} }
func TestOIDC024UserCreation(t *testing.T) { func TestOIDC024UserCreation(t *testing.T) {
@ -429,6 +474,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
hsic.WithTLS(), hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithDERPAsIP(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -617,14 +663,18 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
} }
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { // assertTailscaleNodesLogout verifies that all provided Tailscale clients
t.Helper() // are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
if h, ok := t.(interface{ Helper() }); ok {
h.Helper()
}
for _, client := range clients { for _, client := range clients {
status, err := client.Status() status, err := client.Status()
assertNoErr(t, err) assert.NoError(t, err, "failed to get status for client %s", client.Hostname())
assert.Equal(t, "NeedsLogin", status.BackendState,
assert.Equal(t, "NeedsLogin", status.BackendState) "client %s should be logged out", client.Hostname())
} }
} }

View File

@ -30,6 +30,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
nil, nil,
hsic.WithTestName("webauthping"), hsic.WithTestName("webauthping"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithDERPAsIP(),
hsic.WithTLS(), hsic.WithTLS(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
@ -68,6 +69,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
err = scenario.CreateHeadscaleEnvWithLoginURL( err = scenario.CreateHeadscaleEnvWithLoginURL(
nil, nil,
hsic.WithTestName("weblogout"), hsic.WithTestName("weblogout"),
hsic.WithDERPAsIP(),
hsic.WithTLS(), hsic.WithTLS(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)

View File

@ -5,6 +5,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
) )
@ -28,5 +29,7 @@ type ControlServer interface {
ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error)
GetCert() []byte GetCert() []byte
GetHostname() string GetHostname() string
GetIPInNetwork(network *dockertest.Network) string
SetPolicy(*policyv2.Policy) error SetPolicy(*policyv2.Policy) error
PrimaryRoutes() (*routes.DebugRoutes, error)
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
) )
const dockerExecuteTimeout = time.Second * 30 const dockerExecuteTimeout = time.Second * 10
var ( var (
ErrDockertestCommandFailed = errors.New("dockertest command failed") ErrDockertestCommandFailed = errors.New("dockertest command failed")

View File

@ -96,7 +96,7 @@ func CleanUnreferencedNetworks(pool *dockertest.Pool) error {
} }
for _, network := range networks { for _, network := range networks {
if network.Network.Containers == nil || len(network.Network.Containers) == 0 { if len(network.Network.Containers) == 0 {
err := pool.RemoveNetwork(&network) err := pool.RemoveNetwork(&network)
if err != nil { if err != nil {
log.Printf("removing network %s: %s", network.Network.Name, err) log.Printf("removing network %s: %s", network.Network.Name, err)

View File

@ -945,6 +945,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
[]tsic.Option{}, []tsic.Option{},
hsic.WithTestName("pingallbyipmany"), hsic.WithTestName("pingallbyipmany"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
hsic.WithDERPAsIP(),
hsic.WithTLS(), hsic.WithTLS(),
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)

View File

@ -23,6 +23,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
@ -272,6 +273,14 @@ func WithTimezone(timezone string) Option {
} }
} }
// WithDERPAsIP enables using IP address instead of hostname for DERP server.
// This is useful for integration tests where DNS resolution may be unreliable.
func WithDERPAsIP() Option {
return func(hsic *HeadscaleInContainer) {
hsic.env["HEADSCALE_DEBUG_DERP_USE_IP"] = "1"
}
}
// WithDebugPort sets the debug port for delve debugging. // WithDebugPort sets the debug port for delve debugging.
func WithDebugPort(port int) Option { func WithDebugPort(port int) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
@ -829,9 +838,25 @@ func (t *HeadscaleInContainer) GetHealthEndpoint() string {
// GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer. // GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer.
func (t *HeadscaleInContainer) GetEndpoint() string { func (t *HeadscaleInContainer) GetEndpoint() string {
hostEndpoint := fmt.Sprintf("%s:%d", return t.getEndpoint(false)
t.GetHostname(), }
t.port)
// GetIPEndpoint returns the Headscale endpoint using IP address instead of hostname.
func (t *HeadscaleInContainer) GetIPEndpoint() string {
return t.getEndpoint(true)
}
// getEndpoint returns the Headscale endpoint, optionally using IP address instead of hostname.
func (t *HeadscaleInContainer) getEndpoint(useIP bool) string {
var host string
if useIP && len(t.networks) > 0 {
// Use IP address from the first network
host = t.GetIPInNetwork(t.networks[0])
} else {
host = t.GetHostname()
}
hostEndpoint := fmt.Sprintf("%s:%d", host, t.port)
if t.hasTLS() { if t.hasTLS() {
return "https://" + hostEndpoint return "https://" + hostEndpoint
@ -850,6 +875,11 @@ func (t *HeadscaleInContainer) GetHostname() string {
return t.hostname return t.hostname
} }
// GetIPInNetwork returns the IP address of the HeadscaleInContainer in the given network.
func (t *HeadscaleInContainer) GetIPInNetwork(network *dockertest.Network) string {
return t.container.GetIPInNetwork(network)
}
// WaitForRunning blocks until the Headscale instance is ready to // WaitForRunning blocks until the Headscale instance is ready to
// serve clients. // serve clients.
func (t *HeadscaleInContainer) WaitForRunning() error { func (t *HeadscaleInContainer) WaitForRunning() error {
@ -1243,3 +1273,23 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
return nil return nil
} }
// PrimaryRoutes fetches the primary routes from the debug endpoint.
func (t *HeadscaleInContainer) PrimaryRoutes() (*routes.DebugRoutes, error) {
// Execute curl inside the container to access the debug endpoint locally
command := []string{
"curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/routes",
}
result, err := t.Execute(command)
if err != nil {
return nil, fmt.Errorf("fetching routes from debug endpoint: %w", err)
}
var debugRoutes routes.DebugRoutes
if err := json.Unmarshal([]byte(result), &debugRoutes); err != nil {
return nil, fmt.Errorf("decoding routes response: %w", err)
}
return &debugRoutes, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -327,6 +327,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
return true return true
}) })
s.mu.Lock()
for userName, user := range s.users { for userName, user := range s.users {
for _, client := range user.Clients { for _, client := range user.Clients {
log.Printf("removing client %s in user %s", client.Hostname(), userName) log.Printf("removing client %s in user %s", client.Hostname(), userName)
@ -346,6 +347,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
} }
} }
} }
s.mu.Unlock()
for _, derp := range s.derpServers { for _, derp := range s.derpServers {
err := derp.Shutdown() err := derp.Shutdown()
@ -429,6 +431,27 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) {
return headscale, nil return headscale, nil
} }
// Pool returns the dockertest pool for the scenario.
func (s *Scenario) Pool() *dockertest.Pool {
return s.pool
}
// GetOrCreateUser gets or creates a user in the scenario.
func (s *Scenario) GetOrCreateUser(userStr string) *User {
s.mu.Lock()
defer s.mu.Unlock()
if user, ok := s.users[userStr]; ok {
return user
}
user := &User{
Clients: make(map[string]TailscaleClient),
}
s.users[userStr] = user
return user
}
// CreatePreAuthKey creates a "pre authentorised key" to be created in the // CreatePreAuthKey creates a "pre authentorised key" to be created in the
// Headscale instance on behalf of the Scenario. // Headscale instance on behalf of the Scenario.
func (s *Scenario) CreatePreAuthKey( func (s *Scenario) CreatePreAuthKey(
@ -457,9 +480,11 @@ func (s *Scenario) CreateUser(user string) (*v1.User, error) {
return nil, fmt.Errorf("failed to create user: %w", err) return nil, fmt.Errorf("failed to create user: %w", err)
} }
s.mu.Lock()
s.users[user] = &User{ s.users[user] = &User{
Clients: make(map[string]TailscaleClient), Clients: make(map[string]TailscaleClient),
} }
s.mu.Unlock()
return u, nil return u, nil
} }
@ -541,11 +566,25 @@ func (s *Scenario) CreateTailscaleNodesInUser(
cert := headscale.GetCert() cert := headscale.GetCert()
hostname := headscale.GetHostname() hostname := headscale.GetHostname()
// Determine which network this tailscale client will be in
var network *dockertest.Network
if s.userToNetwork != nil && s.userToNetwork[userStr] != nil {
network = s.userToNetwork[userStr]
} else {
network = s.networks[s.testDefaultNetwork]
}
// Get headscale IP in this network for /etc/hosts fallback DNS
headscaleIP := headscale.GetIPInNetwork(network)
extraHosts := []string{hostname + ":" + headscaleIP}
s.mu.Lock() s.mu.Lock()
opts = append(opts, opts = append(opts,
tsic.WithCACert(cert), tsic.WithCACert(cert),
tsic.WithHeadscaleName(hostname), tsic.WithHeadscaleName(hostname),
tsic.WithExtraHosts(extraHosts),
) )
s.mu.Unlock() s.mu.Unlock()
user.createWaitGroup.Go(func() error { user.createWaitGroup.Go(func() error {
@ -673,6 +712,7 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, ret
if len(allErrors) > 0 { if len(allErrors) > 0 {
return multierr.New(allErrors...) return multierr.New(allErrors...)
} }
return nil return nil
} }

View File

@ -409,7 +409,7 @@ func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient,
// For all other errors, assert no error to trigger retry // For all other errors, assert no error to trigger retry
assert.NoError(ct, err) assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second) }, 10*time.Second, 200*time.Millisecond)
} else { } else {
// For failure cases, just execute once // For failure cases, just execute once
result, stderr, err = client.Execute(command) result, stderr, err = client.Execute(command)

View File

@ -32,6 +32,7 @@ type TailscaleClient interface {
Down() error Down() error
IPs() ([]netip.Addr, error) IPs() ([]netip.Addr, error)
MustIPs() []netip.Addr MustIPs() []netip.Addr
IPv4() (netip.Addr, error)
MustIPv4() netip.Addr MustIPv4() netip.Addr
MustIPv6() netip.Addr MustIPv6() netip.Addr
FQDN() (string, error) FQDN() (string, error)
@ -46,6 +47,7 @@ type TailscaleClient interface {
WaitForPeers(expected int, timeout, retryInterval time.Duration) error WaitForPeers(expected int, timeout, retryInterval time.Duration) error
Ping(hostnameOrIP string, opts ...tsic.PingOption) error Ping(hostnameOrIP string, opts ...tsic.PingOption) error
Curl(url string, opts ...tsic.CurlOption) (string, error) Curl(url string, opts ...tsic.CurlOption) (string, error)
CurlFailFast(url string) (string, error)
Traceroute(netip.Addr) (util.Traceroute, error) Traceroute(netip.Addr) (util.Traceroute, error)
ContainerID() string ContainerID() string
MustID() types.NodeID MustID() types.NodeID

View File

@ -36,8 +36,8 @@ import (
const ( const (
tsicHashLength = 6 tsicHashLength = 6
defaultPingTimeout = 300 * time.Millisecond defaultPingTimeout = 200 * time.Millisecond
defaultPingCount = 10 defaultPingCount = 5
dockerContextPath = "../." dockerContextPath = "../."
caCertRoot = "/usr/local/share/ca-certificates" caCertRoot = "/usr/local/share/ca-certificates"
dockerExecuteTimeout = 60 * time.Second dockerExecuteTimeout = 60 * time.Second
@ -573,7 +573,7 @@ func (t *TailscaleInContainer) Down() error {
// IPs returns the netip.Addr of the Tailscale instance. // IPs returns the netip.Addr of the Tailscale instance.
func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
if t.ips != nil && len(t.ips) != 0 { if len(t.ips) != 0 {
return t.ips, nil return t.ips, nil
} }
@ -589,7 +589,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err) return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err)
} }
for _, address := range strings.Split(result, "\n") { for address := range strings.SplitSeq(result, "\n") {
address = strings.TrimSuffix(address, "\n") address = strings.TrimSuffix(address, "\n")
if len(address) < 1 { if len(address) < 1 {
continue continue
@ -613,6 +613,22 @@ func (t *TailscaleInContainer) MustIPs() []netip.Addr {
return ips return ips
} }
// IPv4 returns the IPv4 address of the Tailscale instance.
func (t *TailscaleInContainer) IPv4() (netip.Addr, error) {
ips, err := t.IPs()
if err != nil {
return netip.Addr{}, err
}
for _, ip := range ips {
if ip.Is4() {
return ip, nil
}
}
return netip.Addr{}, fmt.Errorf("no IPv4 address found")
}
func (t *TailscaleInContainer) MustIPv4() netip.Addr { func (t *TailscaleInContainer) MustIPv4() netip.Addr {
for _, ip := range t.MustIPs() { for _, ip := range t.MustIPs() {
if ip.Is4() { if ip.Is4() {
@ -984,6 +1000,7 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval
expected, expected,
len(peers), len(peers),
)} )}
continue continue
} }
@ -1149,11 +1166,11 @@ func WithCurlRetry(ret int) CurlOption {
} }
const ( const (
defaultConnectionTimeout = 3 * time.Second defaultConnectionTimeout = 1 * time.Second
defaultMaxTime = 10 * time.Second defaultMaxTime = 3 * time.Second
defaultRetry = 5 defaultRetry = 3
defaultRetryDelay = 0 * time.Second defaultRetryDelay = 200 * time.Millisecond
defaultRetryMaxTime = 50 * time.Second defaultRetryMaxTime = 5 * time.Second
) )
// Curl executes the Tailscale curl command and curls a hostname // Curl executes the Tailscale curl command and curls a hostname
@ -1198,6 +1215,17 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
return result, nil return result, nil
} }
// CurlFailFast executes the Tailscale curl command with aggressive timeouts
// optimized for testing expected connection failures. It uses minimal timeouts
// to quickly detect blocked connections without waiting for multiple retries.
func (t *TailscaleInContainer) CurlFailFast(url string) (string, error) {
// Use aggressive timeouts for fast failure detection
return t.Curl(url,
WithCurlConnectionTimeout(1*time.Second),
WithCurlMaxTime(2*time.Second),
WithCurlRetry(1))
}
func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) { func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) {
command := []string{ command := []string{
"traceroute", "traceroute",