mirror of
https://github.com/juanfont/headscale.git
synced 2025-12-09 20:04:54 +01:00
modernize: run gopls modernize to bring up to 1.25 (#2920)
This commit is contained in:
parent
bfcd9d261d
commit
eec196d200
@ -130,7 +130,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
|||||||
return ctx, client, conn, cancel
|
return ctx, client, conn, cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
func output(result interface{}, override string, outputFormat string) string {
|
func output(result any, override string, outputFormat string) string {
|
||||||
var jsonBytes []byte
|
var jsonBytes []byte
|
||||||
var err error
|
var err error
|
||||||
switch outputFormat {
|
switch outputFormat {
|
||||||
@ -158,7 +158,7 @@ func output(result interface{}, override string, outputFormat string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SuccessOutput prints the result to stdout and exits with status code 0.
|
// SuccessOutput prints the result to stdout and exits with status code 0.
|
||||||
func SuccessOutput(result interface{}, override string, outputFormat string) {
|
func SuccessOutput(result any, override string, outputFormat string) {
|
||||||
fmt.Println(output(result, override, outputFormat))
|
fmt.Println(output(result, override, outputFormat))
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -456,8 +456,8 @@ func boolToInt(b bool) int {
|
|||||||
// DockerContext represents Docker context information.
|
// DockerContext represents Docker context information.
|
||||||
type DockerContext struct {
|
type DockerContext struct {
|
||||||
Name string `json:"Name"`
|
Name string `json:"Name"`
|
||||||
Metadata map[string]interface{} `json:"Metadata"`
|
Metadata map[string]any `json:"Metadata"`
|
||||||
Endpoints map[string]interface{} `json:"Endpoints"`
|
Endpoints map[string]any `json:"Endpoints"`
|
||||||
Current bool `json:"Current"`
|
Current bool `json:"Current"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,7 +473,7 @@ func createDockerClient() (*client.Client, error) {
|
|||||||
|
|
||||||
if contextInfo != nil {
|
if contextInfo != nil {
|
||||||
if endpoints, ok := contextInfo.Endpoints["docker"]; ok {
|
if endpoints, ok := contextInfo.Endpoints["docker"]; ok {
|
||||||
if endpointMap, ok := endpoints.(map[string]interface{}); ok {
|
if endpointMap, ok := endpoints.(map[string]any); ok {
|
||||||
if host, ok := endpointMap["Host"].(string); ok {
|
if host, ok := endpointMap["Host"].(string); ok {
|
||||||
if runConfig.Verbose {
|
if runConfig.Verbose {
|
||||||
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
||||||
|
|||||||
@ -2760,7 +2760,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||||||
require.Equal(t, 2, user2NodesAfter.Len(), "user2 should still have 2 nodes (old nodes from original registration)")
|
require.Equal(t, 2, user2NodesAfter.Len(), "user2 should still have 2 nodes (old nodes from original registration)")
|
||||||
|
|
||||||
// Verify original nodes still exist with original users
|
// Verify original nodes still exist with original users
|
||||||
for i := 0; i < 2; i++ {
|
for i := range 2 {
|
||||||
node := nodes[i]
|
node := nodes[i]
|
||||||
// User1's original nodes should still be owned by user1
|
// User1's original nodes should still be owned by user1
|
||||||
registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID))
|
registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID))
|
||||||
@ -3195,6 +3195,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) {
|
|||||||
assert.Error(t, err, "expired pre-auth key should be rejected")
|
assert.Error(t, err, "expired pre-auth key should be rejected")
|
||||||
assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration")
|
assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey tests that an existing node
|
// TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey tests that an existing node
|
||||||
// can re-register using a pre-auth key that's already marked as Used=true, as long as:
|
// can re-register using a pre-auth key that's already marked as Used=true, as long as:
|
||||||
// 1. The node is re-registering with the same MachineKey it originally used
|
// 1. The node is re-registering with the same MachineKey it originally used
|
||||||
@ -3204,6 +3205,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) {
|
|||||||
//
|
//
|
||||||
// Background: When Docker/Kubernetes containers restart, they keep their persistent state
|
// Background: When Docker/Kubernetes containers restart, they keep their persistent state
|
||||||
// (including the MachineKey), but container entrypoints unconditionally run:
|
// (including the MachineKey), but container entrypoints unconditionally run:
|
||||||
|
//
|
||||||
// tailscale up --authkey=$TS_AUTHKEY
|
// tailscale up --authkey=$TS_AUTHKEY
|
||||||
//
|
//
|
||||||
// This caused nodes to be rejected after restart because the pre-auth key was already
|
// This caused nodes to be rejected after restart because the pre-auth key was already
|
||||||
|
|||||||
@ -31,7 +31,7 @@ func decodingError(name string, err error) error {
|
|||||||
// have a type that implements encoding.TextUnmarshaler.
|
// have a type that implements encoding.TextUnmarshaler.
|
||||||
type TextSerialiser struct{}
|
type TextSerialiser struct{}
|
||||||
|
|
||||||
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue any) error {
|
||||||
fieldValue := reflect.New(field.FieldType)
|
fieldValue := reflect.New(field.FieldType)
|
||||||
|
|
||||||
// If the field is a pointer, we need to dereference it to get the actual type
|
// If the field is a pointer, we need to dereference it to get the actual type
|
||||||
@ -77,10 +77,10 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue any) (any, error) {
|
||||||
switch v := fieldValue.(type) {
|
switch v := fieldValue.(type) {
|
||||||
case encoding.TextMarshaler:
|
case encoding.TextMarshaler:
|
||||||
// If the value is nil, we return nil, however, go nil values are not
|
// If the value is nil, we return nil, however, go nil values are not
|
||||||
|
|||||||
@ -1136,13 +1136,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
// First connection
|
// First connection
|
||||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||||
}()
|
})
|
||||||
|
|
||||||
// Add real work during connection chaos
|
// Add real work during connection chaos
|
||||||
if i%10 == 0 {
|
if i%10 == 0 {
|
||||||
@ -1152,24 +1148,17 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
// Rapid second connection - should replace ch1
|
// Rapid second connection - should replace ch1
|
||||||
ch2 := make(chan *tailcfg.MapResponse, 1)
|
ch2 := make(chan *tailcfg.MapResponse, 1)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
time.Sleep(1 * time.Microsecond)
|
time.Sleep(1 * time.Microsecond)
|
||||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||||
}()
|
})
|
||||||
|
|
||||||
// Remove second connection
|
// Remove second connection
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
|
wg.Go(func() {
|
||||||
time.Sleep(2 * time.Microsecond)
|
time.Sleep(2 * time.Microsecond)
|
||||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||||
}()
|
})
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
@ -1789,10 +1778,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
// This ensures some nodes stay connected to continue receiving updates
|
// This ensures some nodes stay connected to continue receiving updates
|
||||||
startIdx := cycle % len(testNodes)
|
startIdx := cycle % len(testNodes)
|
||||||
|
|
||||||
endIdx := startIdx + len(testNodes)/4
|
endIdx := min(startIdx+len(testNodes)/4, len(testNodes))
|
||||||
if endIdx > len(testNodes) {
|
|
||||||
endIdx = len(testNodes)
|
|
||||||
}
|
|
||||||
|
|
||||||
if startIdx >= endIdx {
|
if startIdx >= endIdx {
|
||||||
startIdx = 0
|
startIdx = 0
|
||||||
@ -2313,7 +2299,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
receivedCount := 0
|
receivedCount := 0
|
||||||
timeout := time.After(500 * time.Millisecond)
|
timeout := time.After(500 * time.Millisecond)
|
||||||
|
|
||||||
for i := 0; i < len(allNodes); i++ {
|
for i := range allNodes {
|
||||||
select {
|
select {
|
||||||
case update := <-newChannels[i]:
|
case update := <-newChannels[i]:
|
||||||
if update != nil {
|
if update != nil {
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package v2
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
@ -178,11 +179,8 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||||||
for _, ips := range resolvedSrcIPs {
|
for _, ips := range resolvedSrcIPs {
|
||||||
for _, n := range sameUserNodes {
|
for _, n := range sameUserNodes {
|
||||||
// Check if any of this node's IPs are in the source set
|
// Check if any of this node's IPs are in the source set
|
||||||
for _, nodeIP := range n.IPs() {
|
if slices.ContainsFunc(n.IPs(), ips.Contains) {
|
||||||
if ips.Contains(nodeIP) {
|
|
||||||
n.AppendToIPSet(&srcIPs)
|
n.AppendToIPSet(&srcIPs)
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -375,11 +373,8 @@ func (pol *Policy) compileSSHPolicy(
|
|||||||
var filteredSrcIPs netipx.IPSetBuilder
|
var filteredSrcIPs netipx.IPSetBuilder
|
||||||
for _, n := range sameUserNodes {
|
for _, n := range sameUserNodes {
|
||||||
// Check if any of this node's IPs are in the source set
|
// Check if any of this node's IPs are in the source set
|
||||||
for _, nodeIP := range n.IPs() {
|
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||||
if srcIPs.Contains(nodeIP) {
|
n.AppendToIPSet(&filteredSrcIPs) // Found this node, move to next
|
||||||
n.AppendToIPSet(&filteredSrcIPs)
|
|
||||||
break // Found this node, move to next
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package v2
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -906,14 +907,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, expectedIP := range expectedDestIPs {
|
for _, expectedIP := range expectedDestIPs {
|
||||||
found := false
|
found := slices.Contains(actualDestIPs, expectedIP)
|
||||||
|
|
||||||
for _, actualIP := range actualDestIPs {
|
|
||||||
if actualIP == expectedIP {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
t.Errorf("expected destination IP %s to be included, got: %v", expectedIP, actualDestIPs)
|
t.Errorf("expected destination IP %s to be included, got: %v", expectedIP, actualDestIPs)
|
||||||
|
|||||||
@ -1007,7 +1007,7 @@ func (g Groups) Contains(group *Group) error {
|
|||||||
// with "group:". If any group name is invalid, an error is returned.
|
// with "group:". If any group name is invalid, an error is returned.
|
||||||
func (g *Groups) UnmarshalJSON(b []byte) error {
|
func (g *Groups) UnmarshalJSON(b []byte) error {
|
||||||
// First unmarshal as a generic map to validate group names first
|
// First unmarshal as a generic map to validate group names first
|
||||||
var rawMap map[string]interface{}
|
var rawMap map[string]any
|
||||||
if err := json.Unmarshal(b, &rawMap); err != nil {
|
if err := json.Unmarshal(b, &rawMap); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1024,7 +1024,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||||||
rawGroups := make(map[string][]string)
|
rawGroups := make(map[string][]string)
|
||||||
for key, value := range rawMap {
|
for key, value := range rawMap {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case []interface{}:
|
case []any:
|
||||||
// Convert []interface{} to []string
|
// Convert []interface{} to []string
|
||||||
var stringSlice []string
|
var stringSlice []string
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
|
|||||||
@ -39,9 +39,10 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var portRanges []tailcfg.PortRange
|
var portRanges []tailcfg.PortRange
|
||||||
parts := strings.Split(portDef, ",")
|
|
||||||
|
|
||||||
for _, part := range parts {
|
parts := strings.SplitSeq(portDef, ",")
|
||||||
|
|
||||||
|
for part := range parts {
|
||||||
if strings.Contains(part, "-") {
|
if strings.Contains(part, "-") {
|
||||||
rangeParts := strings.Split(part, "-")
|
rangeParts := strings.Split(part, "-")
|
||||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||||
|
|||||||
@ -200,9 +200,9 @@ func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DebugRegistrationCache returns debug information about the registration cache.
|
// DebugRegistrationCache returns debug information about the registration cache.
|
||||||
func (s *State) DebugRegistrationCache() map[string]interface{} {
|
func (s *State) DebugRegistrationCache() map[string]any {
|
||||||
// The cache doesn't expose internal statistics, so we provide basic info
|
// The cache doesn't expose internal statistics, so we provide basic info
|
||||||
result := map[string]interface{}{
|
result := map[string]any{
|
||||||
"type": "zcache",
|
"type": "zcache",
|
||||||
"expiration": registerCacheExpiration.String(),
|
"expiration": registerCacheExpiration.String(),
|
||||||
"cleanup": registerCacheCleanup.String(),
|
"cleanup": registerCacheCleanup.String(),
|
||||||
|
|||||||
@ -872,7 +872,7 @@ func TestNodeStoreConcurrentPutNode(t *testing.T) {
|
|||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
results := make(chan bool, concurrentOps)
|
results := make(chan bool, concurrentOps)
|
||||||
for i := 0; i < concurrentOps; i++ {
|
for i := range concurrentOps {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(nodeID int) {
|
go func(nodeID int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -904,7 +904,7 @@ func TestNodeStoreBatchingEfficiency(t *testing.T) {
|
|||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
results := make(chan bool, ops)
|
results := make(chan bool, ops)
|
||||||
for i := 0; i < ops; i++ {
|
for i := range ops {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(nodeID int) {
|
go func(nodeID int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -941,11 +941,12 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errors := make(chan error, numGoroutines*opsPerGoroutine)
|
errors := make(chan error, numGoroutines*opsPerGoroutine)
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(gid int) {
|
go func(gid int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for j := 0; j < opsPerGoroutine; j++ {
|
|
||||||
|
for j := range opsPerGoroutine {
|
||||||
switch j % 3 {
|
switch j % 3 {
|
||||||
case 0:
|
case 0:
|
||||||
resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) {
|
resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) {
|
||||||
@ -993,7 +994,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
|||||||
afterStartGoroutines := runtime.NumGoroutine()
|
afterStartGoroutines := runtime.NumGoroutine()
|
||||||
|
|
||||||
const ops = 100
|
const ops = 100
|
||||||
for i := 0; i < ops; i++ {
|
for i := range ops {
|
||||||
nodeID := types.NodeID(i + 1)
|
nodeID := types.NodeID(i + 1)
|
||||||
node := createConcurrentTestNode(nodeID, "cleanup-node")
|
node := createConcurrentTestNode(nodeID, "cleanup-node")
|
||||||
resultNode := store.PutNode(node)
|
resultNode := store.PutNode(node)
|
||||||
@ -1100,7 +1101,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||||||
|
|
||||||
// --- Edge case: update non-existent node ---
|
// --- Edge case: update non-existent node ---
|
||||||
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||||
store.Start()
|
store.Start()
|
||||||
nonExistentID := types.NodeID(999 + i)
|
nonExistentID := types.NodeID(999 + i)
|
||||||
@ -1124,8 +1125,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) {
|
|||||||
store.Start()
|
store.Start()
|
||||||
defer store.Stop()
|
defer store.Stop()
|
||||||
|
|
||||||
b.ResetTimer()
|
for i := 0; b.Loop(); i++ {
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
nodeID := types.NodeID(i + 1)
|
nodeID := types.NodeID(i + 1)
|
||||||
node := createConcurrentTestNode(nodeID, "bench-node")
|
node := createConcurrentTestNode(nodeID, "bench-node")
|
||||||
store.PutNode(node)
|
store.PutNode(node)
|
||||||
|
|||||||
@ -220,10 +220,12 @@ func DefaultBatcherWorkers() int {
|
|||||||
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
|
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
|
||||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||||
func DefaultBatcherWorkersFor(cpuCount int) int {
|
func DefaultBatcherWorkersFor(cpuCount int) int {
|
||||||
defaultWorkers := (cpuCount * 3) / 4
|
const (
|
||||||
if defaultWorkers < 1 {
|
workerNumerator = 3
|
||||||
defaultWorkers = 1
|
workerDenominator = 4
|
||||||
}
|
)
|
||||||
|
|
||||||
|
defaultWorkers := max((cpuCount*workerNumerator)/workerDenominator, 1)
|
||||||
|
|
||||||
return defaultWorkers
|
return defaultWorkers
|
||||||
}
|
}
|
||||||
|
|||||||
@ -49,22 +49,22 @@ func (l *DBLogWrapper) LogMode(gormLogger.LogLevel) gormLogger.Interface {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DBLogWrapper) Info(ctx context.Context, msg string, data ...interface{}) {
|
func (l *DBLogWrapper) Info(ctx context.Context, msg string, data ...any) {
|
||||||
l.Logger.Info().Msgf(msg, data...)
|
l.Logger.Info().Msgf(msg, data...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DBLogWrapper) Warn(ctx context.Context, msg string, data ...interface{}) {
|
func (l *DBLogWrapper) Warn(ctx context.Context, msg string, data ...any) {
|
||||||
l.Logger.Warn().Msgf(msg, data...)
|
l.Logger.Warn().Msgf(msg, data...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DBLogWrapper) Error(ctx context.Context, msg string, data ...interface{}) {
|
func (l *DBLogWrapper) Error(ctx context.Context, msg string, data ...any) {
|
||||||
l.Logger.Error().Msgf(msg, data...)
|
l.Logger.Error().Msgf(msg, data...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||||
elapsed := time.Since(begin)
|
elapsed := time.Since(begin)
|
||||||
sql, rowsAffected := fc()
|
sql, rowsAffected := fc()
|
||||||
fields := map[string]interface{}{
|
fields := map[string]any{
|
||||||
"duration": elapsed,
|
"duration": elapsed,
|
||||||
"sql": sql,
|
"sql": sql,
|
||||||
"rowsAffected": rowsAffected,
|
"rowsAffected": rowsAffected,
|
||||||
@ -83,7 +83,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
|
|||||||
l.Logger.Debug().Fields(fields).Msgf("")
|
l.Logger.Debug().Fields(fields).Msgf("")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) {
|
||||||
if l.ParameterizedQueries {
|
if l.ParameterizedQueries {
|
||||||
return sql, nil
|
return sql, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -98,7 +98,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||||||
|
|
||||||
// Should NOT contain user data after "Unauthorized"
|
// Should NOT contain user data after "Unauthorized"
|
||||||
// This is the security bypass - if users array is present, auth was bypassed
|
// This is the security bypass - if users array is present, auth was bypassed
|
||||||
var jsonCheck map[string]interface{}
|
var jsonCheck map[string]any
|
||||||
jsonErr := json.Unmarshal(body, &jsonCheck)
|
jsonErr := json.Unmarshal(body, &jsonCheck)
|
||||||
|
|
||||||
// If we can unmarshal JSON and it contains "users", that's the bypass
|
// If we can unmarshal JSON and it contains "users", that's the bypass
|
||||||
@ -278,8 +278,8 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||||||
var responseBody string
|
var responseBody string
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if strings.HasPrefix(line, "HTTP_CODE:") {
|
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
|
||||||
httpCode = strings.TrimPrefix(line, "HTTP_CODE:")
|
httpCode = after
|
||||||
} else {
|
} else {
|
||||||
responseBody += line
|
responseBody += line
|
||||||
}
|
}
|
||||||
@ -324,8 +324,8 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||||||
var responseBody string
|
var responseBody string
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if strings.HasPrefix(line, "HTTP_CODE:") {
|
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
|
||||||
httpCode = strings.TrimPrefix(line, "HTTP_CODE:")
|
httpCode = after
|
||||||
} else {
|
} else {
|
||||||
responseBody += line
|
responseBody += line
|
||||||
}
|
}
|
||||||
@ -359,8 +359,8 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||||||
var responseBody string
|
var responseBody string
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if strings.HasPrefix(line, "HTTP_CODE:") {
|
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
|
||||||
httpCode = strings.TrimPrefix(line, "HTTP_CODE:")
|
httpCode = after
|
||||||
} else {
|
} else {
|
||||||
responseBody += line
|
responseBody += line
|
||||||
}
|
}
|
||||||
|
|||||||
@ -56,13 +56,6 @@ type NodeSystemStatus struct {
|
|||||||
NodeStore bool
|
NodeStore bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// requireNotNil validates that an object is not nil and fails the test if it is.
|
|
||||||
// This helper provides consistent error messaging for nil checks in integration tests.
|
|
||||||
func requireNotNil(t *testing.T, object interface{}) {
|
|
||||||
t.Helper()
|
|
||||||
require.NotNil(t, object)
|
|
||||||
}
|
|
||||||
|
|
||||||
// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded.
|
// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded.
|
||||||
// Provides specific error context for headscale environment setup failures.
|
// Provides specific error context for headscale environment setup failures.
|
||||||
func requireNoErrHeadscaleEnv(t *testing.T, err error) {
|
func requireNoErrHeadscaleEnv(t *testing.T, err error) {
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@ -132,9 +133,7 @@ func WithCustomTLS(cert, key []byte) Option {
|
|||||||
// can be used to override Headscale configuration.
|
// can be used to override Headscale configuration.
|
||||||
func WithConfigEnv(configEnv map[string]string) Option {
|
func WithConfigEnv(configEnv map[string]string) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
for key, value := range configEnv {
|
maps.Copy(hsic.env, configEnv)
|
||||||
hsic.env[key] = value
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -1159,12 +1160,10 @@ func (s *Scenario) FindTailscaleClientByIP(ip netip.Addr) (TailscaleClient, erro
|
|||||||
|
|
||||||
for _, client := range clients {
|
for _, client := range clients {
|
||||||
ips, _ := client.IPs()
|
ips, _ := client.IPs()
|
||||||
for _, ip2 := range ips {
|
if slices.Contains(ips, ip) {
|
||||||
if ip == ip2 {
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errNoClientFound
|
return nil, errNoClientFound
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -116,10 +117,7 @@ func calculateMinSupportedCapabilityVersion(versions map[string]tailcfg.Capabili
|
|||||||
sort.Strings(majorMinors)
|
sort.Strings(majorMinors)
|
||||||
|
|
||||||
// Take the latest 10 versions
|
// Take the latest 10 versions
|
||||||
supportedCount := supportedMajorMinorVersions
|
supportedCount := min(len(majorMinors), supportedMajorMinorVersions)
|
||||||
if len(majorMinors) < supportedCount {
|
|
||||||
supportedCount = len(majorMinors)
|
|
||||||
}
|
|
||||||
|
|
||||||
if supportedCount == 0 {
|
if supportedCount == 0 {
|
||||||
return fallbackCapVer
|
return fallbackCapVer
|
||||||
@ -168,9 +166,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
|||||||
}
|
}
|
||||||
|
|
||||||
capsSorted := xmaps.Keys(capVarToTailscaleVer)
|
capsSorted := xmaps.Keys(capVarToTailscaleVer)
|
||||||
sort.Slice(capsSorted, func(i, j int) bool {
|
slices.Sort(capsSorted)
|
||||||
return capsSorted[i] < capsSorted[j]
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, capVer := range capsSorted {
|
for _, capVer := range capsSorted {
|
||||||
fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
|
fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
|
||||||
@ -223,10 +219,7 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport
|
|||||||
sort.Strings(majorMinors)
|
sort.Strings(majorMinors)
|
||||||
|
|
||||||
// Take latest 10
|
// Take latest 10
|
||||||
supportedCount := supportedMajorMinorVersions
|
supportedCount := min(len(majorMinors), supportedMajorMinorVersions)
|
||||||
if len(majorMinors) < supportedCount {
|
|
||||||
supportedCount = len(majorMinors)
|
|
||||||
}
|
|
||||||
|
|
||||||
latest10 := majorMinors[len(majorMinors)-supportedCount:]
|
latest10 := majorMinors[len(majorMinors)-supportedCount:]
|
||||||
latest3 := majorMinors[len(majorMinors)-3:]
|
latest3 := majorMinors[len(majorMinors)-3:]
|
||||||
@ -308,9 +301,7 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport
|
|||||||
|
|
||||||
// Add a few more test cases
|
// Add a few more test cases
|
||||||
capsSorted := xmaps.Keys(capVerToTailscaleVer)
|
capsSorted := xmaps.Keys(capVerToTailscaleVer)
|
||||||
sort.Slice(capsSorted, func(i, j int) bool {
|
slices.Sort(capsSorted)
|
||||||
return capsSorted[i] < capsSorted[j]
|
|
||||||
})
|
|
||||||
|
|
||||||
testCount := 0
|
testCount := 0
|
||||||
for _, capVer := range capsSorted {
|
for _, capVer := range capsSorted {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user