1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-02 13:47:00 +02:00

stuff auth lint

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-08-18 22:31:51 +02:00
parent 2a906cd15e
commit 3a92b14c1a
No known key found for this signature in database
24 changed files with 296 additions and 180 deletions

View File

@ -551,13 +551,12 @@ be assigned to nodes.`,
} }
} }
if confirm || force { if confirm || force {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force }) changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force})
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,

View File

@ -265,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive(
) )
log.Info().Msgf("Starting node registration using key: %s", registrationId) log.Info().Msgf("Starting node registration using key: %s", registrationId)
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId), AuthURL: h.authProvider.AuthURL(registrationId),
}, nil }, nil

View File

@ -1,6 +1,6 @@
package capver package capver
//Generated DO NOT EDIT // Generated DO NOT EDIT
import "tailscale.com/tailcfg" import "tailscale.com/tailcfg"
@ -37,18 +37,17 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.86.2": 123, "v1.86.2": 123,
} }
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
90: "v1.64.2", 90: "v1.64.2",
95: "v1.66.0", 95: "v1.66.0",
97: "v1.68.0", 97: "v1.68.0",
102: "v1.70.0", 102: "v1.70.0",
104: "v1.72.0", 104: "v1.72.0",
106: "v1.74.0", 106: "v1.74.0",
109: "v1.78.0", 109: "v1.78.0",
113: "v1.80.0", 113: "v1.80.0",
115: "v1.82.0", 115: "v1.82.0",
116: "v1.84.0", 116: "v1.84.0",
122: "v1.86.0", 122: "v1.86.0",
123: "v1.86.2", 123: "v1.86.2",
} }

View File

@ -936,7 +936,7 @@ AND auth_key_id NOT IN (
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed // - NEVER use gorm.AutoMigrate, write the exact migration steps needed
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time. // - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
// - Never write migrations that requires foreign keys to be disabled. // - Never write migrations that requires foreign keys to be disabled.
}, },
) )
if err := runMigrations(cfg, dbConn, migrations); err != nil { if err := runMigrations(cfg, dbConn, migrations); err != nil {

View File

@ -269,9 +269,9 @@ func RenameNode(tx *gorm.DB,
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
return fmt.Errorf("failed to check name uniqueness: %w", err) return fmt.Errorf("failed to check name uniqueness: %w", err)
} }
if count > 0 { if count > 0 {
return fmt.Errorf("name is not unique") return errors.New("name is not unique")
} }
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
@ -327,7 +327,6 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
}) })
} }
// RegisterNodeForTest is used only for testing purposes to register a node directly in the database. // RegisterNodeForTest is used only for testing purposes to register a node directly in the database.
// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey. // Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey.
func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {

View File

@ -205,7 +205,7 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil { if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
return fmt.Errorf("failed to check if user exists: %w", err) return fmt.Errorf("failed to check if user exists: %w", err)
} }
if !userExists { if !userExists {
return ErrUserNotFound return ErrUserNotFound
} }

View File

@ -20,7 +20,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Check Accept header to determine response format // Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept") acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json") wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON { if wantsJSON {
overview := h.state.DebugOverviewJSON() overview := h.state.DebugOverviewJSON()
overviewJSON, err := json.MarshalIndent(overview, "", " ") overviewJSON, err := json.MarshalIndent(overview, "", " ")
@ -107,7 +107,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Check Accept header to determine response format // Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept") acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json") wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON { if wantsJSON {
derpInfo := h.state.DebugDERPJSON() derpInfo := h.state.DebugDERPJSON()
derpJSON, err := json.MarshalIndent(derpInfo, "", " ") derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
@ -132,7 +132,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Check Accept header to determine response format // Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept") acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json") wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON { if wantsJSON {
nodeStoreInfo := h.state.DebugNodeStoreJSON() nodeStoreInfo := h.state.DebugNodeStoreJSON()
nodeStoreJSON, err := json.MarshalIndent(nodeStoreInfo, "", " ") nodeStoreJSON, err := json.MarshalIndent(nodeStoreInfo, "", " ")
@ -170,7 +170,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Check Accept header to determine response format // Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept") acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json") wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON { if wantsJSON {
routes := h.state.DebugRoutes() routes := h.state.DebugRoutes()
routesJSON, err := json.MarshalIndent(routes, "", " ") routesJSON, err := json.MarshalIndent(routes, "", " ")
@ -195,7 +195,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Check Accept header to determine response format // Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept") acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json") wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON { if wantsJSON {
policyManagerInfo := h.state.DebugPolicyManagerJSON() policyManagerInfo := h.state.DebugPolicyManagerJSON()
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")

View File

@ -77,7 +77,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
var host string var host string
var port int var port int
var portStr string var portStr string
// Extract hostname and port from URL // Extract hostname and port from URL
host, portStr, err = net.SplitHostPort(serverURL.Host) host, portStr, err = net.SplitHostPort(serverURL.Host)
if err != nil { if err != nil {
@ -94,7 +94,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
return tailcfg.DERPRegion{}, err return tailcfg.DERPRegion{}, err
} }
} }
// If debug flag is set, resolve hostname to IP address // If debug flag is set, resolve hostname to IP address
if debugUseDERPIP { if debugUseDERPIP {
ips, err := net.LookupIP(host) ips, err := net.LookupIP(host)

View File

@ -350,15 +350,16 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
if !ok { if !ok {
return false return false
} }
// nil means connected // nil means connected
if val == nil { if val == nil {
return true return true
} }
// During grace period, always return true to allow DNS resolution // During grace period, always return true to allow DNS resolution
// for logout HTTP requests to complete successfully // for logout HTTP requests to complete successfully
gracePeriod := 45 * time.Second gracePeriod := 45 * time.Second
return time.Since(*val) < gracePeriod return time.Since(*val) < gracePeriod
} }

View File

@ -27,7 +27,7 @@ type batcherTestCase struct {
} }
// testBatcherWrapper wraps a real batcher to add online/offline notifications // testBatcherWrapper wraps a real batcher to add online/offline notifications
// that would normally be sent by poll.go in production // that would normally be sent by poll.go in production.
type testBatcherWrapper struct { type testBatcherWrapper struct {
Batcher Batcher
} }
@ -58,7 +58,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
return true return true
} }
// wrapBatcherForTest wraps a batcher with test-specific behavior // wrapBatcherForTest wraps a batcher with test-specific behavior.
func wrapBatcherForTest(b Batcher) Batcher { func wrapBatcherForTest(b Batcher) Batcher {
return &testBatcherWrapper{Batcher: b} return &testBatcherWrapper{Batcher: b}
} }
@ -808,7 +808,7 @@ func TestBatcherBasicOperations(t *testing.T) {
// Disconnect the second node // Disconnect the second node
batcher.RemoveNode(tn2.n.ID, tn2.ch) batcher.RemoveNode(tn2.n.ID, tn2.ch)
assert.False(t, batcher.IsConnected(tn2.n.ID)) // Note: IsConnected may return true during grace period for DNS resolution
// First node should get update that second has disconnected. // First node should get update that second has disconnected.
select { select {
@ -841,9 +841,8 @@ func TestBatcherBasicOperations(t *testing.T) {
// Test RemoveNode // Test RemoveNode
batcher.RemoveNode(tn.n.ID, tn.ch) batcher.RemoveNode(tn.n.ID, tn.ch)
if batcher.IsConnected(tn.n.ID) { // Note: IsConnected may return true during grace period for DNS resolution
t.Error("Node should be disconnected after RemoveNode") // The node is actually removed from active connections but grace period allows DNS lookups
}
}) })
} }
} }

View File

@ -140,7 +140,7 @@ func tailNode(
lastSeen := node.LastSeen().Get() lastSeen := node.LastSeen().Get()
// Only set LastSeen if the node is offline OR if LastSeen is recent // Only set LastSeen if the node is offline OR if LastSeen is recent
// (indicating it disconnected recently but might be in grace period) // (indicating it disconnected recently but might be in grace period)
if !node.IsOnline().Valid() || !node.IsOnline().Get() || if !node.IsOnline().Valid() || !node.IsOnline().Get() ||
time.Since(lastSeen) < 60*time.Second { time.Since(lastSeen) < 60*time.Second {
tNode.LastSeen = &lastSeen tNode.LastSeen = &lastSeen
} }

View File

@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
util.LogErr(err, "could not get userinfo; only using claims from id token") util.LogErr(err, "could not get userinfo; only using claims from id token")
} }
// The user claims are now updated from the the userinfo endpoint so we can verify the user a // The user claims are now updated from the userinfo endpoint so we can verify the user
// against allowed emails, email domains, and groups. // against allowed emails, email domains, and groups.
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
httpError(writer, err) httpError(writer, err)

View File

@ -147,12 +147,12 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
// This ensures that: // This ensures that:
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes) // - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
// - New routes can be auto-approved according to policy // - New routes can be auto-approved according to policy
// - Routes can only be removed by explicit admin action (not by auto-approval) // - Routes can only be removed by explicit admin action (not by auto-approval).
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) { func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
if pm == nil { if pm == nil {
return currentApproved, false return currentApproved, false
} }
// Start with ALL currently approved routes - we never remove approved routes // Start with ALL currently approved routes - we never remove approved routes
newApproved := make([]netip.Prefix, len(currentApproved)) newApproved := make([]netip.Prefix, len(currentApproved))
copy(newApproved, currentApproved) copy(newApproved, currentApproved)
@ -163,13 +163,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
if slices.Contains(newApproved, route) { if slices.Contains(newApproved, route) {
continue continue
} }
// Check if this new route can be auto-approved by policy // Check if this new route can be auto-approved by policy
canApprove := pm.NodeCanApproveRoute(nv, route) canApprove := pm.NodeCanApproveRoute(nv, route)
if canApprove { if canApprove {
newApproved = append(newApproved, route) newApproved = append(newApproved, route)
} }
log.Trace(). log.Trace().
Uint64("node.id", nv.ID().Uint64()). Uint64("node.id", nv.ID().Uint64()).
Str("node.name", nv.Hostname()). Str("node.name", nv.Hostname()).

View File

@ -79,13 +79,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
node *types.Node node *types.Node
currentApproved []netip.Prefix currentApproved []netip.Prefix
announcedRoutes []netip.Prefix announcedRoutes []netip.Prefix
wantApproved []netip.Prefix wantApproved []netip.Prefix
wantChanged bool wantChanged bool
description string description string
}{ }{
{ {
name: "previously_approved_route_no_longer_advertised_should_remain", name: "previously_approved_route_no_longer_advertised_should_remain",
@ -138,8 +138,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
description: "All approved routes should remain when no routes are announced", description: "All approved routes should remain when no routes are announced",
}, },
{ {
name: "no_changes_when_announced_equals_approved", name: "no_changes_when_announced_equals_approved",
node: node1, node: node1,
currentApproved: []netip.Prefix{ currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("10.0.0.0/24"),
}, },
@ -153,13 +153,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
description: "No changes should occur when announced routes match approved routes", description: "No changes should occur when announced routes match approved routes",
}, },
{ {
name: "auto_approve_multiple_new_routes", name: "auto_approve_multiple_new_routes",
node: node1, node: node1,
currentApproved: []netip.Prefix{ currentApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
}, },
announcedRoutes: []netip.Prefix{ announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
}, },
wantApproved: []netip.Prefix{ wantApproved: []netip.Prefix{
@ -171,8 +171,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
description: "Multiple new routes should be auto-approved while keeping existing approved routes", description: "Multiple new routes should be auto-approved while keeping existing approved routes",
}, },
{ {
name: "node_without_permission_no_auto_approval", name: "node_without_permission_no_auto_approval",
node: node2, // Different node without the tag node: node2, // Different node without the tag
currentApproved: []netip.Prefix{ currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("10.0.0.0/24"),
}, },
@ -192,14 +192,14 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes) gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
// Sort for comparison since ApproveRoutesWithPolicy sorts the results // Sort for comparison since ApproveRoutesWithPolicy sorts the results
tsaddr.SortPrefixes(tt.wantApproved) tsaddr.SortPrefixes(tt.wantApproved)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
// Verify that all previously approved routes are still present // Verify that all previously approved routes are still present
for _, prevRoute := range tt.currentApproved { for _, prevRoute := range tt.currentApproved {
assert.Contains(t, gotApproved, prevRoute, assert.Contains(t, gotApproved, prevRoute,
"previously approved route %s was removed - this should never happen", prevRoute) "previously approved route %s was removed - this should never happen", prevRoute)
} }
}) })
@ -325,7 +325,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes) gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch") assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
// Handle nil vs empty slice comparison // Handle nil vs empty slice comparison
if tt.wantApproved == nil { if tt.wantApproved == nil {
assert.Nil(t, gotApproved, "expected nil approved routes") assert.Nil(t, gotApproved, "expected nil approved routes")
@ -336,4 +336,4 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
}) })
} }
} }
} }

View File

@ -39,15 +39,15 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
}` }`
tests := []struct { tests := []struct {
name string name string
currentApproved []netip.Prefix currentApproved []netip.Prefix
announcedRoutes []netip.Prefix announcedRoutes []netip.Prefix
nodeHostname string nodeHostname string
nodeUser string nodeUser string
nodeTags []string nodeTags []string
wantApproved []netip.Prefix wantApproved []netip.Prefix
wantChanged bool wantChanged bool
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
}{ }{
{ {
name: "previously_approved_route_no_longer_advertised_remains", name: "previously_approved_route_no_longer_advertised_remains",
@ -60,14 +60,14 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
}, },
nodeUser: "test", nodeUser: "test",
wantApproved: []netip.Prefix{ wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Should remain! netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
}, },
wantChanged: false, wantChanged: false,
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
}, },
{ {
name: "add_new_auto_approved_route_keeps_existing", name: "add_new_auto_approved_route_keeps_existing",
currentApproved: []netip.Prefix{ currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("10.0.0.0/24"),
}, },
@ -136,8 +136,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
}, },
announcedRoutes: []netip.Prefix{ announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
}, },
nodeUser: "test", nodeUser: "test",
@ -151,7 +151,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
} }
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
for _, tt := range tests { for _, tt := range tests {
for i, pmf := range pmfs { for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
@ -358,4 +358,4 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
assert.False(t, gotChanged) assert.False(t, gotChanged)
assert.Equal(t, currentApproved, gotApproved) assert.Equal(t, currentApproved, gotApproved)
} }

View File

@ -152,7 +152,6 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix)
Strs("prefixes", util.PrefixesToString(prefixes)). Strs("prefixes", util.PrefixesToString(prefixes)).
Msg("PrimaryRoutes.SetRoutes called") Msg("PrimaryRoutes.SetRoutes called")
// If no routes are being set, remove the node from the routes map. // If no routes are being set, remove the node from the routes map.
if len(prefixes) == 0 { if len(prefixes) == 0 {
wasPresent := false wasPresent := false

View File

@ -33,16 +33,16 @@ type DebugOverviewInfo struct {
// DebugDERPInfo represents DERP map information in a structured format. // DebugDERPInfo represents DERP map information in a structured format.
type DebugDERPInfo struct { type DebugDERPInfo struct {
Configured bool `json:"configured"` Configured bool `json:"configured"`
TotalRegions int `json:"total_regions"` TotalRegions int `json:"total_regions"`
Regions map[int]*DebugDERPRegion `json:"regions,omitempty"` Regions map[int]*DebugDERPRegion `json:"regions,omitempty"`
} }
// DebugDERPRegion represents a single DERP region. // DebugDERPRegion represents a single DERP region.
type DebugDERPRegion struct { type DebugDERPRegion struct {
RegionID int `json:"region_id"` RegionID int `json:"region_id"`
RegionName string `json:"region_name"` RegionName string `json:"region_name"`
Nodes []*DebugDERPNode `json:"nodes"` Nodes []*DebugDERPNode `json:"nodes"`
} }
// DebugDERPNode represents a single DERP node. // DebugDERPNode represents a single DERP node.
@ -282,7 +282,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
// Node statistics // Node statistics
info.Nodes.Total = allNodes.Len() info.Nodes.Total = allNodes.Len()
now := time.Now() now := time.Now()
for _, node := range allNodes.All() { for _, node := range allNodes.All() {
if node.Valid() { if node.Valid() {
userName := node.User().Name userName := node.User().Name

View File

@ -1012,7 +1012,7 @@ func (s *State) HandleNodeFromAuthPath(
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
} }
// Check if node already exists by node key (this is a refresh/re-registration) // Check if node already exists by node key
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey) existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey)
if exists && existingNodeView.Valid() { if exists && existingNodeView.Valid() {
// Node exists - this is a refresh/re-registration // Node exists - this is a refresh/re-registration
@ -1028,8 +1028,8 @@ func (s *State) HandleNodeFromAuthPath(
if expiry != nil { if expiry != nil {
node.Expiry = expiry node.Expiry = expiry
} }
// Node is re-registering, so it's coming online // Mark as offline since node is reconnecting
node.IsOnline = ptr.To(true) node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now()) node.LastSeen = ptr.To(time.Now())
}) })
@ -1048,6 +1048,7 @@ func (s *State) HandleNodeFromAuthPath(
// Get updated node from NodeStore // Get updated node from NodeStore
updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID())
return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil
} }
@ -1059,9 +1060,25 @@ func (s *State) HandleNodeFromAuthPath(
Str("expiresAt", fmt.Sprintf("%v", expiry)). Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("Registering new node from auth callback") Msg("Registering new node from auth callback")
// Check if node exists with same machine key
var existingMachineNode *types.Node
if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() {
existingMachineNode = nv.AsStruct()
}
// Check for different user registration
if existingMachineNode != nil && existingMachineNode.UserID != uint(userID) {
return types.NodeView{}, change.EmptySet, hsdb.ErrDifferentRegisteredUser
}
// Prepare the node for registration // Prepare the node for registration
nodeToRegister := regEntry.Node nodeToRegister := regEntry.Node
nodeToRegister.UserID = uint(userID)
nodeToRegister.User = *user
nodeToRegister.RegisterMethod = registrationMethod nodeToRegister.RegisterMethod = registrationMethod
if expiry != nil {
nodeToRegister.Expiry = expiry
}
// Handle IP allocation // Handle IP allocation
var ipv4, ipv6 *netip.Addr var ipv4, ipv6 *netip.Addr
@ -1092,16 +1109,47 @@ func (s *State) HandleNodeFromAuthPath(
nodeToRegister.GivenName = givenName nodeToRegister.GivenName = givenName
} }
savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{ var savedNode *types.Node
node: &nodeToRegister, if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
userID: userID, // Update existing node - NodeStore first, then database
user: user, s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) {
expiry: expiry, node.NodeKey = nodeToRegister.NodeKey
updateExistingNode: updateFunc, node.DiscoKey = nodeToRegister.DiscoKey
postSaveCallback: nil, // No post-save callback needed node.Hostname = nodeToRegister.Hostname
}) node.Hostinfo = nodeToRegister.Hostinfo
if err != nil { node.Endpoints = nodeToRegister.Endpoints
return types.NodeView{}, change.EmptySet, err node.RegisterMethod = nodeToRegister.RegisterMethod
if expiry != nil {
node.Expiry = expiry
}
node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now())
})
// Save to database
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err := tx.Save(&nodeToRegister).Error; err != nil {
return nil, fmt.Errorf("failed to save node: %w", err)
}
return &nodeToRegister, nil
})
if err != nil {
return types.NodeView{}, change.EmptySet, err
}
} else {
// New node - database first to get ID, then NodeStore
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err := tx.Save(&nodeToRegister).Error; err != nil {
return nil, fmt.Errorf("failed to save node: %w", err)
}
return &nodeToRegister, nil
})
if err != nil {
return types.NodeView{}, change.EmptySet, err
}
// Add to NodeStore after database creates the ID
s.nodeStore.PutNode(*savedNode)
} }
// Delete from registration cache // Delete from registration cache
@ -1114,13 +1162,17 @@ func (s *State) HandleNodeFromAuthPath(
} }
close(regEntry.Registered) close(regEntry.Registered)
// Finalize registration // Update policy manager
c, err := s.finalizeNodeRegistration(savedNode) nodesChange, err := s.updatePolicyManagerNodes()
if err != nil { if err != nil {
return savedNode.View(), c, err return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err)
} }
return savedNode.View(), c, nil if !nodesChange.Empty() {
return savedNode.View(), nodesChange, nil
}
return savedNode.View(), change.NodeAdded(savedNode.ID), nil
} }
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
@ -1145,29 +1197,17 @@ func (s *State) HandleNodeFromPreAuthKey(
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
// Find the node to delete // Find the node to delete
var nodeToDelete types.NodeView var nodeToDelete types.NodeView
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { for _, nv := range s.nodeStore.ListNodes().All() {
nodeToDelete = nv if nv.Valid() && nv.MachineKey() == machineKey {
nodeToDelete = nv
break
}
} }
if nodeToDelete.Valid() { if nodeToDelete.Valid() {
c, err := s.DeleteNode(nodeToDelete) c, err := s.DeleteNode(nodeToDelete)
if err != nil { if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
} }
return types.NodeView{}, c, nil
}
return types.NodeView{}, change.EmptySet, nil
}
// Check if node already exists by node key (this is a refresh/re-registration)
existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regReq.NodeKey)
if exists && existingNodeView.Valid() {
// Node exists - this is a refresh/re-registration
log.Debug().
Str("node", regReq.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("user", pak.User.Username()).
Msg("Refreshing existing node registration with pre-auth key")
return types.NodeView{}, c, nil return types.NodeView{}, c, nil
} }
@ -1182,9 +1222,17 @@ func (s *State) HandleNodeFromPreAuthKey(
Str("user", pak.User.Username()). Str("user", pak.User.Username()).
Msg("Registering node with pre-auth key") Msg("Registering node with pre-auth key")
// Check if node already exists with same machine key
var existingNode *types.Node
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
existingNode = nv.AsStruct()
}
// Prepare the node for registration // Prepare the node for registration
nodeToRegister := types.Node{ nodeToRegister := types.Node{
Hostname: regReq.Hostinfo.Hostname, Hostname: regReq.Hostinfo.Hostname,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey, MachineKey: machineKey,
NodeKey: regReq.NodeKey, NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo, Hostinfo: regReq.Hostinfo,
@ -1195,39 +1243,58 @@ func (s *State) HandleNodeFromPreAuthKey(
AuthKeyID: &pak.ID, AuthKeyID: &pak.ID,
} }
var expiry *time.Time
if !regReq.Expiry.IsZero() { if !regReq.Expiry.IsZero() {
nodeToRegister.Expiry = &regReq.Expiry nodeToRegister.Expiry = &regReq.Expiry
} }
// Post-save callback to use the pre-auth key // Handle IP allocation and existing node properties
postSaveFunc := func(tx *gorm.DB, savedNode *types.Node) error { var ipv4, ipv6 *netip.Addr
if !pak.Reusable {
return hsdb.UsePreAuthKey(tx, pak)
}
return nil
}
// Check if node already exists with same machine key for logging
var existingNode *types.Node
if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() {
existingNode = nv.AsStruct()
}
savedNode, err := s.registerOrUpdateNode(nodeRegistrationHelper{
node: &nodeToRegister,
userID: types.UserID(pak.User.ID),
user: &pak.User,
expiry: expiry,
updateExistingNode: updateFunc,
postSaveCallback: postSaveFunc,
})
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("registering node: %w", err)
}
// Log re-authorization if it was an existing node
if existingNode != nil && existingNode.UserID == pak.User.ID { if existingNode != nil && existingNode.UserID == pak.User.ID {
// Reuse existing node properties
nodeToRegister.ID = existingNode.ID
nodeToRegister.GivenName = existingNode.GivenName
nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes
ipv4 = existingNode.IPv4
ipv6 = existingNode.IPv6
} else {
// Allocate new IPs
ipv4, ipv6, err = s.ipAlloc.Next()
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
}
}
nodeToRegister.IPv4 = ipv4
nodeToRegister.IPv6 = ipv6
// Ensure unique given name if not set
if nodeToRegister.GivenName == "" {
givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname)
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err)
}
nodeToRegister.GivenName = givenName
}
var savedNode *types.Node
if existingNode != nil && existingNode.UserID == pak.User.ID {
// Update existing node - NodeStore first, then database
s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
node.NodeKey = nodeToRegister.NodeKey
node.Hostname = nodeToRegister.Hostname
node.Hostinfo = nodeToRegister.Hostinfo
node.Endpoints = nodeToRegister.Endpoints
node.RegisterMethod = nodeToRegister.RegisterMethod
node.ForcedTags = nodeToRegister.ForcedTags
node.AuthKey = nodeToRegister.AuthKey
node.AuthKeyID = nodeToRegister.AuthKeyID
if nodeToRegister.Expiry != nil {
node.Expiry = nodeToRegister.Expiry
}
node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now())
})
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", nodeToRegister.Hostname). Str("node", nodeToRegister.Hostname).
@ -1235,12 +1302,65 @@ func (s *State) HandleNodeFromPreAuthKey(
Str("node_key", regReq.NodeKey.ShortString()). Str("node_key", regReq.NodeKey.ShortString()).
Str("user", pak.User.Username()). Str("user", pak.User.Username()).
Msg("Node re-authorized") Msg("Node re-authorized")
// Save to database
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err := tx.Save(&nodeToRegister).Error; err != nil {
return nil, fmt.Errorf("failed to save node: %w", err)
}
if !pak.Reusable {
err = hsdb.UsePreAuthKey(tx, pak)
if err != nil {
return nil, fmt.Errorf("using pre auth key: %w", err)
}
}
return &nodeToRegister, nil
})
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
}
} else {
// New node - database first to get ID, then NodeStore
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err := tx.Save(&nodeToRegister).Error; err != nil {
return nil, fmt.Errorf("failed to save node: %w", err)
}
if !pak.Reusable {
err = hsdb.UsePreAuthKey(tx, pak)
if err != nil {
return nil, fmt.Errorf("using pre auth key: %w", err)
}
}
return &nodeToRegister, nil
})
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
}
// Add to NodeStore after database creates the ID
s.nodeStore.PutNode(*savedNode)
} }
// Finalize registration // Update policy managers
c, err := s.finalizeNodeRegistration(savedNode) usersChange, err := s.updatePolicyManagerUsers()
if err != nil { if err != nil {
return savedNode.View(), c, err return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err)
}
nodesChange, err := s.updatePolicyManagerNodes()
if err != nil {
return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err)
}
var c change.ChangeSet
if !usersChange.Empty() || !nodesChange.Empty() {
c = change.PolicyChange()
} else {
c = change.NodeAdded(savedNode.ID)
} }
return savedNode.View(), c, nil return savedNode.View(), c, nil

View File

@ -1317,10 +1317,10 @@ func TestACLAutogroupTagged(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Create nodes with proper naming // Create nodes with proper naming
for i := 0; i < spec.NodesPerUser; i++ { for i := range spec.NodesPerUser {
var tags []string var tags []string
var version string var version string
if i == 0 { if i == 0 {
// First node is tagged // First node is tagged
tags = []string{"tag:test"} tags = []string{"tag:test"}
@ -1395,15 +1395,15 @@ func TestACLAutogroupTagged(t *testing.T) {
// First, categorize nodes by checking their tags // First, categorize nodes by checking their tags
for _, client := range allClients { for _, client := range allClients {
hostname := client.Hostname() hostname := client.Hostname()
assert.EventuallyWithT(t, func(ct *assert.CollectT) { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status() status, err := client.Status()
assert.NoError(ct, err) assert.NoError(ct, err)
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 { if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
// This is a tagged node // This is a tagged node
assert.Len(ct, status.Peers(), 1, "tagged node %s should see exactly 1 peer", hostname) assert.Len(ct, status.Peers(), 1, "tagged node %s should see exactly 1 peer", hostname)
// Add to tagged list only once we've verified it // Add to tagged list only once we've verified it
found := false found := false
for _, tc := range taggedClients { for _, tc := range taggedClients {
@ -1417,8 +1417,8 @@ func TestACLAutogroupTagged(t *testing.T) {
} }
} else { } else {
// This is an untagged node // This is an untagged node
assert.Len(ct, status.Peers(), 0, "untagged node %s should see 0 peers", hostname) assert.Empty(ct, status.Peers(), "untagged node %s should see 0 peers", hostname)
// Add to untagged list only once we've verified it // Add to untagged list only once we've verified it
found := false found := false
for _, uc := range untaggedClients { for _, uc := range untaggedClients {
@ -1431,7 +1431,7 @@ func TestACLAutogroupTagged(t *testing.T) {
untaggedClients = append(untaggedClients, client) untaggedClients = append(untaggedClients, client)
} }
} }
}, 30*time.Second, 1*time.Second, fmt.Sprintf("verifying peer visibility for node %s", hostname)) }, 30*time.Second, 1*time.Second, "verifying peer visibility for node %s", hostname)
} }
// Verify we have the expected number of tagged and untagged nodes // Verify we have the expected number of tagged and untagged nodes
@ -1443,7 +1443,7 @@ func TestACLAutogroupTagged(t *testing.T) {
status, err := client.Status() status, err := client.Status()
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname()) require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
require.Greater(t, status.Self.Tags.Len(), 0, "tagged node %s should have at least one tag", client.Hostname()) require.Positive(t, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags) t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags)
} }

View File

@ -124,7 +124,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
// //
// Known timing considerations: // Known timing considerations:
// - Nodes may expire at different times due to sequential login processing // - Nodes may expire at different times due to sequential login processing
// - The test must account for login time spread between first and last node // - The test must account for login time spread between first and last node.
func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
@ -186,7 +186,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
// - Network and processing delays // - Network and processing delays
// - Safety margin for test reliability // - Safety margin for test reliability
loginTimeSpread := 1 * time.Minute // Account for sequential login delays loginTimeSpread := 1 * time.Minute // Account for sequential login delays
safetyBuffer := 30 * time.Second // Additional safety margin safetyBuffer := 30 * time.Second // Additional safety margin
totalWaitTime := shortAccessTTL + loginTimeSpread + safetyBuffer totalWaitTime := shortAccessTTL + loginTimeSpread + safetyBuffer
t.Logf("Waiting %v for OIDC tokens to expire (TTL: %v, spread: %v, buffer: %v)", t.Logf("Waiting %v for OIDC tokens to expire (TTL: %v, spread: %v, buffer: %v)",
@ -207,17 +207,17 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
} }
} }
} }
// Log progress for debugging // Log progress for debugging
if expiredCount < len(allClients) { if expiredCount < len(allClients) {
t.Logf("Token expiry progress: %d/%d clients in NeedsLogin state", expiredCount, len(allClients)) t.Logf("Token expiry progress: %d/%d clients in NeedsLogin state", expiredCount, len(allClients))
} }
// All clients must be in NeedsLogin state // All clients must be in NeedsLogin state
assert.Equal(ct, len(allClients), expiredCount, assert.Equal(ct, len(allClients), expiredCount,
"expected all %d clients to be in NeedsLogin state, but only %d are", "expected all %d clients to be in NeedsLogin state, but only %d are",
len(allClients), expiredCount) len(allClients), expiredCount)
// Only check detailed logout state if all clients are expired // Only check detailed logout state if all clients are expired
if expiredCount == len(allClients) { if expiredCount == len(allClients) {
assertTailscaleNodesLogout(ct, allClients) assertTailscaleNodesLogout(ct, allClients)

View File

@ -390,7 +390,6 @@ func TestPreAuthKeyCommand(t *testing.T) {
) )
assertNoErr(t, err) assertNoErr(t, err)
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
assert.True(t, listedPreAuthKeysAfterExpire[3].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[3].GetExpiration().AsTime().After(time.Now()))
@ -450,7 +449,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
// There is one key created by "scenario.CreateHeadscaleEnv" // There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 2) assert.Len(t, listedPreAuthKeys, 2)
assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now()))
assert.True( assert.True(
t, t,

View File

@ -2364,14 +2364,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// for all counts. // for all counts.
nodes, err := headscale.ListNodes() nodes, err := headscale.ListNodes()
assert.NoError(c, err) assert.NoError(c, err)
routerNode := MustFindNode(routerUsernet1.Hostname(), nodes) routerNode := MustFindNode(routerUsernet1.Hostname(), nodes)
t.Logf("Initial auto-approval check - Router node %s: announced=%v, approved=%v, subnet=%v", t.Logf("Initial auto-approval check - Router node %s: announced=%v, approved=%v, subnet=%v",
routerNode.GetName(), routerNode.GetName(),
routerNode.GetAvailableRoutes(), routerNode.GetAvailableRoutes(),
routerNode.GetApprovedRoutes(), routerNode.GetApprovedRoutes(),
routerNode.GetSubnetRoutes()) routerNode.GetSubnetRoutes())
requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1) requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "Initial route auto-approval: Route should be approved via policy") }, 10*time.Second, 500*time.Millisecond, "Initial route auto-approval: Route should be approved via policy")
@ -2382,19 +2382,19 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Debug output to understand peer visibility // Debug output to understand peer visibility
t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers()))
routerPeerFound := false routerPeerFound := false
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
if peerStatus.ID == routerUsernet1ID.StableID() { if peerStatus.ID == routerUsernet1ID.StableID() {
routerPeerFound = true routerPeerFound = true
t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v",
peerStatus.HostName, peerStatus.HostName,
peerStatus.ID, peerStatus.ID,
peerStatus.AllowedIPs, peerStatus.AllowedIPs,
peerStatus.PrimaryRoutes) peerStatus.PrimaryRoutes)
assert.NotNil(c, peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes != nil { if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
@ -2404,7 +2404,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
} }
} }
assert.True(c, routerPeerFound, "Client should see the router peer") assert.True(c, routerPeerFound, "Client should see the router peer")
}, 5*time.Second, 200*time.Millisecond, "Verifying routes sent to client after auto-approval") }, 5*time.Second, 200*time.Millisecond, "Verifying routes sent to client after auto-approval")
@ -2439,14 +2439,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Routes already approved should remain approved even after policy change // Routes already approved should remain approved even after policy change
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
assert.NoError(c, err) assert.NoError(c, err)
routerNode := MustFindNode(routerUsernet1.Hostname(), nodes) routerNode := MustFindNode(routerUsernet1.Hostname(), nodes)
t.Logf("After policy removal - Router node %s: announced=%v, approved=%v, subnet=%v", t.Logf("After policy removal - Router node %s: announced=%v, approved=%v, subnet=%v",
routerNode.GetName(), routerNode.GetName(),
routerNode.GetAvailableRoutes(), routerNode.GetAvailableRoutes(),
routerNode.GetApprovedRoutes(), routerNode.GetApprovedRoutes(),
routerNode.GetSubnetRoutes()) routerNode.GetSubnetRoutes())
requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1) requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "Routes should remain approved after auto-approver removal") }, 10*time.Second, 500*time.Millisecond, "Routes should remain approved after auto-approver removal")

View File

@ -449,6 +449,7 @@ func (s *Scenario) GetOrCreateUser(userStr string) *User {
Clients: make(map[string]TailscaleClient), Clients: make(map[string]TailscaleClient),
} }
s.users[userStr] = user s.users[userStr] = user
return user return user
} }

View File

@ -619,14 +619,14 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) {
if err != nil { if err != nil {
return netip.Addr{}, err return netip.Addr{}, err
} }
for _, ip := range ips { for _, ip := range ips {
if ip.Is4() { if ip.Is4() {
return ip, nil return ip, nil
} }
} }
return netip.Addr{}, fmt.Errorf("no IPv4 address found") return netip.Addr{}, errors.New("no IPv4 address found")
} }
func (t *TailscaleInContainer) MustIPv4() netip.Addr { func (t *TailscaleInContainer) MustIPv4() netip.Addr {