diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index f6b5f71a..0d0025d3 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -130,7 +130,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g return ctx, client, conn, cancel } -func output(result interface{}, override string, outputFormat string) string { +func output(result any, override string, outputFormat string) string { var jsonBytes []byte var err error switch outputFormat { @@ -158,7 +158,7 @@ func output(result interface{}, override string, outputFormat string) string { } // SuccessOutput prints the result to stdout and exits with status code 0. -func SuccessOutput(result interface{}, override string, outputFormat string) { +func SuccessOutput(result any, override string, outputFormat string) { fmt.Println(output(result, override, outputFormat)) os.Exit(0) } diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 1a47b7ff..3df7c7f6 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -455,10 +455,10 @@ func boolToInt(b bool) int { // DockerContext represents Docker context information. type DockerContext struct { - Name string `json:"Name"` - Metadata map[string]interface{} `json:"Metadata"` - Endpoints map[string]interface{} `json:"Endpoints"` - Current bool `json:"Current"` + Name string `json:"Name"` + Metadata map[string]any `json:"Metadata"` + Endpoints map[string]any `json:"Endpoints"` + Current bool `json:"Current"` } // createDockerClient creates a Docker client with context detection. @@ -473,7 +473,7 @@ func createDockerClient() (*client.Client, error) { if contextInfo != nil { if endpoints, ok := contextInfo.Endpoints["docker"]; ok { - if endpointMap, ok := endpoints.(map[string]interface{}); ok { + if endpointMap, ok := endpoints.(map[string]any); ok { if host, ok := endpointMap["Host"].(string); ok { if runConfig.Verbose { log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 9a5566c6..c142bf9d 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -2760,7 +2760,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { require.Equal(t, 2, user2NodesAfter.Len(), "user2 should still have 2 nodes (old nodes from original registration)") // Verify original nodes still exist with original users - for i := 0; i < 2; i++ { + for i := range 2 { node := nodes[i] // User1's original nodes should still be owned by user1 registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) @@ -3195,6 +3195,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { assert.Error(t, err, "expired pre-auth key should be rejected") assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration") } + // TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey tests that an existing node // can re-register using a pre-auth key that's already marked as Used=true, as long as: // 1. The node is re-registering with the same MachineKey it originally used @@ -3204,7 +3205,8 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { // // Background: When Docker/Kubernetes containers restart, they keep their persistent state // (including the MachineKey), but container entrypoints unconditionally run: -// tailscale up --authkey=$TS_AUTHKEY +// +// tailscale up --authkey=$TS_AUTHKEY // // This caused nodes to be rejected after restart because the pre-auth key was already // marked as Used=true from the initial registration. The fix allows re-registration of diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 1652901f..6172e7e0 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -31,7 +31,7 @@ func decodingError(name string, err error) error { // have a type that implements encoding.TextUnmarshaler. type TextSerialiser struct{} -func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { +func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue any) error { fieldValue := reflect.New(field.FieldType) // If the field is a pointer, we need to dereference it to get the actual type @@ -77,10 +77,10 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect } } - return err + return nil } -func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { +func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue any) (any, error) { switch v := fieldValue.(type) { case encoding.TextMarshaler: // If the value is nil, we return nil, however, go nil values are not diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 24910947..f98bb6c8 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -1136,13 +1136,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) { // First connection ch1 := make(chan *tailcfg.MapResponse, 1) - wg.Add(1) - - go func() { - defer wg.Done() - + wg.Go(func() { batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) - }() + }) // Add real work during connection chaos if i%10 == 0 { @@ -1152,24 +1148,17 @@ func XTestBatcherChannelClosingRace(t *testing.T) { // Rapid second connection - should replace ch1 ch2 := make(chan *tailcfg.MapResponse, 1) - wg.Add(1) - - go func() { - defer wg.Done() - + wg.Go(func() { time.Sleep(1 * time.Microsecond) batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) - }() + }) // Remove second connection - wg.Add(1) - - go func() { - defer wg.Done() + wg.Go(func() { time.Sleep(2 * time.Microsecond) batcher.RemoveNode(testNode.n.ID, ch2) - }() + }) wg.Wait() @@ -1789,10 +1778,7 @@ func XTestBatcherScalability(t *testing.T) { // This ensures some nodes stay connected to continue receiving updates startIdx := cycle % len(testNodes) - endIdx := startIdx + len(testNodes)/4 - if endIdx > len(testNodes) { - endIdx = len(testNodes) - } + endIdx := min(startIdx+len(testNodes)/4, len(testNodes)) if startIdx >= endIdx { startIdx = 0 @@ -2313,7 +2299,7 @@ func TestBatcherRapidReconnection(t *testing.T) { receivedCount := 0 timeout := time.After(500 * time.Millisecond) - for i := 0; i < len(allNodes); i++ { + for i := range allNodes { select { case update := <-newChannels[i]: if update != nil { diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index dd8e70c5..11d57087 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -3,6 +3,7 @@ package v2 import ( "errors" "fmt" + "slices" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -178,11 +179,8 @@ func (pol *Policy) compileACLWithAutogroupSelf( for _, ips := range resolvedSrcIPs { for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set - for _, nodeIP := range n.IPs() { - if ips.Contains(nodeIP) { - n.AppendToIPSet(&srcIPs) - break - } + if slices.ContainsFunc(n.IPs(), ips.Contains) { + n.AppendToIPSet(&srcIPs) } } } @@ -375,11 +373,8 @@ func (pol *Policy) compileSSHPolicy( var filteredSrcIPs netipx.IPSetBuilder for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set - for _, nodeIP := range n.IPs() { - if srcIPs.Contains(nodeIP) { - n.AppendToIPSet(&filteredSrcIPs) - break // Found this node, move to next - } + if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { + n.AppendToIPSet(&filteredSrcIPs) // Found this node, move to next } } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 37ff8730..9cb0f9b8 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -3,6 +3,7 @@ package v2 import ( "encoding/json" "net/netip" + "slices" "strings" "testing" "time" @@ -906,14 +907,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { } for _, expectedIP := range expectedDestIPs { - found := false - - for _, actualIP := range actualDestIPs { - if actualIP == expectedIP { - found = true - break - } - } + found := slices.Contains(actualDestIPs, expectedIP) if !found { t.Errorf("expected destination IP %s to be included, got: %v", expectedIP, actualDestIPs) diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 2d2f2f19..7b4b2b28 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1007,7 +1007,7 @@ func (g Groups) Contains(group *Group) error { // with "group:". If any group name is invalid, an error is returned. func (g *Groups) UnmarshalJSON(b []byte) error { // First unmarshal as a generic map to validate group names first - var rawMap map[string]interface{} + var rawMap map[string]any if err := json.Unmarshal(b, &rawMap); err != nil { return err } @@ -1024,7 +1024,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { rawGroups := make(map[string][]string) for key, value := range rawMap { switch v := value.(type) { - case []interface{}: + case []any: // Convert []interface{} to []string var stringSlice []string for _, item := range v { diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index 7482c97b..a4367775 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -39,9 +39,10 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } var portRanges []tailcfg.PortRange - parts := strings.Split(portDef, ",") - for _, part := range parts { + parts := strings.SplitSeq(portDef, ",") + + for part := range parts { if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 03d6854f..757591ad 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -200,9 +200,9 @@ func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy { } // DebugRegistrationCache returns debug information about the registration cache. -func (s *State) DebugRegistrationCache() map[string]interface{} { +func (s *State) DebugRegistrationCache() map[string]any { // The cache doesn't expose internal statistics, so we provide basic info - result := map[string]interface{}{ + result := map[string]any{ "type": "zcache", "expiration": registerCacheExpiration.String(), "cleanup": registerCacheCleanup.String(), diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 79f3b1e0..788721b9 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -872,7 +872,7 @@ func TestNodeStoreConcurrentPutNode(t *testing.T) { var wg sync.WaitGroup results := make(chan bool, concurrentOps) - for i := 0; i < concurrentOps; i++ { + for i := range concurrentOps { wg.Add(1) go func(nodeID int) { defer wg.Done() @@ -904,7 +904,7 @@ func TestNodeStoreBatchingEfficiency(t *testing.T) { var wg sync.WaitGroup results := make(chan bool, ops) - for i := 0; i < ops; i++ { + for i := range ops { wg.Add(1) go func(nodeID int) { defer wg.Done() @@ -941,11 +941,12 @@ func TestNodeStoreRaceConditions(t *testing.T) { var wg sync.WaitGroup errors := make(chan error, numGoroutines*opsPerGoroutine) - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { wg.Add(1) go func(gid int) { defer wg.Done() - for j := 0; j < opsPerGoroutine; j++ { + + for j := range opsPerGoroutine { switch j % 3 { case 0: resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) { @@ -993,7 +994,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) { afterStartGoroutines := runtime.NumGoroutine() const ops = 100 - for i := 0; i < ops; i++ { + for i := range ops { nodeID := types.NodeID(i + 1) node := createConcurrentTestNode(nodeID, "cleanup-node") resultNode := store.PutNode(node) @@ -1100,7 +1101,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { // --- Edge case: update non-existent node --- func TestNodeStoreUpdateNonExistentNode(t *testing.T) { - for i := 0; i < 10; i++ { + for i := range 10 { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() nonExistentID := types.NodeID(999 + i) @@ -1124,8 +1125,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) { store.Start() defer store.Stop() - b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { nodeID := types.NodeID(i + 1) node := createConcurrentTestNode(nodeID, "bench-node") store.PutNode(node) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index a7d815bf..f4814519 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -220,10 +220,12 @@ func DefaultBatcherWorkers() int { // DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count. // Default to 3/4 of CPU cores, minimum 1, no maximum. func DefaultBatcherWorkersFor(cpuCount int) int { - defaultWorkers := (cpuCount * 3) / 4 - if defaultWorkers < 1 { - defaultWorkers = 1 - } + const ( + workerNumerator = 3 + workerDenominator = 4 + ) + + defaultWorkers := max((cpuCount*workerNumerator)/workerDenominator, 1) return defaultWorkers } diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go index 936b374c..f28cd4a3 100644 --- a/hscontrol/util/log.go +++ b/hscontrol/util/log.go @@ -49,22 +49,22 @@ func (l *DBLogWrapper) LogMode(gormLogger.LogLevel) gormLogger.Interface { return l } -func (l *DBLogWrapper) Info(ctx context.Context, msg string, data ...interface{}) { +func (l *DBLogWrapper) Info(ctx context.Context, msg string, data ...any) { l.Logger.Info().Msgf(msg, data...) } -func (l *DBLogWrapper) Warn(ctx context.Context, msg string, data ...interface{}) { +func (l *DBLogWrapper) Warn(ctx context.Context, msg string, data ...any) { l.Logger.Warn().Msgf(msg, data...) } -func (l *DBLogWrapper) Error(ctx context.Context, msg string, data ...interface{}) { +func (l *DBLogWrapper) Error(ctx context.Context, msg string, data ...any) { l.Logger.Error().Msgf(msg, data...) } func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { elapsed := time.Since(begin) sql, rowsAffected := fc() - fields := map[string]interface{}{ + fields := map[string]any{ "duration": elapsed, "sql": sql, "rowsAffected": rowsAffected, @@ -83,7 +83,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq l.Logger.Debug().Fields(fields).Msgf("") } -func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { +func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) { if l.ParameterizedQueries { return sql, nil } diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 6c2d07e4..df5f2455 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -98,7 +98,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should NOT contain user data after "Unauthorized" // This is the security bypass - if users array is present, auth was bypassed - var jsonCheck map[string]interface{} + var jsonCheck map[string]any jsonErr := json.Unmarshal(body, &jsonCheck) // If we can unmarshal JSON and it contains "users", that's the bypass @@ -278,8 +278,8 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { var responseBody string for _, line := range lines { - if strings.HasPrefix(line, "HTTP_CODE:") { - httpCode = strings.TrimPrefix(line, "HTTP_CODE:") + if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { + httpCode = after } else { responseBody += line } @@ -324,8 +324,8 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { var responseBody string for _, line := range lines { - if strings.HasPrefix(line, "HTTP_CODE:") { - httpCode = strings.TrimPrefix(line, "HTTP_CODE:") + if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { + httpCode = after } else { responseBody += line } @@ -359,8 +359,8 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { var responseBody string for _, line := range lines { - if strings.HasPrefix(line, "HTTP_CODE:") { - httpCode = strings.TrimPrefix(line, "HTTP_CODE:") + if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { + httpCode = after } else { responseBody += line } @@ -459,9 +459,9 @@ func TestGRPCAuthenticationBypass(t *testing.T) { outputStr := strings.ToLower(output) assert.True(t, strings.Contains(outputStr, "unauthenticated") || - strings.Contains(outputStr, "invalid token") || - strings.Contains(outputStr, "failed to validate token") || - strings.Contains(outputStr, "authentication"), + strings.Contains(outputStr, "invalid token") || + strings.Contains(outputStr, "failed to validate token") || + strings.Contains(outputStr, "authentication"), "Error should indicate authentication failure, got: %s", output) // Should NOT leak user data @@ -609,9 +609,9 @@ cli: outputStr := strings.ToLower(output) assert.True(t, strings.Contains(outputStr, "unauthenticated") || - strings.Contains(outputStr, "invalid token") || - strings.Contains(outputStr, "failed to validate token") || - strings.Contains(outputStr, "authentication"), + strings.Contains(outputStr, "invalid token") || + strings.Contains(outputStr, "failed to validate token") || + strings.Contains(outputStr, "authentication"), "Error should indicate authentication failure, got: %s", output) // Should NOT leak user data diff --git a/integration/helpers.go b/integration/helpers.go index 133a175b..7d40c8e6 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -56,13 +56,6 @@ type NodeSystemStatus struct { NodeStore bool } -// requireNotNil validates that an object is not nil and fails the test if it is. -// This helper provides consistent error messaging for nil checks in integration tests. -func requireNotNil(t *testing.T, object interface{}) { - t.Helper() - require.NotNil(t, object) -} - // requireNoErrHeadscaleEnv validates that headscale environment creation succeeded. // Provides specific error context for headscale environment setup failures. func requireNoErrHeadscaleEnv(t *testing.T, err error) { diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 81c33120..67b57896 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "log" + "maps" "net/http" "net/netip" "os" @@ -132,9 +133,7 @@ func WithCustomTLS(cert, key []byte) Option { // can be used to override Headscale configuration. func WithConfigEnv(configEnv map[string]string) Option { return func(hsic *HeadscaleInContainer) { - for key, value := range configEnv { - hsic.env[key] = value - } + maps.Copy(hsic.env, configEnv) } } diff --git a/integration/scenario.go b/integration/scenario.go index fc3ce44d..d584a3ef 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -14,6 +14,7 @@ import ( "net/netip" "net/url" "os" + "slices" "strconv" "strings" "sync" @@ -1159,10 +1160,8 @@ func (s *Scenario) FindTailscaleClientByIP(ip netip.Addr) (TailscaleClient, erro for _, client := range clients { ips, _ := client.IPs() - for _, ip2 := range ips { - if ip == ip2 { - return client, nil - } + if slices.Contains(ips, ip) { + return client, nil } } diff --git a/tools/capver/main.go b/tools/capver/main.go index 0c9066ba..ec2e4d10 100644 --- a/tools/capver/main.go +++ b/tools/capver/main.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "regexp" + "slices" "sort" "strconv" "strings" @@ -116,10 +117,7 @@ func calculateMinSupportedCapabilityVersion(versions map[string]tailcfg.Capabili sort.Strings(majorMinors) // Take the latest 10 versions - supportedCount := supportedMajorMinorVersions - if len(majorMinors) < supportedCount { - supportedCount = len(majorMinors) - } + supportedCount := min(len(majorMinors), supportedMajorMinorVersions) if supportedCount == 0 { return fallbackCapVer @@ -168,9 +166,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion } capsSorted := xmaps.Keys(capVarToTailscaleVer) - sort.Slice(capsSorted, func(i, j int) bool { - return capsSorted[i] < capsSorted[j] - }) + slices.Sort(capsSorted) for _, capVer := range capsSorted { fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]) @@ -223,10 +219,7 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport sort.Strings(majorMinors) // Take latest 10 - supportedCount := supportedMajorMinorVersions - if len(majorMinors) < supportedCount { - supportedCount = len(majorMinors) - } + supportedCount := min(len(majorMinors), supportedMajorMinorVersions) latest10 := majorMinors[len(majorMinors)-supportedCount:] latest3 := majorMinors[len(majorMinors)-3:] @@ -308,9 +301,7 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport // Add a few more test cases capsSorted := xmaps.Keys(capVerToTailscaleVer) - sort.Slice(capsSorted, func(i, j int) bool { - return capsSorted[i] < capsSorted[j] - }) + slices.Sort(capsSorted) testCount := 0 for _, capVer := range capsSorted {