From 91d5c1879ac21e3d5765f175d0b641fefc99f6a9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 16 Sep 2025 18:57:37 +0200 Subject: [PATCH] integration: consistency changes to *Relogin* tests Signed-off-by: Kristoffer Dalby --- hscontrol/state/debug.go | 10 +- hscontrol/state/node_store.go | 2 +- hscontrol/state/state.go | 4 + integration/auth_key_test.go | 189 +++-- integration/auth_oidc_test.go | 326 +++++--- integration/auth_web_flow_test.go | 51 +- integration/cli_test.go | 138 ++-- integration/derp_verify_endpoint_test.go | 15 +- integration/dns_test.go | 43 +- integration/embedded_derp_test.go | 17 +- integration/general_test.go | 312 ++------ integration/helpers.go | 899 +++++++++++++++++++++++ integration/hsic/hsic.go | 17 + integration/route_test.go | 84 +-- integration/scenario_test.go | 5 +- integration/ssh_test.go | 51 +- integration/utils.go | 533 -------------- 17 files changed, 1519 insertions(+), 1177 deletions(-) create mode 100644 integration/helpers.go delete mode 100644 integration/utils.go diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 7c60128f..dbe790fa 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -60,8 +60,8 @@ type DebugStringInfo struct { // DebugOverview returns a comprehensive overview of the current state for debugging. func (s *State) DebugOverview() string { - s.mu.RLock() - defer s.mu.RUnlock() + s.mu.Lock() + defer s.mu.Unlock() allNodes := s.nodeStore.ListNodes() users, _ := s.ListAllUsers() @@ -265,14 +265,14 @@ func (s *State) DebugPolicyManager() string { // PolicyDebugString returns a debug representation of the current policy. func (s *State) PolicyDebugString() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.polMan.DebugString() } // DebugOverviewJSON returns a structured overview of the current state for debugging. func (s *State) DebugOverviewJSON() DebugOverviewInfo { - s.mu.RLock() - defer s.mu.RUnlock() - allNodes := s.nodeStore.ListNodes() users, _ := s.ListAllUsers() diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 555766d1..b27a2945 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -15,7 +15,7 @@ import ( ) const ( - batchSize = 10 + batchSize = 100 batchTimeout = 500 * time.Millisecond ) diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 15597706..6315d8f2 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -50,6 +50,7 @@ var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode") type State struct { // mu protects all in-memory data structures from concurrent access mu deadlock.RWMutex + // cfg holds the current Headscale configuration cfg *types.Config @@ -201,6 +202,9 @@ func (s *State) DERPMap() tailcfg.DERPMapView { // ReloadPolicy reloads the access control policy and triggers auto-approval if changed. // Returns true if the policy changed. func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + pol, err := policyBytes(s.db, s.cfg) if err != nil { return nil, fmt.Errorf("loading policy: %w", err) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 90034434..01a0f7f3 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -28,7 +28,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) opts := []hsic.Option{ @@ -43,31 +43,25 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) - expectedNodes := make([]types.NodeID, 0, len(allClients)) - for _, client := range allClients { - status := client.MustStatus() - nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) - assertNoErr(t, err) - expectedNodes = append(expectedNodes, types.NodeID(nodeID)) - } - requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second) + expectedNodes := collectExpectedNodeIDs(t, allClients) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second) // Validate that all nodes have NetInfo and DERP servers before logout - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 1*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute) // assertClientsState(t, allClients) @@ -97,19 +91,20 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) // After taking down all nodes, verify all systems show nodes offline requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second) t.Logf("all clients logged out") + t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes after logout") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node persistence after logout (nodes should remain in database)") for _, node := range listNodes { assertLastSeenSet(t, node) @@ -125,7 +120,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) for _, userName := range spec.Users { key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) @@ -139,12 +134,13 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } } + t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection") - }, 30*time.Second, 2*time.Second) + assert.NoError(ct, err, "Failed to list nodes after relogin") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 60*time.Second, 2*time.Second, "validating node count stability after same-user auth key relogin") for _, node := range listNodes { assertLastSeenSet(t, node) @@ -152,11 +148,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 120*time.Second) + // Wait for Tailscale sync before validating NetInfo to ensure proper state propagation + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + // Validate that all nodes have NetInfo and DERP servers after reconnection - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 1*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -197,65 +197,6 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } } -// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database -// and a valid DERP server based on the NetInfo. This function follows the pattern of -// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. -func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { - t.Helper() - - startTime := time.Now() - t.Logf("requireAllClientsNetInfoAndDERP: Starting validation at %s - %s", startTime.Format(TimestampFormat), message) - - require.EventuallyWithT(t, func(c *assert.CollectT) { - // Get nodestore state - nodeStore, err := headscale.DebugNodeStore() - assert.NoError(c, err, "Failed to get nodestore debug info") - if err != nil { - return - } - - // Validate node counts first - expectedCount := len(expectedNodes) - assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch") - - // Check each expected node - for _, nodeID := range expectedNodes { - node, exists := nodeStore[nodeID] - assert.True(c, exists, "Node %d not found in nodestore", nodeID) - if !exists { - continue - } - - // Validate that the node has Hostinfo - assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo", nodeID, node.Hostname) - if node.Hostinfo == nil { - continue - } - - // Validate that the node has NetInfo - assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo", nodeID, node.Hostname) - if node.Hostinfo.NetInfo == nil { - continue - } - - // Validate that the node has a valid DERP server (PreferredDERP should be > 0) - preferredDERP := node.Hostinfo.NetInfo.PreferredDERP - assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0), got %d", nodeID, node.Hostname, preferredDERP) - - t.Logf("Node %d (%s) has valid NetInfo with DERP server %d", nodeID, node.Hostname, preferredDERP) - } - }, timeout, 2*time.Second, message) - - endTime := time.Now() - duration := endTime.Sub(startTime) - t.Logf("requireAllClientsNetInfoAndDERP: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message) -} - -func assertLastSeenSet(t *testing.T, node *v1.Node) { - assert.NotNil(t, node) - assert.NotNil(t, node.GetLastSeen()) -} - // This test will first log in two sets of nodes to two sets of users, then // it will log out all users from user2 and log them in as user1. // This should leave us with all nodes connected to user1, while user2 @@ -269,7 +210,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, @@ -277,18 +218,25 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { hsic.WithTLS(), hsic.WithDERPAsIP(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) listNodes, err := headscale.ListNodes() assert.Len(t, allClients, len(listNodes)) @@ -303,12 +251,15 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) t.Logf("all clients logged out") userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) // Create a new authkey for user1, to be used for all clients key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false) @@ -326,28 +277,41 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error user1Nodes, err = headscale.ListNodes("user1") - assert.NoError(ct, err) - assert.Len(ct, user1Nodes, len(allClients), "User1 should have all clients after re-login") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") + assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) + }, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after auth key relogin") + + // Collect expected node IDs for user1 after relogin + expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes)) + for _, node := range user1Nodes { + expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId())) + } + + // Validate connection state after relogin as user1 + requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute) // Validate that all the old nodes are still present with user2 var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error user2Nodes, err = headscale.ListNodes("user2") - assert.NoError(ct, err) - assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should have half the clients") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") + assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) + }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") + t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) - assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1", client.Hostname()) - }, 30*time.Second, 2*time.Second) + assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) + }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname())) } } @@ -362,7 +326,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) opts := []hsic.Option{ @@ -376,13 +340,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -396,7 +360,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) listNodes, err := headscale.ListNodes() assert.Len(t, allClients, len(listNodes)) @@ -411,7 +382,10 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) t.Logf("all clients logged out") @@ -425,7 +399,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) for _, userName := range spec.Users { key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) @@ -443,7 +417,8 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { "expire", key.GetKey(), }) - assertNoErr(t, err) + require.NoError(t, err) + require.NoError(t, err) err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) assert.ErrorContains(t, err, "authkey expired") diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 751a8d11..e57c0647 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -4,17 +4,20 @@ import ( "maps" "net/netip" "sort" + "strconv" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/oauth2-proxy/mockoidc" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOIDCAuthenticationPingAll(t *testing.T) { @@ -33,7 +36,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -51,16 +54,16 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { hsic.WithTLS(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -72,10 +75,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) listUsers, err := headscale.ListUsers() - assertNoErr(t, err) + require.NoError(t, err) want := []*v1.User{ { @@ -141,7 +144,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -156,18 +159,18 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { hsic.WithTestName("oidcexpirenodes"), hsic.WithConfigEnv(oidcMap), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) // Record when sync completes to better estimate token expiry timing syncCompleteTime := time.Now() err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) loginDuration := time.Since(syncCompleteTime) t.Logf("Login and sync completed in %v", loginDuration) @@ -348,7 +351,7 @@ func TestOIDC024UserCreation(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -366,20 +369,20 @@ func TestOIDC024UserCreation(t *testing.T) { hsic.WithTLS(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) // Ensure that the nodes have logged in, this is what // triggers user creation via OIDC. err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) want := tt.want(scenario.mockOIDC.Issuer()) listUsers, err := headscale.ListUsers() - assertNoErr(t, err) + require.NoError(t, err) sort.Slice(listUsers, func(i, j int) bool { return listUsers[i].GetId() < listUsers[j].GetId() @@ -405,7 +408,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -423,17 +426,17 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { hsic.WithTLS(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) // Get all clients and verify they can connect allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -443,6 +446,11 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } +// TestOIDCReloginSameNodeNewUser tests the scenario where: +// 1. A Tailscale client logs in with user1 (creates node1 for user1) +// 2. The same client logs out and logs in with user2 (creates node2 for user2) +// 3. The same client logs out and logs in with user1 again (reuses node1, node2 remains) +// This validates that OIDC relogin properly handles node reuse and cleanup. func TestOIDCReloginSameNodeNewUser(t *testing.T) { IntegrationSkip(t) @@ -457,7 +465,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { oidcMockUser("user1", true), }, }) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -476,24 +484,25 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithDERPAsIP(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) - assertNoErr(t, err) + require.NoError(t, err) u, err := ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) _, err = doLoginURL(ts.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) + t.Logf("Validating initial user creation at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listUsers, err := headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 1) + assert.NoError(ct, err, "Failed to list users during initial validation") + assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) wantUsers := []*v1.User{ { Id: 1, @@ -509,44 +518,61 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }) if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - t.Fatalf("unexpected users: %s", diff) + ct.Errorf("User validation failed after first login - unexpected users: %s", diff) } - }, 30*time.Second, 1*time.Second, "validating users after first login") + }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") - listNodes, err := headscale.ListNodes() - assertNoErr(t, err) - assert.Len(t, listNodes, 1) + t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during initial validation") + assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) + }, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login") + + // Collect expected node IDs for validation after user1 initial login + expectedNodes := make([]types.NodeID, 0, 1) + status := ts.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + + // Validate initial connection state for user1 + validateInitialConnection(t, headscale, expectedNodes) // Log out user1 and log in user2, this should create a new node // for user2, the node should have the same machine key and // a new node key. err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it // manually. err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) // Wait for logout to complete and then do second logout + t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check that the first logout completed status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "NeedsLogin", status.BackendState) - }, 30*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to get client status during logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before user2 login") u, err = ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) _, err = doLoginURL(ts.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) + t.Logf("Validating user2 creation at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listUsers, err := headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 2) + assert.NoError(ct, err, "Failed to list users after user2 login") + assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) wantUsers := []*v1.User{ { Id: 1, @@ -569,27 +595,83 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }) if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - ct.Errorf("unexpected users: %s", diff) + ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff) } - }, 30*time.Second, 1*time.Second, "validating users after new user login") + }, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login") var listNodesAfterNewUserLogin []*v1.Node + // First, wait for the new node to be created + t.Logf("Waiting for user2 node creation at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listNodesAfterNewUserLogin, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodesAfterNewUserLogin, 2) + assert.NoError(ct, err, "Failed to list nodes after user2 login") + // We might temporarily have more than 2 nodes during cleanup, so check for at least 2 + assert.GreaterOrEqual(ct, len(listNodesAfterNewUserLogin), 2, "Should have at least 2 nodes after user2 login, got %d (may include temporary nodes during cleanup)", len(listNodesAfterNewUserLogin)) + }, 30*time.Second, 1*time.Second, "waiting for user2 node creation (allowing temporary extra nodes during cleanup)") - // Machine key is the same as the "machine" has not changed, - // but Node key is not as it is a new node - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) - assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) - assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) - }, 30*time.Second, 1*time.Second, "listing nodes after new user login") + // Then wait for cleanup to stabilize at exactly 2 nodes + t.Logf("Waiting for node cleanup stabilization at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterNewUserLogin, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during cleanup validation") + assert.Len(ct, listNodesAfterNewUserLogin, 2, "Should have exactly 2 nodes after cleanup (1 for user1, 1 for user2), got %d nodes", len(listNodesAfterNewUserLogin)) + + // Validate that both nodes have the same machine key but different node keys + if len(listNodesAfterNewUserLogin) >= 2 { + // Machine key is the same as the "machine" has not changed, + // but Node key is not as it is a new node + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Machine key should be preserved from original node") + assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "Both nodes should share the same machine key") + assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey(), "Node keys should be different between user1 and user2 nodes") + } + }, 90*time.Second, 2*time.Second, "waiting for node count stabilization at exactly 2 nodes after user2 login") + + // Security validation: Only user2's node should be active after user switch + var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { + if node.GetUser().GetId() == 2 { // user2 + activeUser2NodeID = types.NodeID(node.GetId()) + t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break + } + } + + // Validate only user2's node is online (security requirement) + t.Logf("Validating only user2 node is online at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user2 node is online + if node, exists := nodeStore[activeUser2NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User2 node should be online after login") + } + } else { + assert.Fail(c, "User2 node not found in nodestore") + } + }, 60*time.Second, 2*time.Second, "validating only user2 node is online after user switch") + + // Before logging out user2, validate we have exactly 2 nodes and both are stable + t.Logf("Pre-logout validation: checking node stability at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + currentNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes before user2 logout") + assert.Len(ct, currentNodes, 2, "Should have exactly 2 stable nodes before user2 logout, got %d", len(currentNodes)) + + // Validate node stability - ensure no phantom nodes + for i, node := range currentNodes { + assert.NotNil(ct, node.GetUser(), "Node %d should have a valid user before logout", i) + assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should have a valid machine key before logout", i) + t.Logf("Pre-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...") + } + }, 60*time.Second, 2*time.Second, "validating stable node count and integrity before user2 logout") // Log out user2, and log into user1, no new node should be created, // the node should now "become" node1 again err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Logged out take one") t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") @@ -598,41 +680,63 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // logs in immediately after the first logout and I cannot reproduce it // manually. err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Logged out take two") t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") // Wait for logout to complete and then do second logout + t.Logf("Waiting for user2 logout completion at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check that the first logout completed status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "NeedsLogin", status.BackendState) - }, 30*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to get client status during user2 logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after user2 logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user2 logout to complete before user1 relogin") + + // Before logging back in, ensure we still have exactly 2 nodes + // Note: We skip validateLogoutComplete here since it expects all nodes to be offline, + // but in OIDC scenario we maintain both nodes in DB with only active user online + + // Additional validation that nodes are properly maintained during logout + t.Logf("Post-logout validation: checking node persistence at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + currentNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after user2 logout") + assert.Len(ct, currentNodes, 2, "Should still have exactly 2 nodes after user2 logout (nodes should persist), got %d", len(currentNodes)) + + // Ensure both nodes are still valid (not cleaned up incorrectly) + for i, node := range currentNodes { + assert.NotNil(ct, node.GetUser(), "Node %d should still have a valid user after user2 logout", i) + assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should still have a valid machine key after user2 logout", i) + t.Logf("Post-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...") + } + }, 60*time.Second, 2*time.Second, "validating node persistence and integrity after user2 logout") // We do not actually "change" the user here, it is done by logging in again // as the OIDC mock server is kind of like a stack, and the next user is // prepared and ready to go. u, err = ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) _, err = doLoginURL(ts.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) + t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "Running", status.BackendState) - }, 30*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to get client status during user1 relogin validation") + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (final login)") t.Logf("Logged back in") t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") + t.Logf("Final validation: checking user persistence at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listUsers, err := headscale.ListUsers() - assert.NoError(ct, err) - assert.Len(ct, listUsers, 2) + assert.NoError(ct, err, "Failed to list users during final validation") + assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) wantUsers := []*v1.User{ { Id: 1, @@ -655,59 +759,75 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }) if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - ct.Errorf("unexpected users: %s", diff) + ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff) } - }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") + }, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)") + var listNodesAfterLoggingBackIn []*v1.Node + // Wait for login to complete and nodes to stabilize + t.Logf("Final node validation: checking node stability after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { - listNodesAfterLoggingBackIn, err := headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodesAfterLoggingBackIn, 2) + listNodesAfterLoggingBackIn, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during final validation") + + // Allow for temporary instability during login process + if len(listNodesAfterLoggingBackIn) < 2 { + ct.Errorf("Not enough nodes yet during final validation, got %d, want at least 2", len(listNodesAfterLoggingBackIn)) + return + } + + // Final check should have exactly 2 nodes + assert.Len(ct, listNodesAfterLoggingBackIn, 2, "Should have exactly 2 nodes after complete relogin cycle, got %d", len(listNodesAfterLoggingBackIn)) // Validate that the machine we had when we logged in the first time, has the same // machine key, but a different ID than the newly logged in version of the same // machine. - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) - assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) - assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) - assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) - assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Original user1 machine key should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey(), "Original user1 node key should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId(), "Original user1 node ID should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "User1 and user2 nodes should share the same machine key") + assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId(), "User1 and user2 nodes should have different node IDs") + assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId(), "User1 and user2 nodes should belong to different users") // Even tho we are logging in again with the same user, the previous key has been expired // and a new one has been generated. The node entry in the database should be the same // as the user + machinekey still matches. - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) - assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) - assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey(), "Machine key should remain consistent after user1 relogin") + assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey(), "Node key should be regenerated after user1 relogin") + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId(), "Node ID should be preserved for user1 after relogin") // The "logged back in" machine should have the same machinekey but a different nodekey // than the version logged in with a different user. - assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) - assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) - }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") -} + assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey(), "Both final nodes should share the same machine key") + assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey(), "Final nodes should have different node keys for different users") -// assertTailscaleNodesLogout verifies that all provided Tailscale clients -// are in the logged-out state (NeedsLogin). -func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { - if h, ok := t.(interface{ Helper() }); ok { - h.Helper() + t.Logf("Final validation complete - node counts and key relationships verified at %s", time.Now().Format(TimestampFormat)) + }, 60*time.Second, 2*time.Second, "validating final node state after complete user1->user2->user1 relogin cycle with detailed key validation") + + // Security validation: Only user1's node should be active after relogin + var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { + if node.GetUser().GetId() == 1 { // user1 + activeUser1NodeID = types.NodeID(node.GetId()) + t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break + } } - for _, client := range clients { - status, err := client.Status() - assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) - assert.Equal(t, "NeedsLogin", status.BackendState, - "client %s should be logged out", client.Hostname()) - } -} + // Validate only user1's node is online (security requirement) + t.Logf("Validating only user1 node is online after relogin at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") -func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { - return mockoidc.MockUser{ - Subject: username, - PreferredUsername: username, - Email: username + "@headscale.net", - EmailVerified: emailVerified, - } + // Check user1 node is online + if node, exists := nodeStore[activeUser1NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User1 node should be online after relogin") + } + } else { + assert.Fail(c, "User1 node not found in nodestore after relogin") + } + }, 60*time.Second, 2*time.Second, "validating only user1 node is online after final relogin") } diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index ff190142..d8aac03d 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { @@ -33,16 +34,16 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { hsic.WithDERPAsIP(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -63,7 +64,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnvWithLoginURL( @@ -72,16 +73,16 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { hsic.WithDERPAsIP(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -93,15 +94,22 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, len(allClients), "Node count should match client count after login") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes after web authentication") + assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -122,7 +130,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + validateLogoutComplete(t, headscale, expectedNodes) t.Logf("all clients logged out") @@ -136,7 +147,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients logged in again") allIps, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs = lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -145,14 +156,18 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { success = pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count after re-login") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes after web flow logout") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 60*time.Second, 2*time.Second, "validating node persistence in database after web flow logout") t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) + // Validate connection state after relogin + validateReloginComplete(t, headscale, expectedNodes) + for _, client := range allClients { ips, err := client.IPs() if err != nil { diff --git a/integration/cli_test.go b/integration/cli_test.go index 98e2ddf3..40afd2c3 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -54,14 +54,14 @@ func TestUserCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var listUsers []*v1.User var result []string @@ -99,7 +99,7 @@ func TestUserCommand(t *testing.T) { "--new-name=newname", }, ) - assertNoErr(t, err) + require.NoError(t, err) var listAfterRenameUsers []*v1.User assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -138,7 +138,7 @@ func TestUserCommand(t *testing.T) { }, &listByUsername, ) - assertNoErr(t, err) + require.NoError(t, err) slices.SortFunc(listByUsername, sortWithID) want := []*v1.User{ @@ -165,7 +165,7 @@ func TestUserCommand(t *testing.T) { }, &listByID, ) - assertNoErr(t, err) + require.NoError(t, err) slices.SortFunc(listByID, sortWithID) want = []*v1.User{ @@ -244,7 +244,7 @@ func TestUserCommand(t *testing.T) { }, &listAfterNameDelete, ) - assertNoErr(t, err) + require.NoError(t, err) require.Empty(t, listAfterNameDelete) } @@ -260,17 +260,17 @@ func TestPreAuthKeyCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) keys := make([]*v1.PreAuthKey, count) - assertNoErr(t, err) + require.NoError(t, err) for index := range count { var preAuthKey v1.PreAuthKey @@ -292,7 +292,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &preAuthKey, ) - assertNoErr(t, err) + require.NoError(t, err) keys[index] = &preAuthKey } @@ -313,7 +313,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &listedPreAuthKeys, ) - assertNoErr(t, err) + require.NoError(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 4) @@ -372,7 +372,7 @@ func TestPreAuthKeyCommand(t *testing.T) { listedPreAuthKeys[1].GetKey(), }, ) - assertNoErr(t, err) + require.NoError(t, err) var listedPreAuthKeysAfterExpire []v1.PreAuthKey err = executeAndUnmarshal( @@ -388,7 +388,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &listedPreAuthKeysAfterExpire, ) - assertNoErr(t, err) + require.NoError(t, err) assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) @@ -404,14 +404,14 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var preAuthKey v1.PreAuthKey err = executeAndUnmarshal( @@ -428,7 +428,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { }, &preAuthKey, ) - assertNoErr(t, err) + require.NoError(t, err) var listedPreAuthKeys []v1.PreAuthKey err = executeAndUnmarshal( @@ -444,7 +444,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { }, &listedPreAuthKeys, ) - assertNoErr(t, err) + require.NoError(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) @@ -465,14 +465,14 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var preAuthReusableKey v1.PreAuthKey err = executeAndUnmarshal( @@ -489,7 +489,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &preAuthReusableKey, ) - assertNoErr(t, err) + require.NoError(t, err) var preAuthEphemeralKey v1.PreAuthKey err = executeAndUnmarshal( @@ -506,7 +506,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &preAuthEphemeralKey, ) - assertNoErr(t, err) + require.NoError(t, err) assert.True(t, preAuthEphemeralKey.GetEphemeral()) assert.False(t, preAuthEphemeralKey.GetReusable()) @@ -525,7 +525,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &listedPreAuthKeys, ) - assertNoErr(t, err) + require.NoError(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 3) @@ -543,7 +543,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -552,13 +552,13 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) u2, err := headscale.CreateUser(user2) - assertNoErr(t, err) + require.NoError(t, err) var user2Key v1.PreAuthKey @@ -580,7 +580,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, &user2Key, ) - assertNoErr(t, err) + require.NoError(t, err) var listNodes []*v1.Node assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -592,7 +592,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, 15*time.Second, 1*time.Second) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) require.Len(t, allClients, 1) @@ -600,10 +600,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { // Log out from user1 err = client.Logout() - assertNoErr(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleLogout() - assertNoErr(t, err) + require.NoError(t, err) assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() @@ -613,7 +613,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, 30*time.Second, 2*time.Second) err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) - assertNoErr(t, err) + require.NoError(t, err) assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() @@ -642,14 +642,14 @@ func TestApiKeyCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) keys := make([]string, count) @@ -808,14 +808,14 @@ func TestNodeTagCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1007,7 +1007,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1015,10 +1015,10 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(tt.policy), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Test list all nodes after added seconds resultMachines := make([]*v1.Node, spec.NodesPerUser) @@ -1058,14 +1058,14 @@ func TestNodeCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1302,14 +1302,14 @@ func TestNodeExpireCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1427,14 +1427,14 @@ func TestNodeRenameCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1462,7 +1462,7 @@ func TestNodeRenameCommand(t *testing.T) { "json", }, ) - assertNoErr(t, err) + require.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1480,7 +1480,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &node, ) - assertNoErr(t, err) + require.NoError(t, err) nodes[index] = &node } @@ -1591,20 +1591,20 @@ func TestNodeMoveCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Randomly generated node key regID := types.MustRegistrationID() userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) _, err = headscale.Execute( []string{ @@ -1753,7 +1753,7 @@ func TestPolicyCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1763,10 +1763,10 @@ func TestPolicyCommand(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) p := policyv2.Policy{ ACLs: []policyv2.ACL{ @@ -1789,7 +1789,7 @@ func TestPolicyCommand(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - assertNoErr(t, err) + require.NoError(t, err) // No policy is present at this time. // Add a new policy from a file. @@ -1803,7 +1803,7 @@ func TestPolicyCommand(t *testing.T) { }, ) - assertNoErr(t, err) + require.NoError(t, err) // Get the current policy and check // if it is the same as the one we set. @@ -1819,7 +1819,7 @@ func TestPolicyCommand(t *testing.T) { }, &output, ) - assertNoErr(t, err) + require.NoError(t, err) assert.Len(t, output.TagOwners, 1) assert.Len(t, output.ACLs, 1) @@ -1834,7 +1834,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1844,10 +1844,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) p := policyv2.Policy{ ACLs: []policyv2.ACL{ @@ -1872,7 +1872,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - assertNoErr(t, err) + require.NoError(t, err) // No policy is present at this time. // Add a new policy from a file. diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 4a5e52ae..60260bb1 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netmon" @@ -23,7 +24,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Generate random hostname for the headscale instance hash, err := util.GenerateRandomStringDNSSafe(6) - assertNoErr(t, err) + require.NoError(t, err) testName := "derpverify" hostname := fmt.Sprintf("hs-%s-%s", testName, hash) @@ -31,7 +32,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Create cert for headscale certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) - assertNoErr(t, err) + require.NoError(t, err) spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -39,14 +40,14 @@ func TestDERPVerifyEndpoint(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) derper, err := scenario.CreateDERPServer("head", dsic.WithCACert(certHeadscale), dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), ) - assertNoErr(t, err) + require.NoError(t, err) derpRegion := tailcfg.DERPRegion{ RegionCode: "test-derpverify", @@ -74,17 +75,17 @@ func TestDERPVerifyEndpoint(t *testing.T) { hsic.WithPort(headscalePort), hsic.WithCustomTLS(certHeadscale, keyHeadscale), hsic.WithDERPConfig(derpMap)) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) fakeKey := key.NewNode() DERPVerify(t, fakeKey, derpRegion, false) for _, client := range allClients { nodeKey, err := client.GetNodePrivateKey() - assertNoErr(t, err) + require.NoError(t, err) DERPVerify(t, *nodeKey, derpRegion, true) } } diff --git a/integration/dns_test.go b/integration/dns_test.go index 7cac4d47..7267bc09 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" ) @@ -22,26 +23,26 @@ func TestResolveMagicDNS(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) // Poor mans cache _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -78,7 +79,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) const erPath = "/tmp/extra_records.json" @@ -109,29 +110,29 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) // Poor mans cache _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") } hs, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Write the file directly into place from the docker API. b0, _ := json.Marshal([]tailcfg.DNSRecord{ @@ -143,7 +144,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }) err = hs.WriteFile(erPath, b0) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "2.2.2.2") @@ -159,9 +160,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { b2, _ := json.Marshal(extraRecords) err = hs.WriteFile(erPath+"2", b2) - assertNoErr(t, err) + require.NoError(t, err) _, err = hs.Execute([]string{"mv", erPath + "2", erPath}) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") @@ -179,9 +180,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }) err = hs.WriteFile(erPath+"3", b3) - assertNoErr(t, err) + require.NoError(t, err) _, err = hs.Execute([]string{"cp", erPath + "3", erPath}) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8") @@ -197,7 +198,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }) command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath} _, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")}) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9") @@ -205,7 +206,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Delete the file and create a new one to ensure it is picked up again. _, err = hs.Execute([]string{"rm", erPath}) - assertNoErr(t, err) + require.NoError(t, err) // The same paths should still be available as it is not cleared on delete. assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -219,7 +220,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Write a new file, the backoff mechanism should make the filewatcher pick it up // again. err = hs.WriteFile(erPath, b3) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8") diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index e9ba69dd..17cb01af 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -29,7 +30,7 @@ func TestDERPServerScenario(t *testing.T) { derpServerScenario(t, spec, false, func(scenario *Scenario) { allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) t.Logf("checking %d clients for websocket connections", len(allClients)) for _, client := range allClients { @@ -43,7 +44,7 @@ func TestDERPServerScenario(t *testing.T) { } hsServer, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) derpRegion := tailcfg.DERPRegion{ RegionCode: "test-derpverify", @@ -79,7 +80,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) { derpServerScenario(t, spec, true, func(scenario *Scenario) { allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) t.Logf("checking %d clients for websocket connections", len(allClients)) for _, client := range allClients { @@ -108,7 +109,7 @@ func derpServerScenario( IntegrationSkip(t) scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -128,16 +129,16 @@ func derpServerScenario( "HEADSCALE_DERP_SERVER_VERIFY_CLIENTS": "true", }), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { diff --git a/integration/general_test.go b/integration/general_test.go index 65131af0..ab6d4f71 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -10,19 +10,15 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" - "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale/apitype" "tailscale.com/types/key" @@ -38,7 +34,7 @@ func TestPingAllByIP(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -48,16 +44,16 @@ func TestPingAllByIP(t *testing.T) { hsic.WithTLS(), hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) hs, err := scenario.Headscale() require.NoError(t, err) @@ -80,7 +76,7 @@ func TestPingAllByIP(t *testing.T) { // Get headscale instance for batcher debug check headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Test our DebugBatcher functionality t.Logf("Testing DebugBatcher functionality...") @@ -99,23 +95,23 @@ func TestPingAllByIPPublicDERP(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithTestName("pingallbyippubderp"), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -148,11 +144,11 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) headscale, err := scenario.Headscale(opts...) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) for _, userName := range spec.Users { user, err := scenario.CreateUser(userName) @@ -177,13 +173,13 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -200,7 +196,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) t.Logf("all clients logged out") @@ -222,7 +218,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) headscale, err := scenario.Headscale( @@ -231,7 +227,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s", }), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) for _, userName := range spec.Users { user, err := scenario.CreateUser(userName) @@ -256,13 +252,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -344,22 +340,22 @@ func TestPingAllByHostname(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) success := pingAllHelper(t, allClients, allHostnames) @@ -379,7 +375,7 @@ func TestTaildrop(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, @@ -387,17 +383,17 @@ func TestTaildrop(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // This will essentially fetch and cache all the FQDNs _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { if !strings.Contains(client.Hostname(), "head") { @@ -498,7 +494,7 @@ func TestTaildrop(t *testing.T) { ) result, _, err := client.Execute(command) - assertNoErrf(t, "failed to execute command to ls taildrop: %s", err) + require.NoErrorf(t, err, "failed to execute command to ls taildrop") log.Printf("Result for %s: %s\n", peer.Hostname(), result) if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result { @@ -528,25 +524,25 @@ func TestUpdateHostnameFromClient(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErrf(t, "failed to create scenario: %s", err) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) // update hostnames using the up command for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) command := []string{ "tailscale", @@ -554,11 +550,11 @@ func TestUpdateHostnameFromClient(t *testing.T) { "--hostname=" + hostnames[string(status.Self.ID)], } _, _, err = client.Execute(command) - assertNoErrf(t, "failed to set hostname: %s", err) + require.NoErrorf(t, err, "failed to set hostname") } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for nodestore batch processing to complete // NodeStore batching timeout is 500ms, so we wait up to 1 second @@ -597,7 +593,7 @@ func TestUpdateHostnameFromClient(t *testing.T) { "--identifier", strconv.FormatUint(node.GetId(), 10), }) - assertNoErr(t, err) + require.NoError(t, err) } // Verify that the server-side rename is reflected in DNSName while HostName remains unchanged @@ -643,7 +639,7 @@ func TestUpdateHostnameFromClient(t *testing.T) { for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) command := []string{ "tailscale", @@ -651,11 +647,11 @@ func TestUpdateHostnameFromClient(t *testing.T) { "--hostname=" + hostnames[string(status.Self.ID)] + "NEW", } _, _, err = client.Execute(command) - assertNoErrf(t, "failed to set hostname: %s", err) + require.NoErrorf(t, err, "failed to set hostname") } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for nodestore batch processing to complete // NodeStore batching timeout is 500ms, so we wait up to 1 second @@ -696,20 +692,20 @@ func TestExpireNode(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -731,22 +727,22 @@ func TestExpireNode(t *testing.T) { } headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // TODO(kradalby): This is Headscale specific and would not play nicely // with other implementations of the ControlServer interface result, err := headscale.Execute([]string{ "headscale", "nodes", "expire", "--identifier", "1", "--output", "json", }) - assertNoErr(t, err) + require.NoError(t, err) var node v1.Node err = json.Unmarshal([]byte(result), &node) - assertNoErr(t, err) + require.NoError(t, err) var expiredNodeKey key.NodePublic err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey())) - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) @@ -773,14 +769,14 @@ func TestExpireNode(t *testing.T) { // Verify that the expired node has been marked in all peers list. for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) if client.Hostname() != node.GetName() { t.Logf("available peers of %s: %v", client.Hostname(), status.Peers()) // Ensures that the node is present, and that it is expired. if peerStatus, ok := status.Peer[expiredNodeKey]; ok { - assertNotNil(t, peerStatus.Expired) + requireNotNil(t, peerStatus.Expired) assert.NotNil(t, peerStatus.KeyExpiry) t.Logf( @@ -840,20 +836,20 @@ func TestNodeOnlineStatus(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -866,14 +862,14 @@ func TestNodeOnlineStatus(t *testing.T) { for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) // Assert that we have the original count - self assert.Len(t, status.Peers(), len(MustTestVersions)-1) } headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Duration is chosen arbitrarily, 10m is reported in #1561 testDuration := 12 * time.Minute @@ -963,7 +959,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -973,16 +969,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) { hsic.WithDERPAsIP(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -992,7 +988,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { // Get headscale instance for batcher debug checks headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Initial check: all nodes should be connected to batcher // Extract node IDs for validation @@ -1000,7 +996,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { for _, client := range allClients { status := client.MustStatus() nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) - assertNoErr(t, err) + require.NoError(t, err) expectedNodes = append(expectedNodes, types.NodeID(nodeID)) } requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second) @@ -1072,7 +1068,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1081,16 +1077,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -1100,7 +1096,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Test list all nodes after added otherUser var nodeList []v1.Node @@ -1170,159 +1166,3 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { assert.True(t, nodeListAfter[0].GetOnline()) assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) } - -// NodeSystemStatus represents the online status of a node across different systems -type NodeSystemStatus struct { - Batcher bool - BatcherConnCount int - MapResponses bool - NodeStore bool -} - -// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore -func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { - t.Helper() - - startTime := time.Now() - t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format(TimestampFormat), message) - - var prevReport string - require.EventuallyWithT(t, func(c *assert.CollectT) { - // Get batcher state - debugInfo, err := headscale.DebugBatcher() - assert.NoError(c, err, "Failed to get batcher debug info") - if err != nil { - return - } - - // Get map responses - mapResponses, err := headscale.GetAllMapReponses() - assert.NoError(c, err, "Failed to get map responses") - if err != nil { - return - } - - // Get nodestore state - nodeStore, err := headscale.DebugNodeStore() - assert.NoError(c, err, "Failed to get nodestore debug info") - if err != nil { - return - } - - // Validate node counts first - expectedCount := len(expectedNodes) - assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch") - assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch") - - // Check that we have map responses for expected nodes - mapResponseCount := len(mapResponses) - assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch") - - // Build status map for each node - nodeStatus := make(map[types.NodeID]NodeSystemStatus) - - // Initialize all expected nodes - for _, nodeID := range expectedNodes { - nodeStatus[nodeID] = NodeSystemStatus{} - } - - // Check batcher state - for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes { - nodeID := types.MustParseNodeID(nodeIDStr) - if status, exists := nodeStatus[nodeID]; exists { - status.Batcher = nodeInfo.Connected - status.BatcherConnCount = nodeInfo.ActiveConnections - nodeStatus[nodeID] = status - } - } - - // Check map responses using buildExpectedOnlineMap - onlineFromMaps := make(map[types.NodeID]bool) - onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) - for nodeID := range nodeStatus { - NODE_STATUS: - for id, peerMap := range onlineMap { - if id == nodeID { - continue - } - - online := peerMap[nodeID] - // If the node is offline in any map response, we consider it offline - if !online { - onlineFromMaps[nodeID] = false - continue NODE_STATUS - } - - onlineFromMaps[nodeID] = true - } - } - assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") - - // Update status with map response data - for nodeID, online := range onlineFromMaps { - if status, exists := nodeStatus[nodeID]; exists { - status.MapResponses = online - nodeStatus[nodeID] = status - } - } - - // Check nodestore state - for nodeID, node := range nodeStore { - if status, exists := nodeStatus[nodeID]; exists { - // Check if node is online in nodestore - status.NodeStore = node.IsOnline != nil && *node.IsOnline - nodeStatus[nodeID] = status - } - } - - // Verify all systems show nodes in expected state and report failures - allMatch := true - var failureReport strings.Builder - - ids := types.NodeIDs(maps.Keys(nodeStatus)) - slices.Sort(ids) - for _, nodeID := range ids { - status := nodeStatus[nodeID] - systemsMatch := (status.Batcher == expectedOnline) && - (status.MapResponses == expectedOnline) && - (status.NodeStore == expectedOnline) - - if !systemsMatch { - allMatch = false - stateStr := "offline" - if expectedOnline { - stateStr = "online" - } - failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr)) - failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher)) - failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) - failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses)) - failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore)) - } - } - - if !allMatch { - if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" { - t.Log("Diff between reports:") - t.Logf("Prev report: \n%s\n", prevReport) - t.Logf("New report: \n%s\n", failureReport.String()) - t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") - prevReport = failureReport.String() - } - - failureReport.WriteString("timestamp: " + time.Now().Format(TimestampFormat) + "\n") - - assert.Fail(c, failureReport.String()) - } - - stateStr := "offline" - if expectedOnline { - stateStr = "online" - } - assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr)) - }, timeout, 2*time.Second, message) - - endTime := time.Now() - duration := endTime.Sub(startTime) - t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message) -} diff --git a/integration/helpers.go b/integration/helpers.go new file mode 100644 index 00000000..37cc2ad8 --- /dev/null +++ b/integration/helpers.go @@ -0,0 +1,899 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/cenkalti/backoff/v5" + "github.com/google/go-cmp/cmp" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +const ( + // derpPingTimeout defines the timeout for individual DERP ping operations + // Used in DERP connectivity tests to verify relay server communication. + derpPingTimeout = 2 * time.Second + + // derpPingCount defines the number of ping attempts for DERP connectivity tests + // Higher count provides better reliability assessment of DERP connectivity. + derpPingCount = 10 + + // TimestampFormat is the standard timestamp format used across all integration tests + // Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps + // suitable for debugging and log correlation in integration tests. + TimestampFormat = "2006-01-02T15-04-05.999999999" + + // TimestampFormatRunID is used for generating unique run identifiers + // Format: "20060102-150405" provides compact date-time for file/directory names. + TimestampFormatRunID = "20060102-150405" +) + +// NodeSystemStatus represents the status of a node across different systems +type NodeSystemStatus struct { + Batcher bool + BatcherConnCount int + MapResponses bool + NodeStore bool +} + +// requireNotNil validates that an object is not nil and fails the test if it is. +// This helper provides consistent error messaging for nil checks in integration tests. +func requireNotNil(t *testing.T, object interface{}) { + t.Helper() + require.NotNil(t, object) +} + +// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded. +// Provides specific error context for headscale environment setup failures. +func requireNoErrHeadscaleEnv(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to create headscale environment") +} + +// requireNoErrGetHeadscale validates that headscale server retrieval succeeded. +// Provides specific error context for headscale server access failures. +func requireNoErrGetHeadscale(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to get headscale") +} + +// requireNoErrListClients validates that client listing operations succeeded. +// Provides specific error context for client enumeration failures. +func requireNoErrListClients(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to list clients") +} + +// requireNoErrListClientIPs validates that client IP retrieval succeeded. +// Provides specific error context for client IP address enumeration failures. +func requireNoErrListClientIPs(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to get client IPs") +} + +// requireNoErrSync validates that client synchronization operations succeeded. +// Provides specific error context for client sync failures across the network. +func requireNoErrSync(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to have all clients sync up") +} + +// requireNoErrListFQDN validates that FQDN listing operations succeeded. +// Provides specific error context for DNS name enumeration failures. +func requireNoErrListFQDN(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to list FQDNs") +} + +// requireNoErrLogout validates that tailscale node logout operations succeeded. +// Provides specific error context for client logout failures. +func requireNoErrLogout(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to log out tailscale nodes") +} + +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { + t.Helper() + + expectedNodes := make([]types.NodeID, 0, len(clients)) + for _, client := range clients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + return expectedNodes +} + +// validateInitialConnection performs comprehensive validation after initial client login. +// Validates that all nodes are online and have proper NetInfo/DERP configuration, +// essential for ensuring successful initial connection state in relogin tests. +func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) +} + +// validateLogoutComplete performs comprehensive validation after client logout. +// Ensures all nodes are properly offline across all headscale systems, +// critical for validating clean logout state in relogin tests. +func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) +} + +// validateReloginComplete performs comprehensive validation after client relogin. +// Validates that all nodes are back online with proper NetInfo/DERP configuration, +// ensuring successful relogin state restoration in integration tests. +func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute) +} + +// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) + + if expectedOnline { + // For online validation, use the existing logic with full timeout + requireAllClientsOnlineWithSingleTimeout(t, headscale, expectedNodes, expectedOnline, message, timeout) + } else { + // For offline validation, use staged approach with component-specific timeouts + requireAllClientsOfflineStaged(t, headscale, expectedNodes, message, timeout) + } + + endTime := time.Now() + t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) +} + +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get batcher state + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + // Get map responses + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate node counts first + expectedCount := len(expectedNodes) + assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch - expected %d nodes, got %d", expectedCount, debugInfo.TotalNodes) + assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch - expected %d nodes, got %d", expectedCount, len(nodeStore)) + + // Check that we have map responses for expected nodes + mapResponseCount := len(mapResponses) + assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch - expected %d responses, got %d", expectedCount, mapResponseCount) + + // Build status map for each node + nodeStatus := make(map[types.NodeID]NodeSystemStatus) + + // Initialize all expected nodes + for _, nodeID := range expectedNodes { + nodeStatus[nodeID] = NodeSystemStatus{} + } + + // Check batcher state + for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes { + nodeID := types.MustParseNodeID(nodeIDStr) + if status, exists := nodeStatus[nodeID]; exists { + status.Batcher = nodeInfo.Connected + status.BatcherConnCount = nodeInfo.ActiveConnections + nodeStatus[nodeID] = status + } + } + + // Check map responses using buildExpectedOnlineMap + onlineFromMaps := make(map[types.NodeID]bool) + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + + // For single node scenarios, we can't validate peer visibility since there are no peers + if len(expectedNodes) == 1 { + // For single node, just check that we have map responses for the node + for nodeID := range nodeStatus { + if _, exists := onlineMap[nodeID]; exists { + onlineFromMaps[nodeID] = true + } else { + onlineFromMaps[nodeID] = false + } + } + } else { + // Multi-node scenario: check peer visibility + for nodeID := range nodeStatus { + NODE_STATUS: + for id, peerMap := range onlineMap { + if id == nodeID { + continue + } + + online := peerMap[nodeID] + // If the node is offline in any map response, we consider it offline + if !online { + onlineFromMaps[nodeID] = false + continue NODE_STATUS + } + + onlineFromMaps[nodeID] = true + } + } + } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") + + // Update status with map response data + for nodeID, online := range onlineFromMaps { + if status, exists := nodeStatus[nodeID]; exists { + status.MapResponses = online + nodeStatus[nodeID] = status + } + } + + // Check nodestore state + for nodeID, node := range nodeStore { + if status, exists := nodeStatus[nodeID]; exists { + // Check if node is online in nodestore + status.NodeStore = node.IsOnline != nil && *node.IsOnline + nodeStatus[nodeID] = status + } + } + + // Verify all systems show nodes in expected state and report failures + allMatch := true + var failureReport strings.Builder + + ids := types.NodeIDs(maps.Keys(nodeStatus)) + slices.Sort(ids) + for _, nodeID := range ids { + status := nodeStatus[nodeID] + systemsMatch := (status.Batcher == expectedOnline) && + (status.MapResponses == expectedOnline) && + (status.NodeStore == expectedOnline) + + if !systemsMatch { + allMatch = false + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) + failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) + failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) + failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (expected: %t, down with at least one peer)\n", status.MapResponses, expectedOnline)) + failureReport.WriteString(fmt.Sprintf(" - nodestore: %t (expected: %t)\n", status.NodeStore, expectedOnline)) + } + } + + if !allMatch { + if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" { + t.Logf("Node state validation report changed at %s:", time.Now().Format(TimestampFormat)) + t.Logf("Previous report:\n%s", prevReport) + t.Logf("Current report:\n%s", failureReport.String()) + t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() + } + + failureReport.WriteString(fmt.Sprintf("validation_timestamp: %s\n", time.Now().Format(TimestampFormat))) + // Note: timeout_remaining not available in this context + + assert.Fail(c, failureReport.String()) + } + + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) + }, timeout, 2*time.Second, message) +} + +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components +func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { + t.Helper() + + // Stage 1: Verify batcher disconnection (should be immediate) + t.Logf("Stage 1: Verifying batcher disconnection for %d nodes", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + allBatcherOffline := true + for _, nodeID := range expectedNodes { + nodeIDStr := fmt.Sprintf("%d", nodeID) + if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { + allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) + } + } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") + }, 15*time.Second, 1*time.Second, "batcher disconnection validation") + + // Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay) + t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + allNodeStoreOffline := true + for _, nodeID := range expectedNodes { + if node, exists := nodeStore[nodeID]; exists { + isOnline := node.IsOnline != nil && *node.IsOnline + if isOnline { + allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) + } + } + } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") + }, 20*time.Second, 1*time.Second, "nodestore offline validation") + + // Stage 3: Verify map response propagation (longest delay due to peer update timing) + t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + allMapResponsesOffline := true + + if len(expectedNodes) == 1 { + // Single node: check if it appears in map responses + for nodeID := range onlineMap { + if slices.Contains(expectedNodes, nodeID) { + allMapResponsesOffline = false + assert.False(c, true, "Node %d should not appear in map responses", nodeID) + } + } + } else { + // Multi-node: check peer visibility + for _, nodeID := range expectedNodes { + for id, peerMap := range onlineMap { + if id == nodeID { + continue // Skip self-references + } + if online, exists := peerMap[nodeID]; exists && online { + allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) + } + } + } + } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") + }, 60*time.Second, 2*time.Second, "map response propagation validation") + + t.Logf("All stages completed: nodes are fully offline across all systems") +} + +// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database +// and a valid DERP server based on the NetInfo. This function follows the pattern of +// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. +func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message) + + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate node counts first + expectedCount := len(expectedNodes) + assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch during NetInfo validation - expected %d nodes, got %d", expectedCount, len(nodeStore)) + + // Check each expected node + for _, nodeID := range expectedNodes { + node, exists := nodeStore[nodeID] + assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { + continue + } + + // Validate that the node has Hostinfo + assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { + t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) + continue + } + + // Validate that the node has NetInfo + assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { + t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) + continue + } + + // Validate that the node has a valid DERP server (PreferredDERP should be > 0) + preferredDERP := node.Hostinfo.NetInfo.PreferredDERP + assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) + + t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) + } + }, timeout, 5*time.Second, message) + + endTime := time.Now() + duration := endTime.Sub(startTime) + t.Logf("requireAllClientsNetInfoAndDERP: Completed NetInfo/DERP validation for %d nodes at %s - Duration: %v - %s", len(expectedNodes), endTime.Format(TimestampFormat), duration, message) +} + +// assertLastSeenSet validates that a node has a non-nil LastSeen timestamp. +// Critical for ensuring node activity tracking is functioning properly. +func assertLastSeenSet(t *testing.T, node *v1.Node) { + assert.NotNil(t, node) + assert.NotNil(t, node.GetLastSeen()) +} + +// assertTailscaleNodesLogout verifies that all provided Tailscale clients +// are in the logged-out state (NeedsLogin). +func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { + if h, ok := t.(interface{ Helper() }); ok { + h.Helper() + } + + for _, client := range clients { + status, err := client.Status() + assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) + assert.Equal(t, "NeedsLogin", status.BackendState, + "client %s should be logged out", client.Hostname()) + } +} + +// pingAllHelper performs ping tests between all clients and addresses, returning success count. +// This is used to validate network connectivity in integration tests. +// Returns the total number of successful ping operations. +func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, addr := range addrs { + err := client.Ping(addr, opts...) + if err != nil { + t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + +// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses. +// This specifically tests connectivity through DERP relay servers, which is important +// for validating NAT traversal and relay functionality. Returns success count. +func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, addr := range addrs { + if isSelfClient(client, addr) { + continue + } + + err := client.Ping( + addr, + tsic.WithPingTimeout(derpPingTimeout), + tsic.WithPingCount(derpPingCount), + tsic.WithPingUntilDirect(false), + ) + if err != nil { + t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + +// isSelfClient determines if the given address belongs to the client itself. +// Used to avoid self-ping operations in connectivity tests by checking +// hostname and IP address matches. +func isSelfClient(client TailscaleClient, addr string) bool { + if addr == client.Hostname() { + return true + } + + ips, err := client.IPs() + if err != nil { + return false + } + + for _, ip := range ips { + if ip.String() == addr { + return true + } + } + + return false +} + +// assertClientsState validates the status and netmap of a list of clients for general connectivity. +// Runs parallel validation of status, netcheck, and netmap for all clients to ensure +// they have proper network configuration for all-to-all connectivity tests. +func assertClientsState(t *testing.T, clients []TailscaleClient) { + t.Helper() + + var wg sync.WaitGroup + + for _, client := range clients { + wg.Add(1) + c := client // Avoid loop pointer + go func() { + defer wg.Done() + assertValidStatus(t, c) + assertValidNetcheck(t, c) + assertValidNetmap(t, c) + }() + } + + t.Logf("waiting for client state checks to finish") + wg.Wait() +} + +// assertValidNetmap validates that a client's netmap has all required fields for proper operation. +// Checks self node and all peers for essential networking data including hostinfo, addresses, +// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56. +// This test is not suitable for ACL/partial connection tests. +func assertValidNetmap(t *testing.T, client TailscaleClient) { + t.Helper() + + if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { + t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) + + return + } + + t.Logf("Checking netmap of %q", client.Hostname()) + + netmap, err := client.Netmap() + if err != nil { + t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) + } + + assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { + assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) + } + + assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) + assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) + + assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) + + assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) + + for _, peer := range netmap.Peers { + assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) + assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) + + assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { + assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) + + // Netinfo is not always set + // assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) + if ni := hi.NetInfo(); ni.Valid() { + assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) + } + } + + assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + + assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + + assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) + } +} + +// assertValidStatus validates that a client's status has all required fields for proper operation. +// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints, +// and network map presence. This test is not suitable for ACL/partial connection tests. +func assertValidStatus(t *testing.T, client TailscaleClient) { + t.Helper() + status, err := client.Status(true) + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) + + assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) + + // This seem to not appear until version 1.56 + if status.Self.AllowedIPs != nil { + assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) + } + + assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) + + assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) + + assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) + + // This isn't really relevant for Self as it won't be in its own socket/wireguard. + // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) + + for _, peer := range status.Peer { + assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) + + assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) + + // This seem to not appear until version 1.56 + if peer.AllowedIPs != nil { + assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) + } + + // Addrs does not seem to appear in the status from peers. + // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) + assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) + + // TODO(kradalby): InEngine is only true when a proper tunnel is set up, + // there might be some interesting stuff to test here in the future. + // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) + } +} + +// assertValidNetcheck validates that a client has a proper DERP relay configured. +// Ensures the client has discovered and selected a DERP server for relay functionality, +// which is essential for NAT traversal and connectivity in restricted networks. +func assertValidNetcheck(t *testing.T, client TailscaleClient) { + t.Helper() + report, err := client.Netcheck() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) +} + +// assertCommandOutputContains executes a command with exponential backoff retry until the output +// contains the expected string or timeout is reached (10 seconds). +// This implements eventual consistency patterns and should be used instead of time.Sleep +// before executing commands that depend on network state propagation. +// +// Timeout: 10 seconds with exponential backoff +// Use cases: DNS resolution, route propagation, policy updates. +func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { + t.Helper() + + _, err := backoff.Retry(t.Context(), func() (struct{}, error) { + stdout, stderr, err := c.Execute(command) + if err != nil { + return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) + } + + if !strings.Contains(stdout, contains) { + return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + + assert.NoError(t, err) +} + +// dockertestMaxWait returns the maximum wait time for Docker-based test operations. +// Uses longer timeouts in CI environments to account for slower resource allocation +// and higher system load during automated testing. +func dockertestMaxWait() time.Duration { + wait := 300 * time.Second //nolint + + if util.IsCI() { + wait = 600 * time.Second //nolint + } + + return wait +} + +// didClientUseWebsocketForDERP analyzes client logs to determine if WebSocket was used for DERP. +// Searches for WebSocket connection indicators in client logs to validate +// DERP relay communication method for debugging connectivity issues. +func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { + t.Helper() + + buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) + if err != nil { + t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) + } + + count, err := countMatchingLines(buf, func(line string) bool { + return strings.Contains(line, "websocket: connected to ") + }) + if err != nil { + t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err) + } + + return count > 0 +} + +// countMatchingLines counts lines in a reader that match the given predicate function. +// Uses optimized buffering for log analysis and provides flexible line-by-line +// filtering for log parsing and pattern matching in integration tests. +func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) { + count := 0 + scanner := bufio.NewScanner(in) + { + const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) + scanner.Buffer(buff, len(buff)) + scanner.Split(bufio.ScanLines) + } + + for scanner.Scan() { + if predicate(scanner.Text()) { + count += 1 + } + } + + return count, scanner.Err() +} + +// wildcard returns a wildcard alias (*) for use in policy v2 configurations. +// Provides a convenient helper for creating permissive policy rules. +func wildcard() policyv2.Alias { + return policyv2.Wildcard +} + +// usernamep returns a pointer to a Username as an Alias for policy v2 configurations. +// Used in ACL rules to reference specific users in network access policies. +func usernamep(name string) policyv2.Alias { + return ptr.To(policyv2.Username(name)) +} + +// hostp returns a pointer to a Host as an Alias for policy v2 configurations. +// Used in ACL rules to reference specific hosts in network access policies. +func hostp(name string) policyv2.Alias { + return ptr.To(policyv2.Host(name)) +} + +// groupp returns a pointer to a Group as an Alias for policy v2 configurations. +// Used in ACL rules to reference user groups in network access policies. +func groupp(name string) policyv2.Alias { + return ptr.To(policyv2.Group(name)) +} + +// tagp returns a pointer to a Tag as an Alias for policy v2 configurations. +// Used in ACL rules to reference node tags in network access policies. +func tagp(name string) policyv2.Alias { + return ptr.To(policyv2.Tag(name)) +} + +// prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. +// Converts CIDR notation to policy prefix format for network range specifications. +func prefixp(cidr string) policyv2.Alias { + prefix := netip.MustParsePrefix(cidr) + return ptr.To(policyv2.Prefix(prefix)) +} + +// aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges. +// Combines network targets with specific port restrictions for fine-grained +// access control in policy v2 configurations. +func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts { + return policyv2.AliasWithPorts{ + Alias: alias, + Ports: ports, + } +} + +// usernameOwner returns a Username as an Owner for use in TagOwners policies. +// Specifies which users can assign and manage specific tags in ACL configurations. +func usernameOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Username(name)) +} + +// groupOwner returns a Group as an Owner for use in TagOwners policies. +// Specifies which groups can assign and manage specific tags in ACL configurations. +func groupOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Group(name)) +} + +// usernameApprover returns a Username as an AutoApprover for subnet route policies. +// Specifies which users can automatically approve subnet route advertisements. +func usernameApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Username(name)) +} + +// groupApprover returns a Group as an AutoApprover for subnet route policies. +// Specifies which groups can automatically approve subnet route advertisements. +func groupApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Group(name)) +} + +// tagApprover returns a Tag as an AutoApprover for subnet route policies. +// Specifies which tagged nodes can automatically approve subnet route advertisements. +func tagApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Tag(name)) +} + +// oidcMockUser creates a MockUser for OIDC authentication testing. +// Generates consistent test user data with configurable email verification status +// for validating OIDC integration flows in headscale authentication tests. +func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { + return mockoidc.MockUser{ + Subject: username, + PreferredUsername: username, + Email: username + "@headscale.net", + EmailVerified: emailVerified, + } +} diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 9c28dc00..baf41dcf 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -13,6 +13,7 @@ import ( "net/http" "net/netip" "os" + "os/exec" "path" "path/filepath" "sort" @@ -460,6 +461,12 @@ func New( dockertestutil.DockerAllowNetworkAdministration, ) if err != nil { + // Try to get more detailed build output + log.Printf("Docker build failed, attempting to get detailed output...") + buildOutput := runDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName) + if buildOutput != "" { + return nil, fmt.Errorf("could not start headscale container: %w\n\nDetailed build output:\n%s", err, buildOutput) + } return nil, fmt.Errorf("could not start headscale container: %w", err) } log.Printf("Created %s container\n", hsic.hostname) @@ -1391,3 +1398,13 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er return nodeStore, nil } + +// runDockerBuildForDiagnostics runs docker build manually to get detailed error output +func runDockerBuildForDiagnostics(contextDir, dockerfile string) string { + cmd := exec.Command("docker", "build", "-f", dockerfile, contextDir) + output, err := cmd.CombinedOutput() + if err != nil { + return string(output) + } + return "" +} diff --git a/integration/route_test.go b/integration/route_test.go index 9aced164..a613c375 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -53,16 +53,16 @@ func TestEnablingRoutes(t *testing.T) { err = scenario.CreateHeadscaleEnv( []tsic.Option{tsic.WithAcceptRoutes()}, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) expectedRoutes := map[string]string{ "1": "10.0.0.0/24", @@ -83,7 +83,7 @@ func TestEnablingRoutes(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) var nodes []*v1.Node // Wait for route advertisements to propagate to NodeStore @@ -256,16 +256,16 @@ func TestHASubnetRouterFailover(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) prefp, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) @@ -319,7 +319,7 @@ func TestHASubnetRouterFailover(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for route configuration changes after advertising routes var nodes []*v1.Node @@ -1341,16 +1341,16 @@ func TestSubnetRouteACL(t *testing.T) { }, }, )) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) expectedRoutes := map[string]string{ "1": "10.33.0.0/16", @@ -1393,7 +1393,7 @@ func TestSubnetRouteACL(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for route advertisements to propagate to the server var nodes []*v1.Node @@ -1572,25 +1572,25 @@ func TestEnablingExitRoutes(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErrf(t, "failed to create scenario: %s", err) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithExtraLoginArgs([]string{"--advertise-exit-node"}), }, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) nodes, err := headscale.ListNodes() require.NoError(t, err) @@ -1686,16 +1686,16 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) assert.NotNil(t, headscale) pref, err := scenario.SubnetOfNetwork("usernet1") @@ -1833,16 +1833,16 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) assert.NotNil(t, headscale) var user1c, user2c TailscaleClient @@ -2247,13 +2247,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = scenario.createHeadscaleEnv(tt.withURL, tsOpts, opts..., ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) services, err := scenario.Services("usernet1") require.NoError(t, err) @@ -2263,7 +2263,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) assert.NotNil(t, headscale) // Add the Docker network route to the auto-approvers @@ -2304,21 +2304,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if tt.withURL { u, err := routerUsernet1.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) body, err := doLoginURL(routerUsernet1.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) scenario.runHeadscaleRegister("user1", body) } else { userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) - assertNoErr(t, err) + require.NoError(t, err) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) - assertNoErr(t, err) + require.NoError(t, err) } // extra creation end. @@ -2893,13 +2893,13 @@ func TestSubnetRouteACLFiltering(t *testing.T) { hsic.WithACLPolicy(aclPolicy), hsic.WithPolicyMode(types.PolicyModeDB), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) // Get the router and node clients by user routerClients, err := scenario.ListTailscaleClients(routerUser) @@ -2944,7 +2944,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { require.NoErrorf(t, err, "failed to advertise routes: %s", err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) var routerNode, nodeNode *v1.Node // Wait for route advertisements to propagate to NodeStore diff --git a/integration/scenario_test.go b/integration/scenario_test.go index ead3f1fd..1e2a151a 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -5,6 +5,7 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" ) // This file is intended to "test the test framework", by proxy it will also test @@ -34,7 +35,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" scenario, err := NewScenario(ScenarioSpec{}) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) t.Run("start-headscale", func(t *testing.T) { @@ -82,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 scenario, err := NewScenario(ScenarioSpec{}) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) t.Run("start-headscale", func(t *testing.T) { diff --git a/integration/ssh_test.go b/integration/ssh_test.go index a5975eb4..1299ba52 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" ) @@ -30,7 +31,7 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce Users: []string{"user1", "user2"}, } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) err = scenario.CreateHeadscaleEnv( []tsic.Option{ @@ -50,13 +51,13 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce hsic.WithACLPolicy(policy), hsic.WithTestName("ssh"), ) - assertNoErr(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleSync() - assertNoErr(t, err) + require.NoError(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErr(t, err) + require.NoError(t, err) return scenario } @@ -93,19 +94,19 @@ func TestSSHOneUserToAll(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range user1Clients { for _, peer := range allClients { @@ -160,16 +161,16 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) nsOneClients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) nsTwoClients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) testInterUserSSH := func(sourceClients []TailscaleClient, targetClients []TailscaleClient) { for _, client := range sourceClients { @@ -208,13 +209,13 @@ func TestSSHNoSSHConfigured(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -259,13 +260,13 @@ func TestSSHIsBlockedInACL(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -317,16 +318,16 @@ func TestSSHUserOnlyIsolation(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) ssh1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) ssh2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range ssh1Clients { for _, peer := range ssh2Clients { @@ -422,9 +423,9 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien t.Helper() result, _, err := doSSH(t, client, peer) - assertNoErr(t, err) + require.NoError(t, err) - assertContains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", "")) + require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", "")) } func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { diff --git a/integration/utils.go b/integration/utils.go deleted file mode 100644 index 117bdab7..00000000 --- a/integration/utils.go +++ /dev/null @@ -1,533 +0,0 @@ -package integration - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/netip" - "strings" - "sync" - "testing" - "time" - - "github.com/cenkalti/backoff/v5" - policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/juanfont/headscale/integration/tsic" - "github.com/stretchr/testify/assert" - "tailscale.com/tailcfg" - "tailscale.com/types/ptr" -) - -const ( - // derpPingTimeout defines the timeout for individual DERP ping operations - // Used in DERP connectivity tests to verify relay server communication. - derpPingTimeout = 2 * time.Second - - // derpPingCount defines the number of ping attempts for DERP connectivity tests - // Higher count provides better reliability assessment of DERP connectivity. - derpPingCount = 10 - - // TimestampFormat is the standard timestamp format used across all integration tests - // Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps - // suitable for debugging and log correlation in integration tests. - TimestampFormat = "2006-01-02T15-04-05.999999999" - - // TimestampFormatRunID is used for generating unique run identifiers - // Format: "20060102-150405" provides compact date-time for file/directory names. - TimestampFormatRunID = "20060102-150405" -) - -func assertNoErr(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "unexpected error: %s", err) -} - -func assertNoErrf(t *testing.T, msg string, err error) { - t.Helper() - if err != nil { - t.Fatalf(msg, err) - } -} - -func assertNotNil(t *testing.T, thing interface{}) { - t.Helper() - if thing == nil { - t.Fatal("got unexpected nil") - } -} - -func assertNoErrHeadscaleEnv(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to create headscale environment: %s", err) -} - -func assertNoErrGetHeadscale(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to get headscale: %s", err) -} - -func assertNoErrListClients(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to list clients: %s", err) -} - -func assertNoErrListClientIPs(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to get client IPs: %s", err) -} - -func assertNoErrSync(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to have all clients sync up: %s", err) -} - -func assertNoErrListFQDN(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to list FQDNs: %s", err) -} - -func assertNoErrLogout(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to log out tailscale nodes: %s", err) -} - -func assertContains(t *testing.T, str, subStr string) { - t.Helper() - if !strings.Contains(str, subStr) { - t.Fatalf("%#v does not contain %#v", str, subStr) - } -} - -func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { - t.Helper() - - buf := &bytes.Buffer{} - err := client.WriteLogs(buf, buf) - if err != nil { - t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) - } - - count, err := countMatchingLines(buf, func(line string) bool { - return strings.Contains(line, "websocket: connected to ") - }) - if err != nil { - t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err) - } - - return count > 0 -} - -// pingAllHelper performs ping tests between all clients and addresses, returning success count. -// This is used to validate network connectivity in integration tests. -// Returns the total number of successful ping operations. -func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { - t.Helper() - success := 0 - - for _, client := range clients { - for _, addr := range addrs { - err := client.Ping(addr, opts...) - if err != nil { - t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) - } else { - success++ - } - } - } - - return success -} - -// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses. -// This specifically tests connectivity through DERP relay servers, which is important -// for validating NAT traversal and relay functionality. Returns success count. -func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { - t.Helper() - success := 0 - - for _, client := range clients { - for _, addr := range addrs { - if isSelfClient(client, addr) { - continue - } - - err := client.Ping( - addr, - tsic.WithPingTimeout(derpPingTimeout), - tsic.WithPingCount(derpPingCount), - tsic.WithPingUntilDirect(false), - ) - if err != nil { - t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err) - } else { - success++ - } - } - } - - return success -} - -// assertClientsState validates the status and netmap of a list of -// clients for the general case of all to all connectivity. -func assertClientsState(t *testing.T, clients []TailscaleClient) { - t.Helper() - - var wg sync.WaitGroup - - for _, client := range clients { - wg.Add(1) - c := client // Avoid loop pointer - go func() { - defer wg.Done() - assertValidStatus(t, c) - assertValidNetcheck(t, c) - assertValidNetmap(t, c) - }() - } - - t.Logf("waiting for client state checks to finish") - wg.Wait() -} - -// assertValidNetmap asserts that the netmap of a client has all -// the minimum required fields set to a known working config for -// the general case. Fields are checked on self, then all peers. -// This test is not suitable for ACL/partial connection tests. -// This test can only be run on clients from 1.56.1. It will -// automatically pass all clients below that and is safe to call -// for all versions. -func assertValidNetmap(t *testing.T, client TailscaleClient) { - t.Helper() - - if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { - t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) - - return - } - - t.Logf("Checking netmap of %q", client.Hostname()) - - netmap, err := client.Netmap() - if err != nil { - t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) - } - - assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) - if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { - assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) - } - - assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) - assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) - - assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) - - assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) - assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) - assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) - - for _, peer := range netmap.Peers { - assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) - assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) - - assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) - if hi := peer.Hostinfo(); hi.Valid() { - assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) - - // Netinfo is not always set - // assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) - if ni := hi.NetInfo(); ni.Valid() { - assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) - } - } - - assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) - - assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) - - assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) - } -} - -// assertValidStatus asserts that the status of a client has all -// the minimum required fields set to a known working config for -// the general case. Fields are checked on self, then all peers. -// This test is not suitable for ACL/partial connection tests. -func assertValidStatus(t *testing.T, client TailscaleClient) { - t.Helper() - status, err := client.Status(true) - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) - - assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) - - // This seem to not appear until version 1.56 - if status.Self.AllowedIPs != nil { - assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) - } - - assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) - - assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) - - assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) - - // This isn't really relevant for Self as it won't be in its own socket/wireguard. - // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) - // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) - - for _, peer := range status.Peer { - assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) - - assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) - - // This seem to not appear until version 1.56 - if peer.AllowedIPs != nil { - assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) - } - - // Addrs does not seem to appear in the status from peers. - // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) - assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) - - // TODO(kradalby): InEngine is only true when a proper tunnel is set up, - // there might be some interesting stuff to test here in the future. - // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) - } -} - -func assertValidNetcheck(t *testing.T, client TailscaleClient) { - t.Helper() - report, err := client.Netcheck() - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) -} - -// assertCommandOutputContains executes a command with exponential backoff retry until the output -// contains the expected string or timeout is reached (10 seconds). -// This implements eventual consistency patterns and should be used instead of time.Sleep -// before executing commands that depend on network state propagation. -// -// Timeout: 10 seconds with exponential backoff -// Use cases: DNS resolution, route propagation, policy updates. -func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { - t.Helper() - - _, err := backoff.Retry(t.Context(), func() (struct{}, error) { - stdout, stderr, err := c.Execute(command) - if err != nil { - return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) - } - - if !strings.Contains(stdout, contains) { - return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) - } - - return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) - - assert.NoError(t, err) -} - -func isSelfClient(client TailscaleClient, addr string) bool { - if addr == client.Hostname() { - return true - } - - ips, err := client.IPs() - if err != nil { - return false - } - - for _, ip := range ips { - if ip.String() == addr { - return true - } - } - - return false -} - -func dockertestMaxWait() time.Duration { - wait := 300 * time.Second //nolint - - if util.IsCI() { - wait = 600 * time.Second //nolint - } - - return wait -} - -func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) { - count := 0 - scanner := bufio.NewScanner(in) - { - const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB - buff := make([]byte, logBufferInitialSize) - scanner.Buffer(buff, len(buff)) - scanner.Split(bufio.ScanLines) - } - - for scanner.Scan() { - if predicate(scanner.Text()) { - count += 1 - } - } - - return count, scanner.Err() -} - -// func dockertestCommandTimeout() time.Duration { -// timeout := 10 * time.Second //nolint -// -// if isCI() { -// timeout = 60 * time.Second //nolint -// } -// -// return timeout -// } - -// pingAllNegativeHelper is intended to have 1 or more nodes timing out from the ping, -// it counts failures instead of successes. -// func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { -// t.Helper() -// failures := 0 -// -// timeout := 100 -// count := 3 -// -// for _, client := range clients { -// for _, addr := range addrs { -// err := client.Ping( -// addr, -// tsic.WithPingTimeout(time.Duration(timeout)*time.Millisecond), -// tsic.WithPingCount(count), -// ) -// if err != nil { -// failures++ -// } -// } -// } -// -// return failures -// } - -// // findPeerByIP takes an IP and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus -// // if there is a peer with the given IP. If no peer is found, nil is returned. -// func findPeerByIP( -// ip netip.Addr, -// peers map[key.NodePublic]*ipnstate.PeerStatus, -// ) *ipnstate.PeerStatus { -// for _, peer := range peers { -// for _, peerIP := range peer.TailscaleIPs { -// if ip == peerIP { -// return peer -// } -// } -// } -// -// return nil -// } - -// Helper functions for creating typed policy entities - -// wildcard returns a wildcard alias (*). -func wildcard() policyv2.Alias { - return policyv2.Wildcard -} - -// usernamep returns a pointer to a Username as an Alias. -func usernamep(name string) policyv2.Alias { - return ptr.To(policyv2.Username(name)) -} - -// hostp returns a pointer to a Host. -func hostp(name string) policyv2.Alias { - return ptr.To(policyv2.Host(name)) -} - -// groupp returns a pointer to a Group as an Alias. -func groupp(name string) policyv2.Alias { - return ptr.To(policyv2.Group(name)) -} - -// tagp returns a pointer to a Tag as an Alias. -func tagp(name string) policyv2.Alias { - return ptr.To(policyv2.Tag(name)) -} - -// prefixp returns a pointer to a Prefix from a CIDR string. -func prefixp(cidr string) policyv2.Alias { - prefix := netip.MustParsePrefix(cidr) - return ptr.To(policyv2.Prefix(prefix)) -} - -// aliasWithPorts creates an AliasWithPorts structure from an alias and ports. -func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts { - return policyv2.AliasWithPorts{ - Alias: alias, - Ports: ports, - } -} - -// usernameOwner returns a Username as an Owner for use in TagOwners. -func usernameOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Username(name)) -} - -// groupOwner returns a Group as an Owner for use in TagOwners. -func groupOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Group(name)) -} - -// usernameApprover returns a Username as an AutoApprover. -func usernameApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Username(name)) -} - -// groupApprover returns a Group as an AutoApprover. -func groupApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Group(name)) -} - -// tagApprover returns a Tag as an AutoApprover. -func tagApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Tag(name)) -} - -// -// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus -// // if there is a peer with the given hostname. If no peer is found, nil is returned. -// func findPeerByHostname( -// hostname string, -// peers map[key.NodePublic]*ipnstate.PeerStatus, -// ) *ipnstate.PeerStatus { -// for _, peer := range peers { -// if hostname == peer.HostName { -// return peer -// } -// } -// -// return nil -// }