From 9b47f71f370088b5119cacf5880a7914b936deeb Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 9 Jul 2025 11:15:48 +0000 Subject: [PATCH] integration: replace time.Sleep with Eventually sleeping in tests is a big no no, its time to stop. Sleeping only works well on the same machine under the same conditions we rather wait for something as things take time on different machines --- .github/workflows/docs-deploy.yml | 3 +- .../workflows/integration-test-template.yml | 2 +- .github/workflows/lint.yml | 6 +- cmd/headscale/cli/debug.go | 2 +- cmd/headscale/cli/mockoidc.go | 3 +- cmd/headscale/cli/nodes.go | 27 +-- cmd/headscale/cli/users.go | 37 ++-- cmd/hi/cleanup.go | 12 +- cmd/hi/docker.go | 41 ++-- cmd/hi/tar_utils.go | 9 +- flake.nix | 1 + hscontrol/auth.go | 6 +- hscontrol/capver/capver.go | 3 +- hscontrol/capver/capver_generated.go | 25 ++- hscontrol/db/db.go | 5 +- hscontrol/derp/server/derp_server.go | 2 +- hscontrol/dns/extrarecords.go | 1 - hscontrol/grpcv1.go | 5 +- hscontrol/handlers.go | 3 +- hscontrol/mapper/mapper.go | 2 +- hscontrol/mapper/mapper_test.go | 6 +- hscontrol/mapper/tail.go | 2 +- hscontrol/metrics.go | 1 + hscontrol/notifier/notifier.go | 6 +- hscontrol/notifier/notifier_test.go | 15 +- hscontrol/oidc.go | 12 +- hscontrol/policy/matcher/matcher.go | 5 +- hscontrol/policy/pm.go | 1 - hscontrol/policy/policy.go | 3 +- hscontrol/policy/policy_test.go | 4 +- hscontrol/policy/v2/filter.go | 7 +- hscontrol/policy/v2/policy.go | 10 +- hscontrol/policy/v2/policy_test.go | 2 +- hscontrol/policy/v2/types.go | 86 +++++---- hscontrol/policy/v2/types_test.go | 22 ++- hscontrol/policy/v2/utils_test.go | 8 +- hscontrol/routes/primary.go | 1 + hscontrol/state/state.go | 4 +- hscontrol/tailsql.go | 4 +- hscontrol/templates/apple.go | 12 +- hscontrol/templates/windows.go | 4 +- hscontrol/types/common.go | 1 + hscontrol/types/config.go | 5 +- hscontrol/types/config_test.go | 1 + hscontrol/types/node.go | 18 +- hscontrol/types/node_test.go | 2 +- hscontrol/types/preauth_key.go | 2 +- hscontrol/types/preauth_key_test.go | 4 +- hscontrol/types/users.go | 6 +- hscontrol/types/version.go | 6 +- hscontrol/util/dns.go | 11 +- hscontrol/util/log.go | 2 +- hscontrol/util/net.go | 1 + hscontrol/util/util.go | 40 ++-- integration/acl_test.go | 3 - integration/auth_key_test.go | 145 +++++++++----- integration/auth_oidc_test.go | 70 ++++--- integration/auth_web_flow_test.go | 8 +- integration/cli_test.go | 133 ++++++------- integration/derp_verify_endpoint_test.go | 3 +- integration/dns_test.go | 16 +- integration/dockertestutil/config.go | 13 +- integration/dockertestutil/execute.go | 6 +- integration/dsic/dsic.go | 3 +- integration/embedded_derp_test.go | 11 +- integration/general_test.go | 177 +++++++++++------- integration/hsic/hsic.go | 48 ++--- integration/route_test.go | 27 +-- integration/scenario.go | 15 +- integration/scenario_test.go | 2 - integration/ssh_test.go | 72 ++++--- integration/tsic/tsic.go | 20 +- integration/utils.go | 6 +- 73 files changed, 675 insertions(+), 612 deletions(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 15637069..7d06b6a6 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -48,5 +48,4 @@ jobs: - name: Deploy stable docs from tag if: startsWith(github.ref, 'refs/tags/v') # This assumes that only newer tags are pushed - run: - mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest + run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest diff --git a/.github/workflows/integration-test-template.yml b/.github/workflows/integration-test-template.yml index 1c621192..939451d4 100644 --- a/.github/workflows/integration-test-template.yml +++ b/.github/workflows/integration-test-template.yml @@ -75,7 +75,7 @@ jobs: # Some of the jobs might still require manual restart as they are really # slow and this will cause them to eventually be killed by Github actions. attempt_delay: 300000 # 5 min - attempt_limit: 3 + attempt_limit: 2 command: | nix develop --command -- hi run "^${{ inputs.test }}$" \ --timeout=120m \ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 49334233..1e06f4de 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -36,8 +36,7 @@ jobs: - name: golangci-lint if: steps.changed-files.outputs.files == 'true' - run: - nix develop --command -- golangci-lint run + run: nix develop --command -- golangci-lint run --new-from-rev=${{github.event.pull_request.base.sha}} --format=colored-line-number @@ -75,8 +74,7 @@ jobs: - name: Prettify code if: steps.changed-files.outputs.files == 'true' - run: - nix develop --command -- prettier --no-error-on-unmatched-pattern + run: nix develop --command -- prettier --no-error-on-unmatched-pattern --ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html} proto-lint: diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 41b46fb0..8ce5f237 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -117,7 +117,7 @@ var createNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()), + "Cannot create node: "+status.Convert(err).Message(), output, ) } diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 309ad67d..9969f7c6 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -2,6 +2,7 @@ package cli import ( "encoding/json" + "errors" "fmt" "net" "net/http" @@ -68,7 +69,7 @@ func mockOIDC() error { userStr := os.Getenv("MOCKOIDC_USERS") if userStr == "" { - return fmt.Errorf("MOCKOIDC_USERS not defined") + return errors.New("MOCKOIDC_USERS not defined") } var users []mockoidc.MockUser diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 00d803b2..fb49f4a3 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -184,7 +184,7 @@ var listNodesCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), + "Cannot get nodes: "+status.Convert(err).Message(), output, ) } @@ -398,10 +398,7 @@ var deleteNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error getting node node: %s", - status.Convert(err).Message(), - ), + "Error getting node node: "+status.Convert(err).Message(), output, ) @@ -437,10 +434,7 @@ var deleteNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error deleting node: %s", - status.Convert(err).Message(), - ), + "Error deleting node: "+status.Convert(err).Message(), output, ) @@ -498,10 +492,7 @@ var moveNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error getting node: %s", - status.Convert(err).Message(), - ), + "Error getting node: "+status.Convert(err).Message(), output, ) @@ -517,10 +508,7 @@ var moveNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error moving node: %s", - status.Convert(err).Message(), - ), + "Error moving node: "+status.Convert(err).Message(), output, ) @@ -567,10 +555,7 @@ be assigned to nodes.`, if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error backfilling IPs: %s", - status.Convert(err).Message(), - ), + "Error backfilling IPs: "+status.Convert(err).Message(), output, ) diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index b5f1bc49..c482299c 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/url" + "strconv" survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -27,10 +28,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { err := errors.New("--name or --identifier flag is required") ErrorOutput( err, - fmt.Sprintf( - "Cannot rename user: %s", - status.Convert(err).Message(), - ), + "Cannot rename user: "+status.Convert(err).Message(), "", ) } @@ -114,10 +112,7 @@ var createUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot create user: %s", - status.Convert(err).Message(), - ), + "Cannot create user: "+status.Convert(err).Message(), output, ) } @@ -147,16 +142,16 @@ var destroyUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } if len(users.GetUsers()) != 1 { - err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID") + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } @@ -185,10 +180,7 @@ var destroyUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot destroy user: %s", - status.Convert(err).Message(), - ), + "Cannot destroy user: "+status.Convert(err).Message(), output, ) } @@ -233,7 +225,7 @@ var listUsersCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()), + "Cannot get users: "+status.Convert(err).Message(), output, ) } @@ -247,7 +239,7 @@ var listUsersCmd = &cobra.Command{ tableData = append( tableData, []string{ - fmt.Sprintf("%d", user.GetId()), + strconv.FormatUint(user.GetId(), 10), user.GetDisplayName(), user.GetName(), user.GetEmail(), @@ -287,16 +279,16 @@ var renameUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } if len(users.GetUsers()) != 1 { - err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID") + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) } @@ -312,10 +304,7 @@ var renameUserCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot rename user: %s", - status.Convert(err).Message(), - ), + "Cannot rename user: "+status.Convert(err).Message(), output, ) } diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 080266d8..fd78c66f 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -66,7 +66,7 @@ func killTestContainers(ctx context.Context) error { if cont.State == "running" { _ = cli.ContainerKill(ctx, cont.ID, "KILL") } - + // Then remove the container with retry logic if removeContainerWithRetry(ctx, cli, cont.ID) { removed++ @@ -87,25 +87,25 @@ func killTestContainers(ctx context.Context) error { func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool { maxRetries := 3 baseDelay := 100 * time.Millisecond - - for attempt := 0; attempt < maxRetries; attempt++ { + + for attempt := range maxRetries { err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{ Force: true, }) if err == nil { return true } - + // If this is the last attempt, don't wait if attempt == maxRetries-1 { break } - + // Wait with exponential backoff delay := baseDelay * time.Duration(1< diff --git a/hscontrol/auth.go b/hscontrol/auth.go index f9de67e7..986bbabc 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode( return nil, nil } - } n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry) @@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, changed, err := h.state.HandleNodeFromPreAuthKey( regReq, machineKey, @@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } - if perr, ok := err.(types.PAKError); ok { + var perr types.PAKError + if errors.As(err, &perr) { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } + return nil, err } diff --git a/hscontrol/capver/capver.go b/hscontrol/capver/capver.go index 7ad5074d..347ec981 100644 --- a/hscontrol/capver/capver.go +++ b/hscontrol/capver/capver.go @@ -1,11 +1,10 @@ package capver import ( + "slices" "sort" "strings" - "slices" - xmaps "golang.org/x/exp/maps" "tailscale.com/tailcfg" "tailscale.com/util/set" diff --git a/hscontrol/capver/capver_generated.go b/hscontrol/capver/capver_generated.go index f192fad4..687e3d51 100644 --- a/hscontrol/capver/capver_generated.go +++ b/hscontrol/capver/capver_generated.go @@ -1,6 +1,6 @@ package capver -//Generated DO NOT EDIT +// Generated DO NOT EDIT import "tailscale.com/tailcfg" @@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ "v1.82.5": 115, } - var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ - 87: "v1.60.0", - 88: "v1.62.0", - 90: "v1.64.0", - 95: "v1.66.0", - 97: "v1.68.0", - 102: "v1.70.0", - 104: "v1.72.0", - 106: "v1.74.0", - 109: "v1.78.0", - 113: "v1.80.0", - 115: "v1.82.0", + 87: "v1.60.0", + 88: "v1.62.0", + 90: "v1.64.0", + 95: "v1.66.0", + 97: "v1.68.0", + 102: "v1.70.0", + 104: "v1.72.0", + 106: "v1.74.0", + 109: "v1.78.0", + 113: "v1.80.0", + 115: "v1.82.0", } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 56d7860b..abda802c 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -764,13 +764,13 @@ AND auth_key_id NOT IN ( // Drop all indexes first to avoid conflicts indexesToDrop := []string{ "idx_users_deleted_at", - "idx_provider_identifier", + "idx_provider_identifier", "idx_name_provider_identifier", "idx_name_no_provider_identifier", "idx_api_keys_prefix", "idx_policies_deleted_at", } - + for _, index := range indexesToDrop { _ = tx.Exec("DROP INDEX IF EXISTS " + index).Error } @@ -927,6 +927,7 @@ AND auth_key_id NOT IN ( } log.Info().Msg("Schema recreation completed successfully") + return nil }, Rollback: func(db *gorm.DB) error { return nil }, diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index ae7bf03e..fee395f1 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { Avoid: false, Nodes: []*tailcfg.DERPNode{ { - Name: fmt.Sprintf("%d", d.cfg.ServerRegionID), + Name: strconv.Itoa(d.cfg.ServerRegionID), RegionID: d.cfg.ServerRegionID, HostName: host, DERPPort: port, diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 6ea3aa35..82b3078b 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() { return struct{}{}, nil }, backoff.WithBackOff(backoff.NewExponentialBackOff())) - if err != nil { log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete") continue diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index e098b766..7df4c92e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode( api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) } - ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) + ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname) + api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID) + + ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname) api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) log.Trace(). diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index f32aea96..590541b0 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -32,7 +32,7 @@ const ( reservedResponseHeaderSize = 4 ) -// httpError logs an error and sends an HTTP error response with the given +// httpError logs an error and sends an HTTP error response with the given. func httpError(w http.ResponseWriter, err error) { var herr HTTPError if errors.As(err, &herr) { @@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest( resp := &tailcfg.DERPAdmitClientResponse{ Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), } + return json.NewEncoder(writer).Encode(resp) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 49a99351..553658f5 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types. } // ListNodes queries the database for either all nodes if no parameters are given -// or for the given nodes if at least one node ID is given as parameter +// or for the given nodes if at least one node ID is given as parameter. func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { nodes, err := m.state.ListNodes(nodeIDs...) if err != nil { diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 71b9e4b9..b5747c2b 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) { } } -// mockState is a mock implementation that provides the required methods +// mockState is a mock implementation that provides the required methods. type mockState struct { polMan policy.PolicyManager derpMap *tailcfg.DERPMap @@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ } } } + return filtered, nil } // Return all peers except the node itself @@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ filtered = append(filtered, peer) } } + return filtered, nil } @@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { } } } + return filtered, nil } + return m.nodes, nil } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 9b58ad34..9729301d 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -11,7 +11,7 @@ import ( "tailscale.com/types/views" ) -// NodeCanHaveTagChecker is an interface for checking if a node can have a tag +// NodeCanHaveTagChecker is an interface for checking if a node can have a tag. type NodeCanHaveTagChecker interface { NodeCanHaveTag(node types.NodeView, tag string) bool } diff --git a/hscontrol/metrics.go b/hscontrol/metrics.go index cb01838c..ef427afb 100644 --- a/hscontrol/metrics.go +++ b/hscontrol/metrics.go @@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) { } n, err := r.ResponseWriter.Write(b) r.written += int64(n) + return n, err } diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 2e6b9b0b..6bd990c7 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier { n.b = b go b.doWork() + return n } @@ -72,7 +73,7 @@ func (n *Notifier) Close() { n.nodes = make(map[types.NodeID]chan<- types.StateUpdate) } -// safeCloseChannel closes a channel and panic recovers if already closed +// safeCloseChannel closes a channel and panic recovers if already closed. func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) { defer func() { if r := recover(); r != nil { @@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool { if val, ok := n.connected.Load(nodeID); ok { return val } + return false } @@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { return false } -// LikelyConnectedMap returns a thread safe map of connected nodes +// LikelyConnectedMap returns a thread safe map of connected nodes. func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { return n.connected } diff --git a/hscontrol/notifier/notifier_test.go b/hscontrol/notifier/notifier_test.go index 9654cfc8..c3e96a8d 100644 --- a/hscontrol/notifier/notifier_test.go +++ b/hscontrol/notifier/notifier_test.go @@ -1,17 +1,15 @@ package notifier import ( - "context" "fmt" "math/rand" "net/netip" + "slices" "sort" "sync" "testing" "time" - "slices" - "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) { defer n.RemoveNode(1, ch) for _, u := range tt.updates { - n.NotifyAll(context.Background(), u) + n.NotifyAll(t.Context(), u) } n.b.flush() @@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) { // TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected // Multiple goroutines calling AddNode and RemoveNode cause panics when trying to // close a channel that was already closed, which can happen when a node changes -// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting +// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting. func TestIsLikelyConnectedRaceCondition(t *testing.T) { // mock config for the notifier cfg := &types.Config{ @@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) { for range iterations { // Simulate race by having some goroutines check IsLikelyConnected // while others add/remove the node - if routineID%3 == 0 { + switch routineID % 3 { + case 0: // This goroutine checks connection status isConnected := notifier.IsLikelyConnected(nodeID) if isConnected != true && isConnected != false { errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected) } - } else if routineID%3 == 1 { + case 1: // This goroutine removes the node notifier.RemoveNode(nodeID, updateChan) - } else { + default: // This goroutine adds the node back notifier.AddNode(nodeID, updateChan) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 1f08adf8..5f1935e5 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -84,11 +84,8 @@ func NewAuthProviderOIDC( ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, Endpoint: oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf( - "%s/oidc/callback", - strings.TrimSuffix(serverURL, "/"), - ), - Scopes: cfg.Scope, + RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback", + Scopes: cfg.Scope, } registrationCache := zcache.New[string, RegistrationInfo]( @@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler( req *http.Request, ) { vars := mux.Vars(req) - registrationIdStr, _ := vars["registration_id"] + registrationIdStr := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render @@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } oauth2Token, err := a.getOauth2Token(req.Context(), code, state) - if err != nil { httpError(writer, err) return @@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Neither node nor machine key was found in the state cache meaning // that we could not reauth nor register the node. httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + return } @@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token( if err != nil { return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) } + return oauth2Token, err } diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index d246d5e2..aac5a5f3 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -2,9 +2,8 @@ package matcher import ( "net/netip" - "strings" - "slices" + "strings" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" @@ -28,6 +27,7 @@ func (m Match) DebugString() string { for _, prefix := range m.dests.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } + return sb.String() } @@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { for _, rule := range rules { matches = append(matches, MatchFromFilterRule(rule)) } + return matches } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index cfeb65a1..3a59b25f 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -4,7 +4,6 @@ import ( "net/netip" "github.com/juanfont/headscale/hscontrol/policy/matcher" - policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 4efd1e01..5a9103e5 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -5,7 +5,6 @@ import ( "slices" "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/samber/lo" @@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf // AutoApproveRoutes approves any route that can be autoapproved from // the nodes perspective according to the given policy. // It reports true if any routes were approved. -// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes +// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes. func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { if pm == nil { return false diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9f2f7573..f19ac3d3 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -7,9 +7,8 @@ import ( "testing" "time" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) { } } } + func TestReduceRoutes(t *testing.T) { type args struct { node *types.Node diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 1825926f..9d838e56 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -13,9 +13,7 @@ import ( "tailscale.com/types/views" ) -var ( - ErrInvalidAction = errors.New("invalid action") -) +var ErrInvalidAction = errors.New("invalid action") // compileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. @@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules( var destPorts []tailcfg.NetPortRange for _, dest := range acl.Destinations { - ips, err := dest.Alias.Resolve(pol, users, nodes) + ips, err := dest.Resolve(pol, users, nodes) if err != nil { log.Trace().Err(err).Msgf("resolving destination ips") } @@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string { for _, pref := range ips.Prefixes() { out = append(out, pref.String()) } + return out } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index cbc34215..2f4be34e 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -4,19 +4,17 @@ import ( "encoding/json" "fmt" "net/netip" + "slices" "strings" "sync" "github.com/juanfont/headscale/hscontrol/policy/matcher" - - "slices" - "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/util/deephash" "tailscale.com/types/views" + "tailscale.com/util/deephash" ) type PolicyManager struct { @@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { pm.mu.Lock() defer pm.mu.Unlock() + return pm.filter, pm.matchers } @@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() pm.users = users + return pm.updateLocked() } @@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes + return pm.updateLocked() } @@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // cannot just lookup in the prefix map and have to check // if there is a "parent" prefix available. for prefix, approveAddrs := range pm.autoApproveMap { - // Check if prefix is larger (so containing) and then overlaps // the route to see if the node can approve a subset of an autoapprover if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index b3540e63..a91831ad 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -1,10 +1,10 @@ package v2 import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "testing" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/require" "gorm.io/gorm" diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 550287c2..c38d1991 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -6,9 +6,9 @@ import ( "errors" "fmt" "net/netip" - "strings" - "slices" + "strconv" + "strings" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { // Check if it's the wildcard port range if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 { - return json.Marshal(fmt.Sprintf("%s:*", alias)) + return json.Marshal(alias + ":*") } // Otherwise, format as "alias:ports" var ports []string for _, port := range a.Ports { if port.First == port.Last { - ports = append(ports, fmt.Sprintf("%d", port.First)) + ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) } else { ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last)) } @@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error { if err := u.Validate(); err != nil { return err } + return nil } @@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. return buildIPSetMultiErr(&ips, errs) } -// Group is a special string which is always prefixed with `group:` +// Group is a special string which is always prefixed with `group:`. type Group string func (g Group) Validate() error { @@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error { if err := g.Validate(); err != nil { return err } + return nil } @@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod return buildIPSetMultiErr(&ips, errs) } -// Tag is a special string which is always prefixed with `tag:` +// Tag is a special string which is always prefixed with `tag:`. type Tag string func (t Tag) Validate() error { @@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error { if err := t.Validate(); err != nil { return err } + return nil } @@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error { if err := h.Validate(); err != nil { return err } + return nil } @@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error { } *p = Prefix(addrPref) + return nil } @@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error { return err } *p = Prefix(pref) + return nil } @@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { if err := p.Validate(); err != nil { return err } + return nil } @@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild } } -// AutoGroup is a special string which is always prefixed with `autogroup:` +// AutoGroup is a special string which is always prefixed with `autogroup:`. type AutoGroup string const ( @@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { if err := ag.Validate(); err != nil { return err } + return nil } @@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } - if err := ve.Alias.Validate(); err != nil { + if err := ve.Validate(); err != nil { return err } default: return fmt.Errorf("type %T not supported", vs) } + return nil } @@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { return err } ve.Alias = ptr + return nil } @@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { for i, alias := range aliases { (*a)[i] = alias.Alias } + return nil } @@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I return ips, multierr.New(append(errs, err)...) } -// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer +// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer. func unmarshalPointer[T any]( b []byte, parseFunc func(string) (T, error), @@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { for i, autoApprover := range autoApprovers { (*aa)[i] = autoApprover.AutoApprover } + return nil } @@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { return err } ve.AutoApprover = ptr + return nil } @@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { return err } ve.Owner = ptr + return nil } @@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error { for i, owner := range owners { (*o)[i] = owner.Owner } + return nil } @@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) { case isGroup(s): return ptr.To(Group(s)), nil } + return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: - user (containing an "@") - group (starting with "group:") @@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { (*g)[group] = usernames } + return nil } @@ -1252,7 +1269,7 @@ type Policy struct { // We use the default JSON marshalling behavior provided by the Go runtime. var ( - // TODO(kradalby): Add these checks for tagOwners and autoApprovers + // TODO(kradalby): Add these checks for tagOwners and autoApprovers. autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged} autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} @@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) } if !slices.Contains(autogroupForSrc, *src) { @@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) } if !slices.Contains(autogroupForSSHSrc, *src) { @@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error { } if dst.Is(AutoGroupInternet) { - return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) } if !slices.Contains(autogroupForSSHDst, *dst) { @@ -1360,14 +1377,14 @@ func (p *Policy) validate() error { for _, acl := range p.ACLs { for _, src := range acl.Sources { - switch src.(type) { + switch src := src.(type) { case *Host: - h := src.(*Host) + h := src if !p.Hosts.exist(*h) { errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) } case *AutoGroup: - ag := src.(*AutoGroup) + ag := src if err := validateAutogroupSupported(ag); err != nil { errs = append(errs, err) @@ -1379,12 +1396,12 @@ func (p *Policy) validate() error { continue } case *Group: - g := src.(*Group) + g := src if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := src.(*Tag) + tagOwner := src if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1440,9 +1457,9 @@ func (p *Policy) validate() error { } for _, src := range ssh.Sources { - switch src.(type) { + switch src := src.(type) { case *AutoGroup: - ag := src.(*AutoGroup) + ag := src if err := validateAutogroupSupported(ag); err != nil { errs = append(errs, err) @@ -1454,21 +1471,21 @@ func (p *Policy) validate() error { continue } case *Group: - g := src.(*Group) + g := src if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := src.(*Tag) + tagOwner := src if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } } } for _, dst := range ssh.Destinations { - switch dst.(type) { + switch dst := dst.(type) { case *AutoGroup: - ag := dst.(*AutoGroup) + ag := dst if err := validateAutogroupSupported(ag); err != nil { errs = append(errs, err) continue @@ -1479,7 +1496,7 @@ func (p *Policy) validate() error { continue } case *Tag: - tagOwner := dst.(*Tag) + tagOwner := dst if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1489,9 +1506,9 @@ func (p *Policy) validate() error { for _, tagOwners := range p.TagOwners { for _, tagOwner := range tagOwners { - switch tagOwner.(type) { + switch tagOwner := tagOwner.(type) { case *Group: - g := tagOwner.(*Group) + g := tagOwner if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } @@ -1501,14 +1518,14 @@ func (p *Policy) validate() error { for _, approvers := range p.AutoApprovers.Routes { for _, approver := range approvers { - switch approver.(type) { + switch approver := approver.(type) { case *Group: - g := approver.(*Group) + g := approver if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := approver.(*Tag) + tagOwner := approver if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1517,14 +1534,14 @@ func (p *Policy) validate() error { } for _, approver := range p.AutoApprovers.ExitNode { - switch approver.(type) { + switch approver := approver.(type) { case *Group: - g := approver.(*Group) + g := approver if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } case *Tag: - tagOwner := approver.(*Tag) + tagOwner := approver if err := p.TagOwners.Contains(tagOwner); err != nil { errs = append(errs, err) } @@ -1536,6 +1553,7 @@ func (p *Policy) validate() error { } p.validated = true + return nil } @@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { ) } } + return nil } @@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { ) } } + return nil } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 8cddfeba..4aca150e 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -5,13 +5,13 @@ import ( "net/netip" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go4.org/netipx" @@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) { // Marshal the policy to JSON marshalled, err := json.MarshalIndent(policy, "", " ") require.NoError(t, err) - + // Make sure all expected fields are present in the JSON jsonString := string(marshalled) assert.Contains(t, jsonString, "group:example") @@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) { assert.Contains(t, jsonString, "accept") assert.Contains(t, jsonString, "tcp") assert.Contains(t, jsonString, "80") - + // Unmarshal back to verify round trip var roundTripped Policy err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) - + // Compare the original and round-tripped policies - cmps := append(util.Comparers, + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { return x == y }), cmpopts.IgnoreUnexported(Policy{}), cmpopts.EquateEmpty(), ) - + if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" { t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) } @@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) { }, } - cmps := append(util.Comparers, + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { return x == y }), cmpopts.IgnoreUnexported(Policy{}), ) - + // For round-trip testing, we'll normalize the policies before comparing for _, tt := range tests { @@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) { } else if !strings.Contains(err.Error(), tt.wantErr) { t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr) } + return // Skip the rest of the test if we expected an error } @@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) { if err != nil { t.Fatalf("round-trip unmarshalling: %v", err) } - + // Add EquateEmpty to handle nil vs empty maps/slices - roundTripCmps := append(cmps, + roundTripCmps := append(cmps, cmpopts.EquateEmpty(), cmpopts.IgnoreUnexported(Policy{}), ) @@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet { builder.AddPrefix(mp(p)) } ipSet, _ := builder.IPSet() + return ipSet } diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index d1645071..2084b22f 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) { expected []tailcfg.PortRange err string }{ - {"80", []tailcfg.PortRange{{80, 80}}, ""}, - {"80-90", []tailcfg.PortRange{{80, 90}}, ""}, - {"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""}, - {"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""}, + {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""}, + {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""}, + {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""}, + {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""}, {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, {"80-", nil, "invalid port range format"}, {"-90", nil, "invalid port range format"}, diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 67eb8d1f..f65d9122 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } tsaddr.SortPrefixes(routes) + return routes } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 0d8a2a8e..b754e594 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) { if err != nil { return types.NodeView{}, err } + return node.View(), nil } @@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er if err != nil { return types.NodeView{}, err } + return node.View(), nil } @@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey( if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) { nodeToRegister.Expiry = ®Req.Expiry } else if !regReq.Expiry.IsZero() { - // If client is sending an expired time (e.g., after logout), + // If client is sending an expired time (e.g., after logout), // don't set expiry so the node won't be considered expired log.Debug(). Time("requested_expiry", regReq.Expiry). diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 82e82d78..1a949173 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -2,6 +2,7 @@ package hscontrol import ( "context" + "errors" "fmt" "net/http" "os" @@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. certDomains := tsNode.CertDomains() if len(certDomains) == 0 { - return fmt.Errorf("no cert domains available for HTTPS") + return errors.New("no cert domains available for HTTPS") } base := "https://" + certDomains[0] go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s logf("TailSQL started") <-ctx.Done() logf("TailSQL shutting down...") + return tsNode.Close() } diff --git a/hscontrol/templates/apple.go b/hscontrol/templates/apple.go index 99b1cc8e..84928ed5 100644 --- a/hscontrol/templates/apple.go +++ b/hscontrol/templates/apple.go @@ -62,7 +62,7 @@ func Apple(url string) *elem.Element { ), elem.Pre(nil, elem.Code(nil, - elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)), + elem.Text("tailscale login --login-server "+url), ), ), headerTwo("GUI"), @@ -143,10 +143,7 @@ func Apple(url string) *elem.Element { elem.Code( nil, elem.Text( - fmt.Sprintf( - `defaults write io.tailscale.ipn.macos ControlURL %s`, - url, - ), + "defaults write io.tailscale.ipn.macos ControlURL "+url, ), ), ), @@ -155,10 +152,7 @@ func Apple(url string) *elem.Element { elem.Code( nil, elem.Text( - fmt.Sprintf( - `defaults write io.tailscale.ipn.macsys ControlURL %s`, - url, - ), + "defaults write io.tailscale.ipn.macsys ControlURL "+url, ), ), ), diff --git a/hscontrol/templates/windows.go b/hscontrol/templates/windows.go index 680d6655..ecf7d77c 100644 --- a/hscontrol/templates/windows.go +++ b/hscontrol/templates/windows.go @@ -1,8 +1,6 @@ package templates import ( - "fmt" - "github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go/attrs" ) @@ -31,7 +29,7 @@ func Windows(url string) *elem.Element { ), elem.Pre(nil, elem.Code(nil, - elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)), + elem.Text("tailscale login --login-server "+url), ), ), ), diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 69c298b9..51e11757 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID { if err != nil { panic(err) } + return rid } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 03c1e7ea..1e35303e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error { log.Warn().Msg("No config file found, using defaults") return nil } + return fmt.Errorf("fatal error reading config file: %w", err) } @@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) { } if prefix4 == nil && prefix6 == nil { - return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") } allocStr := viper.GetString("prefixes.allocation") @@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error { s := len(serverDomainParts) b := len(baseDomainParts) - for i := range len(baseDomainParts) { + for i := range baseDomainParts { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { return nil } diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 7ae3db59..6b9fc2ef 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) { assert.Equal(t, "trace", viper.GetString("log.level")) assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4")) assert.False(t, viper.GetBool("database.sqlite.write_ahead_log")) + return nil, nil }, want: nil, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 11383950..32f0274c 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -28,8 +28,10 @@ var ( ErrNodeUserHasNoName = errors.New("node user has no name") ) -type NodeID uint64 -type NodeIDs []NodeID +type ( + NodeID uint64 + NodeIDs []NodeID +) func (n NodeIDs) Len() int { return len(n) } func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } @@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool { return true } } + return false } @@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool { // and therefore should not be treated as a // user owned device. // Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) +// via CLI ("forced tags" and preauthkeys). func (node *Node) IsTagged() bool { if len(node.ForcedTags) > 0 { return true @@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool { // HasTag reports if a node has a given tag. // Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) +// via CLI ("forced tags" and preauthkeys). func (node *Node) HasTag(tag string) bool { return slices.Contains(node.Tags(), tag) } @@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string { sb.WriteString(node.DebugString()) sb.WriteString("\n") } + return sb.String() } @@ -590,6 +594,7 @@ func (node Node) DebugString() string { fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes()) sb.WriteString("\n") + return sb.String() } @@ -689,7 +694,7 @@ func (v NodeView) Tags() []string { // and therefore should not be treated as a // user owned device. // Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) +// via CLI ("forced tags" and preauthkeys). func (v NodeView) IsTagged() bool { if !v.Valid() { return false @@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC // GetFQDN returns the fully qualified domain name for the node. func (v NodeView) GetFQDN(baseDomain string) (string, error) { if !v.Valid() { - return "", fmt.Errorf("failed to create valid FQDN: node view is invalid") + return "", errors.New("failed to create valid FQDN: node view is invalid") } return v.ж.GetFQDN(baseDomain) } @@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string { } return v.ж.IPsAsString() } - diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index c7261587..f6d1d027 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -2,7 +2,6 @@ package types import ( "fmt" - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "strings" "testing" @@ -10,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" "tailscale.com/types/key" diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 51c474eb..e47666ff 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -11,7 +11,7 @@ import ( type PAKError string func (e PAKError) Error() string { return string(e) } -func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) } +func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) } // PreAuthKey describes a pre-authorization key usable in a particular user. type PreAuthKey struct { diff --git a/hscontrol/types/preauth_key_test.go b/hscontrol/types/preauth_key_test.go index 3f7eb269..4ab1c717 100644 --- a/hscontrol/types/preauth_key_test.go +++ b/hscontrol/types/preauth_key_test.go @@ -1,6 +1,7 @@ package types import ( + "errors" "testing" "time" @@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) { if err == nil { t.Errorf("expected error but got none") } else { - httpErr, ok := err.(PAKError) + var httpErr PAKError + ok := errors.As(err, &httpErr) if !ok { t.Errorf("expected HTTPError but got %T", err) } else { diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 6cd2c41a..69377b95 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string { // - Remove empty path segments // - For non-URL identifiers, it joins non-empty segments with a single slash // - Returns empty string for identifiers with only slashes -// - Normalize URL schemes to lowercase +// - Normalize URL schemes to lowercase. func CleanIdentifier(identifier string) string { if identifier == "" { return identifier @@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, part) } } - + if len(cleanParts) == 0 { u.Path = "" } else { @@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string { } // Ensure scheme is lowercase u.Scheme = strings.ToLower(u.Scheme) + return u.String() } @@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string { if len(cleanParts) == 0 { return "" } + return strings.Join(cleanParts, "/") } diff --git a/hscontrol/types/version.go b/hscontrol/types/version.go index e84087fb..7fe23250 100644 --- a/hscontrol/types/version.go +++ b/hscontrol/types/version.go @@ -1,4 +1,6 @@ package types -var Version = "dev" -var GitCommitHash = "dev" +var ( + Version = "dev" + GitCommitHash = "dev" +) diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index 3a08fc3a..65194720 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "regexp" + "strconv" "strings" "unicode" @@ -21,8 +22,10 @@ const ( LabelHostnameLength = 63 ) -var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") -var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") +var ( + invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") + invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") +) var ErrInvalidUserName = errors.New("invalid user name") @@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.) rdnsSlice := []string{} for i := lastOctet - 1; i >= 0; i-- { - rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i])) + rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10)) } rdnsSlice = append(rdnsSlice, "in-addr.arpa.") rdnsBase := strings.Join(rdnsSlice, ".") @@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) { prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".") - return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix)) + return dnsname.ToFQDN(prefix + ".ip6.arpa") } var fqdns []dnsname.FQDN diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go index 12f646b1..936b374c 100644 --- a/hscontrol/util/log.go +++ b/hscontrol/util/log.go @@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq "rowsAffected": rowsAffected, } - if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) { + if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) { l.Logger.Error().Err(err).Fields(fields).Msgf("") return } diff --git a/hscontrol/util/net.go b/hscontrol/util/net.go index 0d6b4412..e28bb00b 100644 --- a/hscontrol/util/net.go +++ b/hscontrol/util/net.go @@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet { internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16")) theInternetSet, _ := internetBuilder.IPSet() + return theInternetSet }) diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 4f6660be..a44a6e97 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { } type TraceroutePath struct { - // Hop is the current jump in the total traceroute. - Hop int + // Hop is the current jump in the total traceroute. + Hop int - // Hostname is the resolved hostname or IP address identifying the jump - Hostname string + // Hostname is the resolved hostname or IP address identifying the jump + Hostname string - // IP is the IP address of the jump - IP netip.Addr + // IP is the IP address of the jump + IP netip.Addr - // Latencies is a list of the latencies for this jump - Latencies []time.Duration + // Latencies is a list of the latencies for this jump + Latencies []time.Duration } type Traceroute struct { - // Hostname is the resolved hostname or IP address identifying the target - Hostname string + // Hostname is the resolved hostname or IP address identifying the target + Hostname string - // IP is the IP address of the target - IP netip.Addr + // IP is the IP address of the target + IP netip.Addr - // Route is the path taken to reach the target if successful. The list is ordered by the path taken. - Route []TraceroutePath + // Route is the path taken to reach the target if successful. The list is ordered by the path taken. + Route []TraceroutePath - // Success indicates if the traceroute was successful. - Success bool + // Success indicates if the traceroute was successful. + Success bool - // Err contains an error if the traceroute was not successful. - Err error + // Err contains an error if the traceroute was not successful. + Err error } -// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct +// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct. func ParseTraceroute(output string) (Traceroute, error) { lines := strings.Split(strings.TrimSpace(output), "\n") if len(lines) < 1 { @@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Parse each hop line - hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) + hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?") for i := 1; i < len(lines); i++ { matches := hopRegex.FindStringSubmatch(lines[i]) diff --git a/integration/acl_test.go b/integration/acl_test.go index 193b6669..3aef521e 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1077,7 +1077,6 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1213,7 +1212,6 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { func TestACLAutogroupMember(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := aclScenario(t, &policyv2.Policy{ @@ -1271,7 +1269,6 @@ func TestACLAutogroupMember(t *testing.T) { func TestACLAutogroupTagged(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := aclScenario(t, &policyv2.Policy{ diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index d54ff593..061c2595 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -3,12 +3,11 @@ package integration import ( "fmt" "net/netip" + "slices" "strconv" "testing" "time" - "slices" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -19,7 +18,6 @@ import ( func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() for _, https := range []bool{true, false} { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { @@ -66,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -87,7 +85,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("all clients logged out") listNodes, err = headscale.ListNodes() - require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + require.Len(t, listNodes, nodeCountBeforeLogout) for _, node := range listNodes { assertLastSeenSet(t, node) @@ -99,26 +97,48 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - time.Sleep(5 * time.Minute) - } + userMap, err := headscale.MapUsers() + assertNoErr(t, err) - userMap, err := headscale.MapUsers() - assertNoErr(t, err) - - for _, userName := range spec.Users { - key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) - if err != nil { - t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + // Create auth keys once outside the retry loop + userKeys := make(map[string]string) + for _, userName := range spec.Users { + key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) + assertNoErr(t, err) + userKeys[userName] = key.GetKey() } - err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) - if err != nil { - t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + // Wait for the 2-minute noise dial memory to expire + // The Tailscale commit shows clients remember noise dials for 2 minutes + t.Logf("Waiting 2.5 minutes for Tailscale noise dial memory to expire...") + time.Sleep(2*time.Minute + 30*time.Second) + + // Wait for clients to be ready to reconnect over HTTP after HTTPS + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, userName := range spec.Users { + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), userKeys[userName]) + assert.NoError(ct, err, "Client should be able to reconnect over HTTP") + } + }, 6*time.Minute, 30*time.Second) + } else { + userMap, err := headscale.MapUsers() + assertNoErr(t, err) + + for _, userName := range spec.Users { + key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } } } listNodes, err = headscale.ListNodes() - require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + require.Len(t, listNodes, nodeCountBeforeLogout) for _, node := range listNodes { assertLastSeenSet(t, node) @@ -155,18 +175,17 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } listNodes, err = headscale.ListNodes() - require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + require.Len(t, listNodes, nodeCountBeforeLogout) for _, node := range listNodes { assertLastSeenSet(t, node) } }) } - } func assertLastSeenSet(t *testing.T, node *v1.Node) { assert.NotNil(t, node) - assert.NotNil(t, node.LastSeen) + assert.NotNil(t, node.GetLastSeen()) } // This test will first log in two sets of nodes to two sets of users, then @@ -175,7 +194,6 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) { // still has nodes, but they are not connected. func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -204,7 +222,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -259,7 +277,6 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { IntegrationSkip(t) - t.Parallel() for _, https := range []bool{true, false} { t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { @@ -303,7 +320,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -325,32 +342,62 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - time.Sleep(5 * time.Minute) - } - - userMap, err := headscale.MapUsers() - assertNoErr(t, err) - - for _, userName := range spec.Users { - key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) - if err != nil { - t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) - } - - // Expire the key so it can't be used - _, err = headscale.Execute( - []string{ - "headscale", - "preauthkeys", - "--user", - strconv.FormatUint(userMap[userName].GetId(), 10), - "expire", - key.Key, - }) + userMap, err := headscale.MapUsers() assertNoErr(t, err) - err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) - assert.ErrorContains(t, err, "authkey expired") + // Create and expire auth keys once outside the retry loop + userExpiredKeys := make(map[string]string) + for _, userName := range spec.Users { + key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) + assertNoErr(t, err) + + // Expire the key so it can't be used + _, err = headscale.Execute( + []string{ + "headscale", + "preauthkeys", + "--user", + strconv.FormatUint(userMap[userName].GetId(), 10), + "expire", + key.GetKey(), + }) + assertNoErr(t, err) + userExpiredKeys[userName] = key.GetKey() + } + + // Wait for clients to be ready to reconnect over HTTP after HTTPS + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, userName := range spec.Users { + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), userExpiredKeys[userName]) + assert.Error(ct, err, "Should get error when using expired key") + assert.Contains(ct, err.Error(), "authkey expired") + } + }, 6*time.Minute, 30*time.Second) + } else { + userMap, err := headscale.MapUsers() + assertNoErr(t, err) + + for _, userName := range spec.Users { + key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + // Expire the key so it can't be used + _, err = headscale.Execute( + []string{ + "headscale", + "preauthkeys", + "--user", + strconv.FormatUint(userMap[userName].GetId(), 10), + "expire", + key.GetKey(), + }) + assertNoErr(t, err) + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + assert.ErrorContains(t, err, "authkey expired") + } } }) } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 53c74577..d118b643 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1,14 +1,12 @@ package integration import ( - "fmt" + "maps" "net/netip" "sort" "testing" "time" - "maps" - "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -21,7 +19,6 @@ import ( func TestOIDCAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Logins to MockOIDC is served by a queue with a strict order, // if we use more than one node per user, the order of the logins @@ -119,7 +116,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { // This test is really flaky. func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { IntegrationSkip(t) - t.Parallel() shortAccessTTL := 5 * time.Minute @@ -174,9 +170,13 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { // of safety reasons) before checking if the clients have logged out. // The Wait function can't do it itself as it has an upper bound of 1 // min. - time.Sleep(shortAccessTTL + 10*time.Second) - - assertTailscaleNodesLogout(t, allClients) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + } + }, shortAccessTTL+10*time.Second, 5*time.Second) } func TestOIDC024UserCreation(t *testing.T) { @@ -295,9 +295,7 @@ func TestOIDC024UserCreation(t *testing.T) { spec := ScenarioSpec{ NodesPerUser: 1, } - for _, user := range tt.cliUsers { - spec.Users = append(spec.Users, user) - } + spec.Users = append(spec.Users, tt.cliUsers...) for _, user := range tt.oidcUsers { spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified)) @@ -350,7 +348,6 @@ func TestOIDC024UserCreation(t *testing.T) { func TestOIDCAuthenticationWithPKCE(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Single user with one node for testing PKCE flow spec := ScenarioSpec{ @@ -402,7 +399,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { func TestOIDCReloginSameNodeNewUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Create no nodes and no users scenario, err := NewScenario(ScenarioSpec{ @@ -440,7 +436,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assertNoErr(t, err) - assert.Len(t, listUsers, 0) + assert.Empty(t, listUsers) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) assertNoErr(t, err) @@ -482,7 +478,13 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { err = ts.Logout() assertNoErr(t, err) - time.Sleep(5 * time.Second) + // Wait for logout to complete and then do second logout + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + }, 5*time.Second, 1*time.Second) // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it @@ -530,16 +532,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Machine key is the same as the "machine" has not changed, // but Node key is not as it is a new node - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) - assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) - assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) + assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) + assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) // Log out user2, and log into user1, no new node should be created, // the node should now "become" node1 again err = ts.Logout() assertNoErr(t, err) - time.Sleep(5 * time.Second) + // Wait for logout to complete and then do second logout + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + }, 5*time.Second, 1*time.Second) // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it @@ -588,24 +596,24 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Validate that the machine we had when we logged in the first time, has the same // machine key, but a different ID than the newly logged in version of the same // machine. - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) - assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey) - assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id) - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) - assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id) - assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) + assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) + assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) + assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) + assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) // Even tho we are logging in again with the same user, the previous key has been expired // and a new one has been generated. The node entry in the database should be the same // as the user + machinekey still matches. - assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey) - assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey) - assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id) + assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) + assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) + assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) // The "logged back in" machine should have the same machinekey but a different nodekey // than the version logged in with a different user. - assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey) - assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) + assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) + assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) } func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { @@ -623,7 +631,7 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { return mockoidc.MockUser{ Subject: username, PreferredUsername: username, - Email: fmt.Sprintf("%s@headscale.net", username), + Email: username + "@headscale.net", EmailVerified: emailVerified, } } diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 64cace7b..83413e0d 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -2,9 +2,8 @@ package integration import ( "net/netip" - "testing" - "slices" + "testing" "github.com/juanfont/headscale/integration/hsic" "github.com/samber/lo" @@ -55,7 +54,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -95,7 +93,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { assertNoErrGetHeadscale(t, err) listNodes, err := headscale.ListNodes() - assert.Equal(t, len(listNodes), len(allClients)) + assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -140,7 +138,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) listNodes, err = headscale.ListNodes() - require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + require.Len(t, listNodes, nodeCountBeforeLogout) t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) for _, client := range allClients { diff --git a/integration/cli_test.go b/integration/cli_test.go index 2cff0500..fd9c49a7 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -18,8 +18,8 @@ import ( "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tailscale.com/tailcfg" "golang.org/x/exp/slices" + "tailscale.com/tailcfg" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -30,7 +30,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul err = json.Unmarshal([]byte(str), result) if err != nil { - return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str) + return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str) } return nil @@ -48,7 +48,6 @@ func sortWithID[T GRPCSortable](a, b T) int { func TestUserCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"user1", "user2"}, @@ -184,7 +183,7 @@ func TestUserCommand(t *testing.T) { "--identifier=1", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterIDDelete []*v1.User @@ -222,7 +221,7 @@ func TestUserCommand(t *testing.T) { "--name=newname", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterNameDelete []v1.User @@ -238,12 +237,11 @@ func TestUserCommand(t *testing.T) { ) assertNoErr(t, err) - require.Len(t, listAfterNameDelete, 0) + require.Empty(t, listAfterNameDelete) } func TestPreAuthKeyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "preauthkeyspace" count := 3 @@ -347,7 +345,7 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) + assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags()) } // Test key expiry @@ -386,7 +384,6 @@ func TestPreAuthKeyCommand(t *testing.T) { func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "pre-auth-key-without-exp-user" spec := ScenarioSpec{ @@ -448,7 +445,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "pre-auth-key-reus-ephm-user" spec := ScenarioSpec{ @@ -524,7 +520,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() user1 := "user1" user2 := "user2" @@ -575,7 +570,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { assertNoErr(t, err) listNodes, err := headscale.ListNodes() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, listNodes, 1) assert.Equal(t, user1, listNodes[0].GetUser().GetName()) @@ -613,7 +608,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { } listNodes, err = headscale.ListNodes() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, listNodes, 2) assert.Equal(t, user1, listNodes[0].GetUser().GetName()) assert.Equal(t, user2, listNodes[1].GetUser().GetName()) @@ -621,7 +616,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { func TestApiKeyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() count := 5 @@ -653,7 +647,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -672,7 +666,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAPIKeys, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listedAPIKeys, 5) @@ -728,7 +722,7 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - assert.Nil(t, err) + assert.NoError(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } @@ -744,7 +738,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAfterExpireAPIKeys, ) - assert.Nil(t, err) + assert.NoError(t, err) for index := range listedAfterExpireAPIKeys { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { @@ -770,7 +764,7 @@ func TestApiKeyCommand(t *testing.T) { "--prefix", listedAPIKeys[0].GetPrefix(), }) - assert.Nil(t, err) + assert.NoError(t, err) var listedAPIKeysAfterDelete []v1.ApiKey err = executeAndUnmarshal(headscale, @@ -783,14 +777,13 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAPIKeysAfterDelete, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listedAPIKeysAfterDelete, 4) } func TestNodeTagCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"user1"}, @@ -811,7 +804,7 @@ func TestNodeTagCommand(t *testing.T) { types.MustRegistrationID().String(), } nodes := make([]*v1.Node, len(regIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -829,7 +822,7 @@ func TestNodeTagCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -847,7 +840,7 @@ func TestNodeTagCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) nodes[index] = &node } @@ -866,7 +859,7 @@ func TestNodeTagCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) @@ -894,7 +887,7 @@ func TestNodeTagCommand(t *testing.T) { }, &resultMachines, ) - assert.Nil(t, err) + assert.NoError(t, err) found := false for _, node := range resultMachines { if node.GetForcedTags() != nil { @@ -905,19 +898,15 @@ func TestNodeTagCommand(t *testing.T) { } } } - assert.Equal( + assert.True( t, - true, found, "should find a node with the tag 'tag:test' in the list of nodes", ) } - - func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() tests := []struct { name string @@ -1024,7 +1013,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { }, &resultMachines, ) - assert.Nil(t, err) + assert.NoError(t, err) found := false for _, node := range resultMachines { if tags := node.GetValidTags(); tags != nil { @@ -1043,7 +1032,6 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { func TestNodeCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"node-user", "other-user"}, @@ -1067,7 +1055,7 @@ func TestNodeCommand(t *testing.T) { types.MustRegistrationID().String(), } nodes := make([]*v1.Node, len(regIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1085,7 +1073,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1103,7 +1091,7 @@ func TestNodeCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) nodes[index] = &node } @@ -1123,7 +1111,7 @@ func TestNodeCommand(t *testing.T) { }, &listAll, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAll, 5) @@ -1144,7 +1132,7 @@ func TestNodeCommand(t *testing.T) { types.MustRegistrationID().String(), } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range otherUserRegIDs { _, err := headscale.Execute( @@ -1162,7 +1150,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1180,7 +1168,7 @@ func TestNodeCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) otherUserMachines[index] = &node } @@ -1200,7 +1188,7 @@ func TestNodeCommand(t *testing.T) { }, &listAllWithotherUser, ) - assert.Nil(t, err) + assert.NoError(t, err) // All nodes, nodes + otherUser assert.Len(t, listAllWithotherUser, 7) @@ -1226,7 +1214,7 @@ func TestNodeCommand(t *testing.T) { }, &listOnlyotherUserMachineUser, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listOnlyotherUserMachineUser, 2) @@ -1258,7 +1246,7 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - assert.Nil(t, err) + assert.NoError(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node @@ -1275,14 +1263,13 @@ func TestNodeCommand(t *testing.T) { }, &listOnlyMachineUserAfterDelete, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listOnlyMachineUserAfterDelete, 4) } func TestNodeExpireCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"node-expire-user"}, @@ -1323,7 +1310,7 @@ func TestNodeExpireCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1341,7 +1328,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) nodes[index] = &node } @@ -1360,7 +1347,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &listAll, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAll, 5) @@ -1377,10 +1364,10 @@ func TestNodeExpireCommand(t *testing.T) { "nodes", "expire", "--identifier", - fmt.Sprintf("%d", listAll[idx].GetId()), + strconv.FormatUint(listAll[idx].GetId(), 10), }, ) - assert.Nil(t, err) + assert.NoError(t, err) } var listAllAfterExpiry []v1.Node @@ -1395,7 +1382,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &listAllAfterExpiry, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAllAfterExpiry, 5) @@ -1408,7 +1395,6 @@ func TestNodeExpireCommand(t *testing.T) { func TestNodeRenameCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"node-rename-command"}, @@ -1432,7 +1418,7 @@ func TestNodeRenameCommand(t *testing.T) { types.MustRegistrationID().String(), } nodes := make([]*v1.Node, len(regIDs)) - assert.Nil(t, err) + assert.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1487,7 +1473,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAll, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAll, 5) @@ -1504,11 +1490,11 @@ func TestNodeRenameCommand(t *testing.T) { "nodes", "rename", "--identifier", - fmt.Sprintf("%d", listAll[idx].GetId()), + strconv.FormatUint(listAll[idx].GetId(), 10), fmt.Sprintf("newnode-%d", idx+1), }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Contains(t, res, "Node renamed") } @@ -1525,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAllAfterRename, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAllAfterRename, 5) @@ -1542,7 +1528,7 @@ func TestNodeRenameCommand(t *testing.T) { "nodes", "rename", "--identifier", - fmt.Sprintf("%d", listAll[4].GetId()), + strconv.FormatUint(listAll[4].GetId(), 10), strings.Repeat("t", 64), }, ) @@ -1560,7 +1546,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAllAfterRenameAttempt, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, listAllAfterRenameAttempt, 5) @@ -1573,7 +1559,6 @@ func TestNodeRenameCommand(t *testing.T) { func TestNodeMoveCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"old-user", "new-user"}, @@ -1610,7 +1595,7 @@ func TestNodeMoveCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1628,13 +1613,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint64(1), node.GetId()) assert.Equal(t, "nomad-node", node.GetName()) - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Equal(t, "old-user", node.GetUser().GetName()) - nodeID := fmt.Sprintf("%d", node.GetId()) + nodeID := strconv.FormatUint(node.GetId(), 10) err = executeAndUnmarshal( headscale, @@ -1651,9 +1636,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, node.GetUser().GetName(), "new-user") + assert.Equal(t, "new-user", node.GetUser().GetName()) var allNodes []v1.Node err = executeAndUnmarshal( @@ -1667,13 +1652,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &allNodes, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, allNodes, 1) assert.Equal(t, allNodes[0].GetId(), node.GetId()) assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) - assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") + assert.Equal(t, "new-user", allNodes[0].GetUser().GetName()) _, err = headscale.Execute( []string{ @@ -1693,7 +1678,7 @@ func TestNodeMoveCommand(t *testing.T) { err, "user not found", ) - assert.Equal(t, node.GetUser().GetName(), "new-user") + assert.Equal(t, "new-user", node.GetUser().GetName()) err = executeAndUnmarshal( headscale, @@ -1710,9 +1695,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Equal(t, "old-user", node.GetUser().GetName()) err = executeAndUnmarshal( headscale, @@ -1729,14 +1714,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Equal(t, "old-user", node.GetUser().GetName()) } func TestPolicyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ Users: []string{"user1"}, @@ -1817,7 +1801,6 @@ func TestPolicyCommand(t *testing.T) { func TestPolicyBrokenConfigCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 23879d56..4a5e52ae 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -1,7 +1,6 @@ package integration import ( - "context" "fmt" "net" "strconv" @@ -104,7 +103,7 @@ func DERPVerify( defer c.Close() var result error - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(t.Context()); err != nil { result = fmt.Errorf("client Connect: %w", err) } if m, err := c.Recv(); err != nil { diff --git a/integration/dns_test.go b/integration/dns_test.go index ef6c479b..456895cc 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -15,7 +15,6 @@ import ( func TestResolveMagicDNS(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -49,7 +48,7 @@ func TestResolveMagicDNS(t *testing.T) { // It is safe to ignore this error as we handled it when caching it peerFQDN, _ := peer.FQDN() - assert.Equal(t, fmt.Sprintf("%s.headscale.net.", peer.Hostname()), peerFQDN) + assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN) command := []string{ "tailscale", @@ -85,7 +84,6 @@ func TestResolveMagicDNS(t *testing.T) { func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -222,12 +220,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { _, err = hs.Execute([]string{"rm", erPath}) assertNoErr(t, err) - time.Sleep(2 * time.Second) - // The same paths should still be available as it is not cleared on delete. - for _, client := range allClients { - assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9") - } + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"}) + assert.NoError(ct, err) + assert.Contains(ct, result, "9.9.9.9") + } + }, 10*time.Second, 1*time.Second) // Write a new file, the backoff mechanism should make the filewatcher pick it up // again. diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index f8bbde5f..dc8391d7 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -33,26 +33,27 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { } // GenerateRunID creates a unique run identifier with timestamp and random hash. -// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3) +// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3). func GenerateRunID() string { now := time.Now() timestamp := now.Format("20060102-150405") - + // Add a short random hash to ensure uniqueness randomHash := util.MustGenerateRandomStringDNSSafe(6) + return fmt.Sprintf("%s-%s", timestamp, randomHash) } // ExtractRunIDFromContainerName extracts the run ID from container name. -// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH" +// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH". func ExtractRunIDFromContainerName(containerName string) string { parts := strings.Split(containerName, "-") if len(parts) >= 3 { // Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH) return strings.Join(parts[len(parts)-3:], "-") } - - panic(fmt.Sprintf("unexpected container name format: %s", containerName)) + + panic("unexpected container name format: " + containerName) } // IsRunningInContainer checks if the current process is running inside a Docker container. @@ -62,4 +63,4 @@ func IsRunningInContainer() bool { // This could be improved with more robust detection if needed _, err := os.Stat("/.dockerenv") return err == nil -} \ No newline at end of file +} diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index e77b7cb8..e4b39efb 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -30,7 +30,7 @@ func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption { }) } -// buffer is a goroutine safe bytes.buffer +// buffer is a goroutine safe bytes.buffer. type buffer struct { store bytes.Buffer mutex sync.Mutex @@ -58,8 +58,8 @@ func ExecuteCommand( env []string, options ...ExecuteCommandOption, ) (string, string, error) { - var stdout = buffer{} - var stderr = buffer{} + stdout := buffer{} + stderr := buffer{} execConfig := ExecuteCommandConfig{ timeout: dockerExecuteTimeout, diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index 857a5def..dd6c6978 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -159,7 +159,6 @@ func New( }, } - if dsic.workdir != "" { runOptions.WorkingDir = dsic.workdir } @@ -192,7 +191,7 @@ func New( } // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(runOptions, "derp") - + container, err = pool.BuildAndRunWithBuildOptions( buildOptions, runOptions, diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index ca4e8a14..b1d947cd 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -2,13 +2,13 @@ package integration import ( "strings" - "tailscale.com/tailcfg" - "tailscale.com/types/key" "testing" "time" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "tailscale.com/tailcfg" + "tailscale.com/types/key" ) type ClientsSpec struct { @@ -71,9 +71,9 @@ func TestDERPServerWebsocketScenario(t *testing.T) { NodesPerUser: 1, Users: []string{"user1", "user2", "user3"}, Networks: map[string][]string{ - "usernet1": []string{"user1"}, - "usernet2": []string{"user2"}, - "usernet3": []string{"user3"}, + "usernet1": {"user1"}, + "usernet2": {"user2"}, + "usernet3": {"user3"}, }, } @@ -106,7 +106,6 @@ func derpServerScenario( furtherAssertions ...func(*Scenario), ) { IntegrationSkip(t) - // t.Parallel() scenario, err := NewScenario(spec) assertNoErr(t, err) diff --git a/integration/general_test.go b/integration/general_test.go index 292eb5ca..c60c2f46 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -26,7 +26,6 @@ import ( func TestPingAllByIP(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -68,7 +67,6 @@ func TestPingAllByIP(t *testing.T) { func TestPingAllByIPPublicDERP(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -118,7 +116,6 @@ func TestEphemeralInAlternateTimezone(t *testing.T) { func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -191,7 +188,6 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { // deleted by accident if they are still online and active. func TestEphemeral2006DeletedTooQuickly(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -260,18 +256,21 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { // Wait a bit and bring up the clients again before the expiry // time of the ephemeral nodes. // Nodes should be able to reconnect and work fine. - time.Sleep(30 * time.Second) - for _, client := range allClients { err := client.Up() if err != nil { t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) } } - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - success = pingAllHelper(t, allClients, allAddrs) + // Wait for clients to sync and be able to ping each other after reconnection + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = scenario.WaitForTailscaleSync() + assert.NoError(ct, err) + + success = pingAllHelper(t, allClients, allAddrs) + assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping") + }, 60*time.Second, 2*time.Second) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) // Take down all clients, this should start an expiry timer for each. @@ -284,7 +283,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { // This time wait for all of the nodes to expire and check that they are no longer // registered. - time.Sleep(3 * time.Minute) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, userName := range spec.Users { + nodes, err := headscale.ListNodes(userName) + assert.NoError(ct, err) + assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName) + } + }, 4*time.Minute, 10*time.Second) for _, userName := range spec.Users { nodes, err := headscale.ListNodes(userName) @@ -305,7 +310,6 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -341,20 +345,6 @@ func TestPingAllByHostname(t *testing.T) { // nolint:tparallel func TestTaildrop(t *testing.T) { IntegrationSkip(t) - t.Parallel() - - retry := func(times int, sleepInterval time.Duration, doWork func() error) error { - var err error - for range times { - err = doWork() - if err == nil { - return nil - } - time.Sleep(sleepInterval) - } - - return err - } spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -396,40 +386,27 @@ func TestTaildrop(t *testing.T) { "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets", } - err = retry(10, 1*time.Second, func() error { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { result, _, err := client.Execute(curlCommand) - if err != nil { - return err - } + assert.NoError(ct, err) + var fts []apitype.FileTarget err = json.Unmarshal([]byte(result), &fts) - if err != nil { - return err - } + assert.NoError(ct, err) if len(fts) != len(allClients)-1 { ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname()) for _, ft := range fts { ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) } - return fmt.Errorf( - "client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", - client.Hostname(), + assert.Failf(ct, "client %s does not have all its peers as FileTargets", + "got %d, want: %d\n%s", len(fts), len(allClients)-1, ftStr, ) } - - return err - }) - if err != nil { - t.Errorf( - "failed to query localapi for filetarget on %s, err: %s", - client.Hostname(), - err, - ) - } + }, 10*time.Second, 1*time.Second) } for _, client := range allClients { @@ -454,24 +431,15 @@ func TestTaildrop(t *testing.T) { fmt.Sprintf("%s:", peerFQDN), } - err := retry(10, 1*time.Second, func() error { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { t.Logf( "Sending file from %s to %s\n", client.Hostname(), peer.Hostname(), ) _, _, err := client.Execute(command) - - return err - }) - if err != nil { - t.Fatalf( - "failed to send taildrop file on %s with command %q, err: %s", - client.Hostname(), - strings.Join(command, " "), - err, - ) - } + assert.NoError(ct, err) + }, 10*time.Second, 1*time.Second) }) } } @@ -520,7 +488,6 @@ func TestTaildrop(t *testing.T) { func TestUpdateHostnameFromClient(t *testing.T) { IntegrationSkip(t) - t.Parallel() hostnames := map[string]string{ "1": "user1-host", @@ -603,9 +570,47 @@ func TestUpdateHostnameFromClient(t *testing.T) { assertNoErr(t, err) } - time.Sleep(5 * time.Second) + // Verify that the server-side rename is reflected in DNSName while HostName remains unchanged + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Build a map of expected DNSNames by node ID + expectedDNSNames := make(map[string]string) + for _, node := range nodes { + nodeID := strconv.FormatUint(node.GetId(), 10) + expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId()) + } + + // Verify from each client's perspective + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + + // Check self node + selfID := string(status.Self.ID) + expectedDNS := expectedDNSNames[selfID] + assert.Equal(ct, expectedDNS, status.Self.DNSName, + "Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID) + + // HostName should remain as the original client-reported hostname + originalHostname := hostnames[selfID] + assert.Equal(ct, originalHostname, status.Self.HostName, + "Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID) + + // Check peers + for _, peer := range status.Peer { + peerID := string(peer.ID) + if expectedDNS, ok := expectedDNSNames[peerID]; ok { + assert.Equal(ct, expectedDNS, peer.DNSName, + "Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname()) + + // HostName should remain as the original client-reported hostname + originalHostname := hostnames[peerID] + assert.Equal(ct, originalHostname, peer.HostName, + "Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname()) + } + } + } + }, 60*time.Second, 2*time.Second) - // Verify that the clients can see the new hostname, but no givenName for _, client := range allClients { status, err := client.Status() assertNoErr(t, err) @@ -647,7 +652,6 @@ func TestUpdateHostnameFromClient(t *testing.T) { func TestExpireNode(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -707,7 +711,23 @@ func TestExpireNode(t *testing.T) { t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) - time.Sleep(2 * time.Minute) + // Verify that the expired node has been marked in all peers list. + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + + if client.Hostname() != node.GetName() { + // Check if the expired node appears as expired in this client's peer list + for key, peer := range status.Peer { + if key == expiredNodeKey { + assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname()) + break + } + } + } + } + }, 3*time.Minute, 10*time.Second) now := time.Now() @@ -774,7 +794,6 @@ func TestExpireNode(t *testing.T) { func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -890,7 +909,6 @@ func TestNodeOnlineStatus(t *testing.T) { // five times ensuring they are able to restablish connectivity. func TestPingAllByIPManyUpDown(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -944,8 +962,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) { t.Fatalf("failed to take down all nodes: %s", err) } - time.Sleep(5 * time.Second) - for _, client := range allClients { c := client wg.Go(func() error { @@ -958,10 +974,14 @@ func TestPingAllByIPManyUpDown(t *testing.T) { t.Fatalf("failed to take down all nodes: %s", err) } - time.Sleep(5 * time.Second) + // Wait for sync and successful pings after nodes come back up + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = scenario.WaitForTailscaleSync() + assert.NoError(ct, err) - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + success := pingAllHelper(t, allClients, allAddrs) + assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up") + }, 30*time.Second, 2*time.Second) success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) @@ -970,7 +990,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) { func Test2118DeletingOnlineNodePanics(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1042,10 +1061,24 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { ) require.NoError(t, err) - time.Sleep(2 * time.Second) - // Ensure that the node has been deleted, this did not occur due to a panic. var nodeListAfter []v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeListAfter, + ) + assert.NoError(ct, err) + assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list") + }, 10*time.Second, 1*time.Second) + err = executeAndUnmarshal( headscale, []string{ diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 9c6816fa..c300a205 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -191,7 +191,7 @@ func WithPostgres() Option { } } -// WithPolicy sets the policy mode for headscale +// WithPolicy sets the policy mode for headscale. func WithPolicyMode(mode types.PolicyMode) Option { return func(hsic *HeadscaleInContainer) { hsic.policyMode = mode @@ -279,7 +279,7 @@ func New( return nil, err } - hostname := fmt.Sprintf("hs-%s", hash) + hostname := "hs-" + hash hsic := &HeadscaleInContainer{ hostname: hostname, @@ -308,14 +308,14 @@ func New( if hsic.postgres { hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres" - hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = fmt.Sprintf("postgres-%s", hash) + hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale" hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale" delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH") pgRunOptions := &dockertest.RunOptions{ - Name: fmt.Sprintf("postgres-%s", hash), + Name: "postgres-" + hash, Repository: "postgres", Tag: "latest", Networks: networks, @@ -328,7 +328,7 @@ func New( // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres") - + pg, err := pool.RunWithOptions(pgRunOptions) if err != nil { return nil, fmt.Errorf("starting postgres container: %w", err) @@ -373,7 +373,6 @@ func New( Env: env, } - if len(hsic.hostPortBindings) > 0 { runOptions.PortBindings = map[docker.Port][]docker.PortBinding{} for port, hostPorts := range hsic.hostPortBindings { @@ -396,7 +395,7 @@ func New( // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale") - + container, err := pool.BuildAndRunWithBuildOptions( headscaleBuildOptions, runOptions, @@ -566,7 +565,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { // extractTarToDirectory extracts a tar archive to a directory. func extractTarToDirectory(tarData []byte, targetDir string) error { - if err := os.MkdirAll(targetDir, 0755); err != nil { + if err := os.MkdirAll(targetDir, 0o755); err != nil { return fmt.Errorf("failed to create directory %s: %w", targetDir, err) } @@ -624,6 +623,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error { } targetDir := path.Join(savePath, t.hostname+"-pprof") + return extractTarToDirectory(tarFile, targetDir) } @@ -634,6 +634,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { } targetDir := path.Join(savePath, t.hostname+"-mapresponses") + return extractTarToDirectory(tarFile, targetDir) } @@ -672,17 +673,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { if err != nil { return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err) } - + if strings.TrimSpace(schemaCheck) == "" { - return fmt.Errorf("database file exists but has no schema (empty database)") + return errors.New("database file exists but has no schema (empty database)") } - + // Show a preview of the schema (first 500 chars) schemaPreview := schemaCheck if len(schemaPreview) > 500 { schemaPreview = schemaPreview[:500] + "..." } - log.Printf("Database schema preview:\n%s", schemaPreview) tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") if err != nil { @@ -727,7 +727,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } } - return fmt.Errorf("no regular file found in database tar archive") + return errors.New("no regular file found in database tar archive") } // Execute runs a command inside the Headscale container and returns the @@ -756,13 +756,13 @@ func (t *HeadscaleInContainer) Execute( // GetPort returns the docker container port as a string. func (t *HeadscaleInContainer) GetPort() string { - return fmt.Sprintf("%d", t.port) + return strconv.Itoa(t.port) } // GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer // instance. func (t *HeadscaleInContainer) GetHealthEndpoint() string { - return fmt.Sprintf("%s/health", t.GetEndpoint()) + return t.GetEndpoint() + "/health" } // GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer. @@ -772,10 +772,10 @@ func (t *HeadscaleInContainer) GetEndpoint() string { t.port) if t.hasTLS() { - return fmt.Sprintf("https://%s", hostEndpoint) + return "https://" + hostEndpoint } - return fmt.Sprintf("http://%s", hostEndpoint) + return "http://" + hostEndpoint } // GetCert returns the public certificate of the HeadscaleInContainer. @@ -910,6 +910,7 @@ func (t *HeadscaleInContainer) ListNodes( } ret = append(ret, nodes...) + return nil } @@ -932,6 +933,7 @@ func (t *HeadscaleInContainer) ListNodes( sort.Slice(ret, func(i, j int) bool { return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 }) + return ret, nil } @@ -943,10 +945,10 @@ func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) { var userMap map[string][]*v1.Node for _, node := range nodes { - if _, ok := userMap[node.User.Name]; !ok { - mak.Set(&userMap, node.User.Name, []*v1.Node{node}) + if _, ok := userMap[node.GetUser().GetName()]; !ok { + mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node}) } else { - userMap[node.User.Name] = append(userMap[node.User.Name], node) + userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node) } } @@ -999,7 +1001,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) { var userMap map[string]*v1.User for _, user := range users { - mak.Set(&userMap, user.Name, user) + mak.Set(&userMap, user.GetName(), user) } return userMap, nil @@ -1095,7 +1097,7 @@ func (h *HeadscaleInContainer) PID() (int, error) { case 1: return pids[0], nil default: - return 0, fmt.Errorf("multiple headscale processes running") + return 0, errors.New("multiple headscale processes running") } } @@ -1121,7 +1123,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( "headscale", "nodes", "approve-routes", "--output", "json", "--identifier", strconv.FormatUint(id, 10), - fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")), + "--routes=" + strings.Join(util.PrefixesToString(routes), ","), } result, _, err := dockertestutil.ExecuteCommand( diff --git a/integration/route_test.go b/integration/route_test.go index 053b4582..64677aec 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -4,13 +4,12 @@ import ( "encoding/json" "fmt" "net/netip" + "slices" "sort" "strings" "testing" "time" - "slices" - cmpdiff "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -37,7 +36,6 @@ var allPorts = filter.PortRange{First: 0, Last: 0xffff} // routes. func TestEnablingRoutes(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 3, @@ -182,11 +180,12 @@ func TestEnablingRoutes(t *testing.T) { for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] - if peerStatus.ID == "1" { + switch peerStatus.ID { + case "1": requirePeerSubnetRoutes(t, peerStatus, nil) - } else if peerStatus.ID == "2" { + case "2": requirePeerSubnetRoutes(t, peerStatus, nil) - } else { + default: requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")}) } } @@ -195,7 +194,6 @@ func TestEnablingRoutes(t *testing.T) { func TestHASubnetRouterFailover(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 3, @@ -779,7 +777,6 @@ func TestHASubnetRouterFailover(t *testing.T) { // https://github.com/juanfont/headscale/issues/1604 func TestSubnetRouteACL(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "user4" @@ -1003,7 +1000,6 @@ func TestSubnetRouteACL(t *testing.T) { // set during login instead of set. func TestEnablingExitRoutes(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "user2" @@ -1097,7 +1093,6 @@ func TestEnablingExitRoutes(t *testing.T) { // subnet router is working as expected. func TestSubnetRouterMultiNetwork(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1177,7 +1172,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Enable route _, err = headscale.ApproveRoutes( - nodes[0].Id, + nodes[0].GetId(), []netip.Prefix{*pref}, ) require.NoError(t, err) @@ -1224,7 +1219,6 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { IntegrationSkip(t) - t.Parallel() spec := ScenarioSpec{ NodesPerUser: 1, @@ -1300,7 +1294,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { } // Enable route - _, err = headscale.ApproveRoutes(nodes[0].Id, []netip.Prefix{tsaddr.AllIPv4()}) + _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) require.NoError(t, err) time.Sleep(5 * time.Second) @@ -1719,7 +1713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) assertNoErr(t, err) - err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key) + err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) assertNoErr(t, err) } // extra creation end. @@ -2065,7 +2059,6 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub // that are explicitly allowed in the ACL. func TestSubnetRouteACLFiltering(t *testing.T) { IntegrationSkip(t) - t.Parallel() // Use router and node users for better clarity routerUser := "router" @@ -2090,7 +2083,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) // Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24) - aclPolicyStr := fmt.Sprintf(`{ + aclPolicyStr := `{ "hosts": { "router": "100.64.0.1/32", "node": "100.64.0.2/32" @@ -2115,7 +2108,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { ] } ] - }`) + }` route, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) diff --git a/integration/scenario.go b/integration/scenario.go index 358291ff..b235cf34 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -123,7 +123,7 @@ type ScenarioSpec struct { // NodesPerUser is how many nodes should be attached to each user. NodesPerUser int - // Networks, if set, is the seperate Docker networks that should be + // Networks, if set, is the separate Docker networks that should be // created and a list of the users that should be placed in those networks. // If not set, a single network will be created and all users+nodes will be // added there. @@ -1077,7 +1077,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) - hostname := fmt.Sprintf("hs-oidcmock-%s", hash) + hostname := "hs-oidcmock-" + hash usersJSON, err := json.Marshal(users) if err != nil { @@ -1093,16 +1093,15 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse }, Networks: s.Networks(), Env: []string{ - fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), + "MOCKOIDC_ADDR=" + hostname, fmt.Sprintf("MOCKOIDC_PORT=%d", port), "MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_SECRET=supersecret", - fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), - fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), + "MOCKOIDC_ACCESS_TTL=" + accessTTL.String(), + "MOCKOIDC_USERS=" + string(usersJSON), }, } - headscaleBuildOptions := &dockertest.BuildOptions{ Dockerfile: hsic.IntegrationTestDockerFileName, ContextDir: dockerContextPath, @@ -1117,7 +1116,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") - + if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( headscaleBuildOptions, mockOidcOptions, @@ -1184,7 +1183,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) - hostname := fmt.Sprintf("hs-webservice-%s", hash) + hostname := "hs-webservice-" + hash network, ok := s.networks[s.prefixedNetworkName(networkName)] if !ok { diff --git a/integration/scenario_test.go b/integration/scenario_test.go index ac0ff238..ead3f1fd 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -28,7 +28,6 @@ func IntegrationSkip(t *testing.T) { // nolint:tparallel func TestHeadscale(t *testing.T) { IntegrationSkip(t) - t.Parallel() var err error @@ -75,7 +74,6 @@ func TestHeadscale(t *testing.T) { // nolint:tparallel func TestTailscaleNodesJoiningHeadcale(t *testing.T) { IntegrationSkip(t) - t.Parallel() var err error diff --git a/integration/ssh_test.go b/integration/ssh_test.go index cf08613d..236aba20 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -22,35 +22,6 @@ func isSSHNoAccessStdError(stderr string) bool { strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node") } -var retry = func(times int, sleepInterval time.Duration, - doWork func() (string, string, error), -) (string, string, error) { - var result string - var stderr string - var err error - - for range times { - tempResult, tempStderr, err := doWork() - - result += tempResult - stderr += tempStderr - - if err == nil { - return result, stderr, nil - } - - // If we get a permission denied error, we can fail immediately - // since that is something we won-t recover from by retrying. - if err != nil && isSSHNoAccessStdError(stderr) { - return result, stderr, err - } - - time.Sleep(sleepInterval) - } - - return result, stderr, err -} - func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario { t.Helper() @@ -92,7 +63,6 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce func TestSSHOneUserToAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -160,7 +130,6 @@ func TestSSHOneUserToAll(t *testing.T) { func TestSSHMultipleUsersAllToAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -216,7 +185,6 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { func TestSSHNoSSHConfigured(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -261,7 +229,6 @@ func TestSSHNoSSHConfigured(t *testing.T) { func TestSSHIsBlockedInACL(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -313,7 +280,6 @@ func TestSSHIsBlockedInACL(t *testing.T) { func TestSSHUserOnlyIsolation(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, &policyv2.Policy{ @@ -404,6 +370,14 @@ func TestSSHUserOnlyIsolation(t *testing.T) { } func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + return doSSHWithRetry(t, client, peer, true) +} + +func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + return doSSHWithRetry(t, client, peer, false) +} + +func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) { t.Helper() peerFQDN, _ := peer.FQDN() @@ -417,9 +391,29 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname()) log.Printf("Command: %s", strings.Join(command, " ")) - return retry(10, 1*time.Second, func() (string, string, error) { - return client.Execute(command) - }) + var result, stderr string + var err error + + if retry { + // Use assert.EventuallyWithT to retry SSH connections for success cases + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, stderr, err = client.Execute(command) + + // If we get a permission denied error, we can fail immediately + // since that is something we won't recover from by retrying. + if err != nil && isSSHNoAccessStdError(stderr) { + return // Don't retry permission denied errors + } + + // For all other errors, assert no error to trigger retry + assert.NoError(ct, err) + }, 10*time.Second, 1*time.Second) + } else { + // For failure cases, just execute once + result, stderr, err = client.Execute(command) + } + + return result, stderr, err } func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) { @@ -434,7 +428,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() - result, stderr, err := doSSH(t, client, peer) + result, stderr, err := doSSHWithoutRetry(t, client, peer) assert.Empty(t, result) @@ -444,7 +438,7 @@ func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer Tailsc func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() - result, stderr, _ := doSSH(t, client, peer) + result, stderr, _ := doSSHWithoutRetry(t, client, peer) assert.Empty(t, result) diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index d2738c55..3e4847eb 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -251,7 +251,6 @@ func New( Env: []string{}, } - if tsic.withWebsocketDERP { if version != VersionHead { return tsic, errInvalidClientConfig @@ -463,7 +462,7 @@ func (t *TailscaleInContainer) buildLoginCommand( if len(t.withTags) > 0 { command = append(command, - fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")), + "--advertise-tags="+strings.Join(t.withTags, ","), ) } @@ -685,7 +684,7 @@ func (t *TailscaleInContainer) MustID() types.NodeID { // Panics if version is lower then minimum. func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { - panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version)) + panic("tsic.Netmap() called with unsupported version: " + t.version) } command := []string{ @@ -1026,7 +1025,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err "tailscale", "ping", fmt.Sprintf("--timeout=%s", args.timeout), fmt.Sprintf("--c=%d", args.count), - fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)), + "--until-direct=" + strconv.FormatBool(args.direct), } command = append(command, hostnameOrIP) @@ -1131,11 +1130,11 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err command := []string{ "curl", "--silent", - "--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())), - "--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())), - "--retry", fmt.Sprintf("%d", args.retry), - "--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())), - "--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())), + "--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())), + "--max-time", strconv.Itoa(int(args.maxTime.Seconds())), + "--retry", strconv.Itoa(args.retry), + "--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())), + "--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())), url, } @@ -1230,7 +1229,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) { } if out.Len() == 0 { - return nil, fmt.Errorf("file is empty") + return nil, errors.New("file is empty") } return out.Bytes(), nil @@ -1259,5 +1258,6 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { if err = json.Unmarshal(currentProfile, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err) } + return &p.Persist.PrivateNodeKey, nil } diff --git a/integration/utils.go b/integration/utils.go index bcf488e2..c19f6459 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -3,7 +3,6 @@ package integration import ( "bufio" "bytes" - "context" "fmt" "io" "net/netip" @@ -267,7 +266,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) { // This isn't really relevant for Self as it won't be in its own socket/wireguard. // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) - // assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) for _, peer := range status.Peer { assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) @@ -311,7 +310,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) { func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { t.Helper() - _, err := backoff.Retry(context.Background(), func() (struct{}, error) { + _, err := backoff.Retry(t.Context(), func() (struct{}, error) { stdout, stderr, err := c.Execute(command) if err != nil { return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) @@ -492,6 +491,7 @@ func groupApprover(name string) policyv2.AutoApprover { func tagApprover(name string) policyv2.AutoApprover { return ptr.To(policyv2.Tag(name)) } + // // // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus // // if there is a peer with the given hostname. If no peer is found, nil is returned.