diff --git a/.golangci.yaml b/.golangci.yaml index 7e1ab297..5ebd698a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -18,6 +18,7 @@ linters: - lll - maintidx - makezero + - mnd - musttag - nestif - nolintlint diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index d5da5852..8204ecc2 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -1,6 +1,7 @@ package cli import ( + "context" "encoding/json" "fmt" "net" @@ -99,7 +100,7 @@ func mockOIDC() error { return err } - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port)) + listener, err := new(net.ListenConfig).Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", addrStr, port)) if err != nil { return err } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index b5192318..827d72e7 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -243,10 +243,7 @@ var listNodeRoutesCmd = &cobra.Command{ return } - tableData, err := nodeRoutesToPtables(nodes) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - } + tableData := nodeRoutesToPtables(nodes) err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { @@ -629,7 +626,7 @@ func nodesToPtables( func nodeRoutesToPtables( nodes []*v1.Node, -) (pterm.TableData, error) { +) pterm.TableData { tableHeader := []string{ "ID", "Hostname", @@ -653,7 +650,7 @@ func nodeRoutesToPtables( ) } - return tableData, nil + return tableData } var tagCmd = &cobra.Command{ diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index dc65ec3d..f7ea957a 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -480,7 +480,7 @@ func createDockerClient() (*client.Client, error) { // getCurrentDockerContext retrieves the current Docker context information. func getCurrentDockerContext() (*DockerContext, error) { - cmd := exec.Command("docker", "context", "inspect") + cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect") output, err := cmd.Output() if err != nil { diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 62e9cf29..c30a1ca9 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -265,7 +265,7 @@ func checkGoInstallation() DoctorResult { } } - cmd := exec.Command("go", "version") + cmd := exec.CommandContext(context.Background(), "go", "version") output, err := cmd.Output() if err != nil { @@ -287,7 +287,7 @@ func checkGoInstallation() DoctorResult { // checkGitRepository verifies we're in a git repository. func checkGitRepository() DoctorResult { - cmd := exec.Command("git", "rev-parse", "--git-dir") + cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir") err := cmd.Run() if err != nil { @@ -320,7 +320,7 @@ func checkRequiredFiles() DoctorResult { var missingFiles []string for _, file := range requiredFiles { - cmd := exec.Command("test", "-e", file) + cmd := exec.CommandContext(context.Background(), "test", "-e", file) err := cmd.Run() if err != nil { diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index 191bbdd0..9da6d5a8 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -403,25 +403,25 @@ func calculateStatsSummary(values []float64) StatsSummary { return StatsSummary{} } - min := values[0] - max := values[0] + minVal := values[0] + maxVal := values[0] sum := 0.0 for _, value := range values { - if value < min { - min = value + if value < minVal { + minVal = value } - if value > max { - max = value + if value > maxVal { + maxVal = value } sum += value } return StatsSummary{ - Min: min, - Max: max, + Min: minVal, + Max: maxVal, Average: sum / float64(len(values)), } } diff --git a/hscontrol/app.go b/hscontrol/app.go index 822de62b..4a8e5658 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -623,7 +623,7 @@ func (h *Headscale) Serve() error { return fmt.Errorf("setting up unix socket: %w", err) } - socketListener, err := net.Listen("unix", h.cfg.UnixSocket) + socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket) if err != nil { return fmt.Errorf("setting up gRPC socket: %w", err) } @@ -716,7 +716,7 @@ func (h *Headscale) Serve() error { v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) reflection.Register(grpcServer) - grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) + grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr) if err != nil { return fmt.Errorf("binding to TCP address: %w", err) } @@ -751,7 +751,7 @@ func (h *Headscale) Serve() error { httpServer.TLSConfig = tlsConfig httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig) } else { - httpListener, err = net.Listen("tcp", h.cfg.Addr) + httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr) } if err != nil { @@ -788,12 +788,14 @@ func (h *Headscale) Serve() error { if tailsqlEnabled { if h.cfg.Database.Type != types.DatabaseSqlite { + //nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start log.Fatal(). Str("type", h.cfg.Database.Type). Msgf("tailsql only support %q", types.DatabaseSqlite) } if tailsqlTSKey == "" { + //nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") } diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 47a527b9..f93b9ef8 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "os" "os/exec" @@ -177,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { return err } - _, err = db.Exec(string(schemaContent)) + _, err = db.ExecContext(context.Background(), string(schemaContent)) return err } @@ -322,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { } // Construct the pg_restore command - cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) + cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) // Set the output streams cmd.Stdout = os.Stdout diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 46dd58ba..83289055 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -27,6 +27,9 @@ import ( const ( NodeGivenNameHashLength = 8 NodeGivenNameTrimSize = 2 + + // defaultTestNodePrefix is the default hostname prefix for nodes created in tests. + defaultTestNodePrefix = "testnode" ) // ErrNodeNameNotUnique is returned when a node name is not unique. @@ -669,7 +672,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) panic("CreateNodeForTest requires a valid user") } - nodeName := "testnode" + nodeName := defaultTestNodePrefix if len(hostname) > 0 && hostname[0] != "" { nodeName = hostname[0] } @@ -741,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname panic("CreateNodesForTest requires a valid user") } - prefix := "testnode" + prefix := defaultTestNodePrefix if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { prefix = hostnamePrefix[0] } @@ -764,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int panic("CreateRegisteredNodesForTest requires a valid user") } - prefix := "testnode" + prefix := defaultTestNodePrefix if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { prefix = hostnamePrefix[0] } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e6602564..b247bd61 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -625,10 +625,10 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) { } //nolint:unused -func generateRandomNumber(t *testing.T, max int64) int64 { +func generateRandomNumber(t *testing.T, maxVal int64) int64 { t.Helper() - maxB := big.NewInt(max) + maxB := big.NewInt(maxVal) n, err := rand.Int(rand.Reader, maxB) if err != nil { diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index b411daeb..00adaa64 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -1,6 +1,7 @@ package sqliteconfig import ( + "context" "database/sql" "path/filepath" "strings" @@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { defer db.Close() // Test connection - if err := db.Ping(); err != nil { + ctx := context.Background() + + err = db.PingContext(ctx) + if err != nil { t.Fatalf("Failed to ping database: %v", err) } @@ -112,7 +116,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { query := "PRAGMA " + pragma - err := db.QueryRow(query).Scan(&actualValue) + err := db.QueryRowContext(ctx, query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) } @@ -165,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } defer db.Close() + ctx := context.Background() + // Create test tables with foreign key relationship schema := ` CREATE TABLE parent ( @@ -180,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { ); ` - if _, err := db.Exec(schema); err != nil { + _, err = db.ExecContext(ctx, schema) + if err != nil { t.Fatalf("Failed to create schema: %v", err) } // Insert parent record - if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { + _, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')") + if err != nil { t.Fatalf("Failed to insert parent: %v", err) } // Test 1: Valid foreign key should work - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") + _, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") if err != nil { t.Fatalf("Valid foreign key insert failed: %v", err) } // Test 2: Invalid foreign key should fail - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") + _, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") if err == nil { t.Error("Expected foreign key constraint violation, but insert succeeded") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -206,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } // Test 3: Deleting referenced parent should fail - _, err = db.Exec("DELETE FROM parent WHERE id = 1") + _, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1") if err == nil { t.Error("Expected foreign key constraint violation when deleting referenced parent") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -252,7 +260,7 @@ func TestJournalModeValidation(t *testing.T) { var actualMode string - err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) + err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) } diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 6c3df807..f965ee6b 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -101,12 +101,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { // If debug flag is set, resolve hostname to IP address if debugUseDERPIP { - ips, err := net.LookupIP(host) + ips, err := new(net.Resolver).LookupIPAddr(context.Background(), host) if err != nil { log.Error().Caller().Err(err).Msgf("failed to resolve DERP hostname %s to IP, using hostname", host) } else if len(ips) > 0 { // Use the first IP address - ipStr := ips[0].String() + ipStr := ips[0].IP.String() log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: resolved %s to %s", host, ipStr) host = ipStr } @@ -355,7 +355,7 @@ func DERPBootstrapDNSHandler( // ServeSTUN starts a STUN server on the configured addr. func (d *DERPServer) ServeSTUN() { - packetConn, err := net.ListenPacket("udp", d.cfg.STUNAddr) + packetConn, err := new(net.ListenConfig).ListenPacket(context.Background(), "udp", d.cfg.STUNAddr) if err != nil { log.Fatal().Msgf("failed to open STUN listener: %v", err) } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 648e72a7..c26385bc 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -382,16 +382,14 @@ type UpdateInfo struct { } // parseUpdateAndAnalyze parses an update and returns detailed information. -func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { - info := UpdateInfo{ +func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo { + return UpdateInfo{ PeerCount: len(resp.Peers), PatchCount: len(resp.PeersChangedPatch), IsFull: len(resp.Peers) > 0, IsPatch: len(resp.PeersChangedPatch) > 0, IsDERP: resp.DERPMap != nil, } - - return info, nil } // start begins consuming updates from the node's channel and tracking stats. @@ -413,7 +411,8 @@ func (n *node) start() { atomic.AddInt64(&n.updateCount, 1) // Parse update and track detailed stats - if info, err := parseUpdateAndAnalyze(data); err == nil { + info := parseUpdateAndAnalyze(data) + { // Track update types if info.IsFull { atomic.AddInt64(&n.fullCount, 1) @@ -840,7 +839,7 @@ func TestBatcherBasicOperations(t *testing.T) { } // Drain any initial messages from first node - drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) + drainChannelTimeout(tn.ch, 100*time.Millisecond) // Add the second node and verify update message _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100) @@ -909,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) { } } -func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { - count := 0 - +func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) { timer := time.NewTimer(timeout) defer timer.Stop() for { select { - case data := <-ch: - count++ - // Optional: add debug output if needed - _ = data + case <-ch: + // Drain message case <-timer.C: return } @@ -1431,6 +1426,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for i := range numCycles { for j := range churningNodes { node := &churningNodes[j] + wg.Add(2) // Connect churning node @@ -1661,7 +1657,7 @@ func XTestBatcherScalability(t *testing.T) { description string } - var testCases []testCase + testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes)) // Generate all combinations of the test matrix for _, nodeCount := range nodes { @@ -1773,6 +1769,7 @@ func XTestBatcherScalability(t *testing.T) { for i := range testNodes { node := &testNodes[i] _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true @@ -2303,6 +2300,7 @@ func TestBatcherRapidReconnection(t *testing.T) { for i := range allNodes { node := &allNodes[i] + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add node %d: %v", i, err) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 18b2ac1d..368e1829 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -3,14 +3,10 @@ package mapper import ( "fmt" "net/netip" - "slices" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -81,97 +77,3 @@ func TestDNSConfigMapResponse(t *testing.T) { }) } } - -// mockState is a mock implementation that provides the required methods. -type mockState struct { - polMan policy.PolicyManager - derpMap *tailcfg.DERPMap - primary *routes.PrimaryRoutes - nodes types.Nodes - peers types.Nodes -} - -func (m *mockState) DERPMap() *tailcfg.DERPMap { - return m.derpMap -} - -func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { - if m.polMan == nil { - return tailcfg.FilterAllowAll, nil - } - - return m.polMan.Filter() -} - -func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - if m.polMan == nil { - return nil, nil - } - - return m.polMan.SSHPolicy(node) -} - -func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { - if m.polMan == nil { - return false - } - - return m.polMan.NodeCanHaveTag(node, tag) -} - -func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { - if m.primary == nil { - return nil - } - - return m.primary.PrimaryRoutes(nodeID) -} - -func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - if len(peerIDs) > 0 { - // Filter peers by the provided IDs - var filtered types.Nodes - - for _, peer := range m.peers { - if slices.Contains(peerIDs, peer.ID) { - filtered = append(filtered, peer) - } - } - - return filtered, nil - } - // Return all peers except the node itself - var filtered types.Nodes - - for _, peer := range m.peers { - if peer.ID != nodeID { - filtered = append(filtered, peer) - } - } - - return filtered, nil -} - -func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { - if len(nodeIDs) > 0 { - // Filter nodes by the provided IDs - var filtered types.Nodes - - for _, node := range m.nodes { - if slices.Contains(nodeIDs, node.ID) { - filtered = append(filtered, node) - } - } - - return filtered, nil - } - - return m.nodes, nil -} - -func Test_fullMapResponse(t *testing.T) { - t.Skip("Test needs to be refactored for new state-based architecture") - // TODO: Refactor this test to work with the new state-based mapper - // The test architecture needs to be updated to work with the state interface - // instead of the old direct dependency injection pattern -} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index c07694e3..0c4347fd 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -327,11 +327,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } // TODO(kradalby): replace with go-elem - content, err := renderOIDCCallbackTemplate(user, verb) - if err != nil { - httpError(writer, err) - return - } + content := renderOIDCCallbackTemplate(user, verb) writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) @@ -604,9 +600,9 @@ func (a *AuthProviderOIDC) handleRegistration( func renderOIDCCallbackTemplate( user *types.User, verb string, -) (*bytes.Buffer, error) { +) *bytes.Buffer { html := templates.OIDCCallback(user.Display(), verb).Render() - return bytes.NewBufferString(html), nil + return bytes.NewBufferString(html) } // getCookieName generates a unique cookie name based on a cookie value. diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index b0c26f30..6dfacd91 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -70,7 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ } func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { - var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) + polmanFuncs := make([]func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error), 0, 1) polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { return policyv2.NewPolicyManager(pol, u, n) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 0d5c2b75..486fdec7 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -799,7 +799,7 @@ func TestReduceNodes(t *testing.T) { func TestReduceNodesFromPolicy(t *testing.T) { n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { - var routes []netip.Prefix + routes := make([]netip.Prefix, 0, len(routess)) for _, route := range routess { routes = append(routes, netip.MustParsePrefix(route)) } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 67736177..2bce5bfc 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -48,7 +48,7 @@ func (pol *Policy) compileFilterRules( continue } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange @@ -165,7 +165,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( } } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() var rules []*tailcfg.FilterRule diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 15c81af4..062bcc04 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -14,7 +14,7 @@ import ( "tailscale.com/types/ptr" ) -func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { +func node(name, ipv4, ipv6 string, user types.User) *types.Node { return &types.Node{ ID: 0, Hostname: name, @@ -22,7 +22,6 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) IPv6: ap(ipv6), User: ptr.To(user), UserID: ptr.To(user.ID), - Hostinfo: hostinfo, } } @@ -89,10 +88,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { }` initialNodes := types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), } for i, n := range initialNodes { @@ -119,10 +118,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "no_changes", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 0, description: "No changes should clear no cache entries", @@ -130,11 +129,11 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "node_added", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0], nil), // New node - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0]), // New node + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 2, // user1's existing nodes should be cleared description: "Adding a node should clear cache for that user's existing nodes", @@ -142,10 +141,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "node_removed", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), // user1-node2 removed - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 2, // user1's remaining node + removed node should be cleared description: "Removing a node should clear cache for that user's remaining nodes", @@ -153,10 +152,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "user_changed", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2], nil), // Changed to user3 - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2]), // Changed to user3 + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 3, // user1's node + user2's node + user3's nodes should be cleared description: "Changing a node's user should clear cache for both old and new users", @@ -164,10 +163,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "ip_changed", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0], nil), // IP changed - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0]), // IP changed + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 2, // user1's nodes should be cleared description: "Changing a node's IP should clear cache for that user's nodes", @@ -381,9 +380,9 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { users := types.Users{user1, user2} // Create two nodes - node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil) + node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1) node1.ID = 1 - node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil) + node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2) node2.ID = 2 nodes := types.Nodes{node1, node2} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index ad475220..d161138b 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1455,55 +1455,49 @@ func (p Protocol) Description() string { } } -// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. -func (p Protocol) parseProtocol() ([]int, bool) { +func (p Protocol) parseProtocol() []int { switch p { case "": // Empty protocol applies to TCP, UDP, ICMP, and ICMPv6 traffic // This matches Tailscale's behavior for protocol defaults - return []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, false + return []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP} case ProtocolNameWildcard: // Wildcard protocol - defensive handling (should not reach here due to validation) - return nil, false + return nil case ProtocolNameIGMP: - return []int{ProtocolIGMP}, true + return []int{ProtocolIGMP} case ProtocolNameIPv4, ProtocolNameIPInIP: - return []int{ProtocolIPv4}, true + return []int{ProtocolIPv4} case ProtocolNameTCP: - return []int{ProtocolTCP}, false + return []int{ProtocolTCP} case ProtocolNameEGP: - return []int{ProtocolEGP}, true + return []int{ProtocolEGP} case ProtocolNameIGP: - return []int{ProtocolIGP}, true + return []int{ProtocolIGP} case ProtocolNameUDP: - return []int{ProtocolUDP}, false + return []int{ProtocolUDP} case ProtocolNameGRE: - return []int{ProtocolGRE}, true + return []int{ProtocolGRE} case ProtocolNameESP: - return []int{ProtocolESP}, true + return []int{ProtocolESP} case ProtocolNameAH: - return []int{ProtocolAH}, true + return []int{ProtocolAH} case ProtocolNameSCTP: - return []int{ProtocolSCTP}, false + return []int{ProtocolSCTP} case ProtocolNameICMP: // ICMP only - use "ipv6-icmp" or protocol number 58 for ICMPv6 - return []int{ProtocolICMP}, true + return []int{ProtocolICMP} case ProtocolNameIPv6ICMP: - return []int{ProtocolIPv6ICMP}, true + return []int{ProtocolIPv6ICMP} case ProtocolNameFC: - return []int{ProtocolFC}, true + return []int{ProtocolFC} default: // Try to parse as a numeric protocol number // This should not fail since validation happened during unmarshaling protocolNumber, _ := strconv.Atoi(string(p)) - - // Determine if wildcard is needed based on protocol number - needsWildcard := protocolNumber != ProtocolTCP && - protocolNumber != ProtocolUDP && - protocolNumber != ProtocolSCTP - - return []int{protocolNumber}, needsWildcard + return []int{protocolNumber} } } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 8a94c43b..8a842e49 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -1,15 +1,12 @@ package state import ( - "net/netip" "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -136,26 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node") }) } - -// Simple helper function for tests. -func createTestNodeSimple(id types.NodeID) *types.Node { - user := types.User{ - Name: "test-user", - } - - machineKey := key.NewMachine() - nodeKey := key.NewNode() - - node := &types.Node{ - ID: id, - Hostname: "test-node", - UserID: ptr.To(uint(id)), - User: &user, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPv4: &netip.Addr{}, - IPv6: &netip.Addr{}, - } - - return node -} diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index e29857b5..72ea1f81 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -931,8 +931,6 @@ func TestNodeStoreConcurrentPutNode(t *testing.T) { // --- Batching: concurrent ops fit in one batch ---. func TestNodeStoreBatchingEfficiency(t *testing.T) { - const batchSize = 10 - const ops = 15 // more than batchSize store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index ce679063..013cf56d 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -230,6 +230,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // propagate correctly when switching between policy types. s.nodeStore.RebuildPeerMaps() + //nolint:prealloc // cs starts with one element and may grow cs := []change.Change{change.PolicyChange()} // Always call autoApproveNodes during policy reload, regardless of whether diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index fbcaa663..1b5a3806 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -181,10 +181,10 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // wildcardBits is the number of bits not under the mask in the lastOctet wildcardBits := ByteSize - maskBits%ByteSize - // min is the value in the lastOctet byte of the IP - // max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1 - min := uint(netRange.IP[lastOctet]) - max := (min + 1<