diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml index d8f8a0b7..70f1a146 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yaml +++ b/.github/ISSUE_TEMPLATE/feature_request.yaml @@ -16,15 +16,13 @@ body: - type: textarea attributes: label: Description - description: - A clear and precise description of what new or changed feature you want. + description: A clear and precise description of what new or changed feature you want. validations: required: true - type: checkboxes attributes: label: Contribution - description: - Are you willing to contribute to the implementation of this feature? + description: Are you willing to contribute to the implementation of this feature? options: - label: I can write the design doc for this feature required: false @@ -33,7 +31,6 @@ body: - type: textarea attributes: label: How can it be implemented? - description: - Free text for your ideas on how this feature could be implemented. + description: Free text for your ideas on how this feature could be implemented. validations: required: false diff --git a/hscontrol/app.go b/hscontrol/app.go index 47b38c83..6f669d4a 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -146,12 +146,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { policyChanged, err := app.state.DeleteNode(node) if err != nil { - log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node") + log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed") return } app.Change(policyChanged) - log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node") + log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached") }) app.ephemeralGC = ephemeralGC @@ -384,53 +384,49 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler log.Trace(). Caller(). Str("client_address", req.RemoteAddr). - Msg(`missing "Bearer " prefix in "Authorization" header`) - writer.WriteHeader(http.StatusUnauthorized) - _, err := writer.Write([]byte("Unauthorized")) + Msg("HTTP authentication invoked") + + authHeader := req.Header.Get("Authorization") + + if !strings.HasPrefix(authHeader, AuthPrefix) { + log.Error(). + Caller(). + Str("client_address", req.RemoteAddr). + Msg(`missing "Bearer " prefix in "Authorization" header`) + writer.WriteHeader(http.StatusUnauthorized) + _, err := writer.Write([]byte("Unauthorized")) + return err + } + + valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) if err != nil { log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Str("client_address", req.RemoteAddr). + Msg("failed to validate token") + + writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Unauthorized")) + return err } - return - } + if !valid { + log.Info(). + Str("client_address", req.RemoteAddr). + Msg("invalid token") - valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) - if err != nil { + writer.WriteHeader(http.StatusUnauthorized) + _, err := writer.Write([]byte("Unauthorized")) + return err + } + + return nil + }(); err != nil { log.Error(). Caller(). Err(err). - Str("client_address", req.RemoteAddr). - Msg("failed to validate token") - - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Unauthorized")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - - if !valid { - log.Info(). - Str("client_address", req.RemoteAddr). - Msg("invalid token") - - writer.WriteHeader(http.StatusUnauthorized) - _, err := writer.Write([]byte("Unauthorized")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + Msg("Failed to write HTTP response") return } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index d2f39ff0..e18f2e5d 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -260,7 +260,7 @@ func NewHeadscaleDatabase( log.Error().Err(err).Msg("Error creating route") } else { log.Info(). - Uint64("node_id", route.NodeID). + Uint64("node.id", route.NodeID). Str("prefix", prefix.String()). Msg("Route migrated") } @@ -870,23 +870,23 @@ AND auth_key_id NOT IN ( // Copy data directly using SQL dataCopySQL := []string{ `INSERT INTO users (id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at) - SELECT id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at + SELECT id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at FROM users_old`, `INSERT INTO pre_auth_keys (id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at) - SELECT id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at + SELECT id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at FROM pre_auth_keys_old`, `INSERT INTO api_keys (id, prefix, hash, expiration, last_seen, created_at) - SELECT id, prefix, hash, expiration, last_seen, created_at + SELECT id, prefix, hash, expiration, last_seen, created_at FROM api_keys_old`, `INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at) - SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at + SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at FROM nodes_old`, `INSERT INTO policies (id, data, created_at, updated_at, deleted_at) - SELECT id, data, created_at, updated_at, deleted_at + SELECT id, data, created_at, updated_at, deleted_at FROM policies_old`, } @@ -1131,7 +1131,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig } for _, migrationID := range migrationIDs { - log.Trace().Str("migration_id", migrationID).Msg("Running migration") + log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration") needsFKDisabled := migrationsRequiringFKDisabled[migrationID] if needsFKDisabled { diff --git a/hscontrol/db/ip.go b/hscontrol/db/ip.go index 63130c4c..3fddcfd2 100644 --- a/hscontrol/db/ip.go +++ b/hscontrol/db/ip.go @@ -275,7 +275,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { return errors.New("backfilling IPs: ip allocator was nil") } - log.Trace().Msgf("starting to backfill IPs") + log.Trace().Caller().Msgf("starting to backfill IPs") nodes, err := ListNodes(tx) if err != nil { @@ -283,7 +283,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { } for _, node := range nodes { - log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill") + log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database") changed := false // IPv4 prefix is set, but node ip is missing, alloc diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 3531fc49..f899ddd3 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -34,9 +34,6 @@ var ( "node not found in registration cache", ) ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") - ErrDifferentRegisteredUser = errors.New( - "node was previously registered with a different user", - ) ) // ListPeers returns peers of node, regardless of any Policy or if the node is expired. diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 32c837f1..629b7be1 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/arl/statsviz" + "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/prometheus/client_golang/prometheus/promhttp" "tailscale.com/tsweb" @@ -239,6 +240,34 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Write(resJSON) })) + // Batcher endpoint + debug.Handle("batcher", "Batcher connected nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + batcherInfo := h.debugBatcherJSON() + + batcherJSON, err := json.MarshalIndent(batcherInfo, "", " ") + if err != nil { + httpError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(batcherJSON) + } else { + // Default to text/plain for backward compatibility + batcherInfo := h.debugBatcher() + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(batcherInfo)) + } + })) + err := statsviz.Register(debugMux) if err == nil { debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)") @@ -256,3 +285,124 @@ func (h *Headscale) debugHTTPServer() *http.Server { return debugHTTPServer } + +// debugBatcher returns debug information about the batcher's connected nodes. +func (h *Headscale) debugBatcher() string { + var sb strings.Builder + sb.WriteString("=== Batcher Connected Nodes ===\n\n") + + totalNodes := 0 + connectedCount := 0 + + // Collect nodes and sort them by ID + type nodeStatus struct { + id types.NodeID + connected bool + activeConnections int + } + + var nodes []nodeStatus + + // Try to get detailed debug info if we have a LockFreeBatcher + if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok { + debugInfo := batcher.Debug() + for nodeID, info := range debugInfo { + nodes = append(nodes, nodeStatus{ + id: nodeID, + connected: info.Connected, + activeConnections: info.ActiveConnections, + }) + totalNodes++ + if info.Connected { + connectedCount++ + } + } + } else { + // Fallback to basic connection info + connectedMap := h.mapBatcher.ConnectedMap() + connectedMap.Range(func(nodeID types.NodeID, connected bool) bool { + nodes = append(nodes, nodeStatus{ + id: nodeID, + connected: connected, + activeConnections: 0, + }) + totalNodes++ + if connected { + connectedCount++ + } + return true + }) + } + + // Sort by node ID + for i := 0; i < len(nodes); i++ { + for j := i + 1; j < len(nodes); j++ { + if nodes[i].id > nodes[j].id { + nodes[i], nodes[j] = nodes[j], nodes[i] + } + } + } + + // Output sorted nodes + for _, node := range nodes { + status := "disconnected" + if node.connected { + status = "connected" + } + + if node.activeConnections > 0 { + sb.WriteString(fmt.Sprintf("Node %d:\t%s (%d connections)\n", node.id, status, node.activeConnections)) + } else { + sb.WriteString(fmt.Sprintf("Node %d:\t%s\n", node.id, status)) + } + } + + sb.WriteString(fmt.Sprintf("\nSummary: %d connected, %d total\n", connectedCount, totalNodes)) + + return sb.String() +} + +// DebugBatcherInfo represents batcher connection information in a structured format. +type DebugBatcherInfo struct { + ConnectedNodes map[string]DebugBatcherNodeInfo `json:"connected_nodes"` // NodeID -> node connection info + TotalNodes int `json:"total_nodes"` +} + +// DebugBatcherNodeInfo represents connection information for a single node. +type DebugBatcherNodeInfo struct { + Connected bool `json:"connected"` + ActiveConnections int `json:"active_connections"` +} + +// debugBatcherJSON returns structured debug information about the batcher's connected nodes. +func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { + info := DebugBatcherInfo{ + ConnectedNodes: make(map[string]DebugBatcherNodeInfo), + TotalNodes: 0, + } + + // Try to get detailed debug info if we have a LockFreeBatcher + if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok { + debugInfo := batcher.Debug() + for nodeID, debugData := range debugInfo { + info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ + Connected: debugData.Connected, + ActiveConnections: debugData.ActiveConnections, + } + info.TotalNodes++ + } + } else { + // Fallback to basic connection info + connectedMap := h.mapBatcher.ConnectedMap() + connectedMap.Range(func(nodeID types.NodeID, connected bool) bool { + info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ + Connected: connected, + ActiveConnections: 0, + } + info.TotalNodes++ + return true + }) + } + + return info +} diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index c679b3dc..da261304 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -161,7 +161,7 @@ func (d *DERPServer) DERPHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -199,7 +199,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -229,7 +229,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) { log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -245,7 +245,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) { log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -284,7 +284,7 @@ func DERPProbeHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } } } @@ -330,7 +330,7 @@ func DERPBootstrapDNSHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } } } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 1b1a22e2..6663b44a 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -237,6 +237,7 @@ func (api headscaleV1APIServer) RegisterNode( request *v1.RegisterNodeRequest, ) (*v1.RegisterNodeResponse, error) { log.Trace(). + Caller(). Str("user", request.GetUser()). Str("registration_id", request.GetKey()). Msg("Registering node") @@ -525,7 +526,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs( ctx context.Context, request *v1.BackfillNodeIPsRequest, ) (*v1.BackfillNodeIPsResponse, error) { - log.Trace().Msg("Backfill called") + log.Trace().Caller().Msg("Backfill called") if !request.Confirmed { return nil, errors.New("not confirmed, aborting") @@ -709,6 +710,10 @@ func (api headscaleV1APIServer) SetPolicy( UpdatedAt: timestamppb.New(updated.UpdatedAt), } + log.Debug(). + Caller(). + Msg("gRPC SetPolicy completed successfully because response prepared") + return response, nil } @@ -731,7 +736,7 @@ func (api headscaleV1APIServer) DebugCreateNode( Caller(). Interface("route-prefix", routes). Interface("route-str", request.GetRoutes()). - Msg("") + Msg("Creating routes for node") hostinfo := tailcfg.Hostinfo{ RoutableIPs: routes, @@ -760,6 +765,7 @@ func (api headscaleV1APIServer) DebugCreateNode( } log.Debug(). + Caller(). Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index cac4ff0f..f9f9115a 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -197,7 +197,7 @@ func (h *Headscale) RobotsHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } } diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 1299ed54..91564a3a 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -9,6 +9,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog/log" "tailscale.com/tailcfg" "tailscale.com/types/ptr" ) @@ -23,7 +24,7 @@ type Batcher interface { RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] - AddWork(c change.ChangeSet) + AddWork(c ...change.ChangeSet) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) } @@ -36,7 +37,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB // The size of this channel is arbitrary chosen, the sizing should be revisited. workCh: make(chan work, workers*200), - nodes: xsync.NewMap[types.NodeID, *nodeConn](), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](), } @@ -47,6 +48,7 @@ func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher { m := newMapper(cfg, state) b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) m.batcher = b + return b } @@ -72,8 +74,10 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) } - var mapResp *tailcfg.MapResponse - var err error + var ( + mapResp *tailcfg.MapResponse + err error + ) switch c.Change { case change.DERP: @@ -84,10 +88,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, // TODO(kradalby): This can potentially be a peer update of the old and new subnet router. mapResp, err = mapper.fullMapResponse(nodeID, version) } else { + // CRITICAL FIX: Read actual online status from NodeStore when available, + // fall back to deriving from change type for unit tests or when NodeStore is empty + var onlineStatus bool + if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() { + // Use actual NodeStore status when available (production case) + onlineStatus = node.IsOnline().Get() + } else { + // Fall back to deriving from change type (unit test case or initial setup) + onlineStatus = c.Change == change.NodeCameOnline + } + mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ { NodeID: c.NodeID.NodeID(), - Online: ptr.To(c.Change == change.NodeCameOnline), + Online: ptr.To(onlineStatus), }, }) } @@ -125,7 +140,12 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err } nodeID := nc.nodeID() - data, err := generateMapResponse(nodeID, nc.version(), mapper, c) + + log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received") + + var data *tailcfg.MapResponse + var err error + data, err = generateMapResponse(nodeID, nc.version(), mapper, c) if err != nil { return fmt.Errorf("generating map response for node %d: %w", nodeID, err) } @@ -136,7 +156,8 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err } // Send the map response - if err := nc.send(data); err != nil { + err = nc.send(data) + if err != nil { return fmt.Errorf("sending map response to node %d: %w", nodeID, err) } diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 7476b72f..aaa58f2f 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -2,6 +2,7 @@ package mapper import ( "context" + "crypto/rand" "fmt" "sync" "sync/atomic" @@ -57,16 +58,21 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse version: version, created: now, } + // Initialize last used timestamp + newEntry.lastUsed.Store(now.Unix()) - // Only after validation succeeds, create or update node connection - newConn := newNodeConn(id, c, version, b.mapper) + // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection + nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper)) if !loaded { b.totalNodes.Add(1) - conn = newConn } - b.connected.Store(id, nil) // nil = connected + // Add connection to the list (lock-free) + nodeConn.addConnection(newEntry) + + // Use the worker pool for controlled concurrency instead of direct generation + initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) if err != nil { log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") @@ -87,6 +93,16 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse return fmt.Errorf("failed to send initial map to node %d: timeout", id) } + // Update connection status + b.connected.Store(id, nil) // nil = connected + + // Node will automatically receive updates through the normal flow + // The initial full map already contains all current state + + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("Node connection established in batcher because AddNode completed successfully") + return nil } @@ -101,10 +117,11 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo return false } - // Mark the connection as closed to prevent further sends - if connData := existing.connData.Load(); connData != nil { - connData.closed.Store(true) - } + // Remove specific connection + removed := nodeConn.removeConnectionByChannel(c) + if !removed { + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid") + return false } // Check if node has any remaining active connections @@ -115,18 +132,17 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo return true // Node still has active connections } - // Remove node and mark disconnected atomically - b.nodes.Delete(id) + // No active connections - keep the node entry alive for rapid reconnections + // The node will get a fresh full map when it reconnects + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection") b.connected.Store(id, ptr.To(time.Now())) - b.totalNodes.Add(-1) return false } // AddWork queues a change to be processed by the batcher. -// Critical changes are processed immediately, while others are batched for efficiency. -func (b *LockFreeBatcher) AddWork(c change.ChangeSet) { - b.addWork(c) +func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) { + b.addWork(c...) } func (b *LockFreeBatcher) Start() { @@ -137,23 +153,36 @@ func (b *LockFreeBatcher) Start() { func (b *LockFreeBatcher) Close() { if b.cancel != nil { b.cancel() + b.cancel = nil // Prevent multiple calls + } + + // Only close workCh once + select { + case <-b.workCh: + // Channel is already closed + default: + close(b.workCh) } - close(b.workCh) } func (b *LockFreeBatcher) doWork() { - log.Debug().Msg("batcher doWork loop started") - defer log.Debug().Msg("batcher doWork loop stopped") for i := range b.workers { go b.worker(i + 1) } + // Create a cleanup ticker for removing truly disconnected nodes + cleanupTicker := time.NewTicker(5 * time.Minute) + defer cleanupTicker.Stop() + for { select { case <-b.tick.C: // Process batched changes b.processBatchedChanges() + case <-cleanupTicker.C: + // Clean up nodes that have been offline for too long + b.cleanupOfflineNodes() case <-b.ctx.Done(): return } @@ -161,8 +190,6 @@ func (b *LockFreeBatcher) doWork() { } func (b *LockFreeBatcher) worker(workerID int) { - log.Debug().Int("workerID", workerID).Msg("batcher worker started") - defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped") for { select { @@ -171,7 +198,6 @@ func (b *LockFreeBatcher) worker(workerID int) { return } - startTime := time.Now() b.workProcessed.Add(1) // If the resultCh is set, it means that this is a work request @@ -181,7 +207,9 @@ func (b *LockFreeBatcher) worker(workerID int) { if w.resultCh != nil { var result workResult if nc, exists := b.nodes.Load(w.nodeID); exists { - result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c) + var err error + result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c) + result.err = err if result.err != nil { b.workErrors.Add(1) log.Error().Err(result.err). @@ -192,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) { } } else { result.err = fmt.Errorf("node %d not found", w.nodeID) + b.workErrors.Add(1) log.Error().Err(result.err). Int("workerID", workerID). @@ -260,19 +289,22 @@ func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) { }) return } -} + all, self := change.SplitAllAndSelf(c) + + for _, changeSet := range self { + changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{}) + changes = append(changes, changeSet) + b.pendingChanges.Store(changeSet.NodeID, changes) return } - b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool { - if c.NodeID == nodeID && !c.AlsoSelf() { - return true - } + b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + rel := change.RemoveUpdatesForSelf(nodeID, all) changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) - changes = append(changes, c) + changes = append(changes, rel...) b.pendingChanges.Store(nodeID, changes) return true @@ -303,7 +335,44 @@ func (b *LockFreeBatcher) processBatchedChanges() { }) } -// IsConnected is lock-free read. +// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. +func (b *LockFreeBatcher) cleanupOfflineNodes() { + cleanupThreshold := 15 * time.Minute + now := time.Now() + + var nodesToCleanup []types.NodeID + + // Find nodes that have been offline for too long + b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { + if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold { + // Double-check the node doesn't have active connections + if nodeConn, exists := b.nodes.Load(nodeID); exists { + if !nodeConn.hasActiveConnections() { + nodesToCleanup = append(nodesToCleanup, nodeID) + } + } + } + return true + }) + + // Clean up the identified nodes + for _, nodeID := range nodesToCleanup { + log.Info().Uint64("node.id", nodeID.Uint64()). + Dur("offline_duration", cleanupThreshold). + Msg("Cleaning up node that has been offline for too long") + + b.nodes.Delete(nodeID) + b.connected.Delete(nodeID) + b.totalNodes.Add(-1) + } + + if len(nodesToCleanup) > 0 { + log.Info().Int("cleaned_nodes", len(nodesToCleanup)). + Msg("Completed cleanup of long-offline nodes") + } +} + +// IsConnected is lock-free read that checks if a node has any active connections. func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { // First check if we have active connections for this node if nodeConn, exists := b.nodes.Load(id); exists { @@ -373,89 +442,234 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change } } -// connectionData holds the channel and connection parameters. -type connectionData struct { - c chan<- *tailcfg.MapResponse - version tailcfg.CapabilityVersion - closed atomic.Bool // Track if this connection has been closed +// connectionEntry represents a single connection to a node. +type connectionEntry struct { + id string // unique connection ID + c chan<- *tailcfg.MapResponse + version tailcfg.CapabilityVersion + created time.Time + lastUsed atomic.Int64 // Unix timestamp of last successful send } -// nodeConn described the node connection and its associated data. -type nodeConn struct { +// multiChannelNodeConn manages multiple concurrent connections for a single node. +type multiChannelNodeConn struct { id types.NodeID mapper *mapper - // Atomic pointer to connection data - allows lock-free updates - connData atomic.Pointer[connectionData] + mutex sync.RWMutex + connections []*connectionEntry updateCount atomic.Int64 } -func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn { - nc := &nodeConn{ +// generateConnectionID generates a unique connection identifier. +func generateConnectionID() string { + bytes := make([]byte, 8) + rand.Read(bytes) + return fmt.Sprintf("%x", bytes) +} + +// newMultiChannelNodeConn creates a new multi-channel node connection. +func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn { + return &multiChannelNodeConn{ id: id, mapper: mapper, } - - // Initialize connection data - data := &connectionData{ - c: c, - version: version, - } - nc.connData.Store(data) - - return nc } -// updateConnection atomically updates connection parameters. -func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) { - newData := &connectionData{ - c: c, - version: version, - } - nc.connData.Store(newData) +// addConnection adds a new connection. +func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { + mutexWaitStart := time.Now() + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). + Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") + + mc.mutex.Lock() + mutexWaitDur := time.Since(mutexWaitStart) + defer mc.mutex.Unlock() + + mc.connections = append(mc.connections, entry) + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). + Int("total_connections", len(mc.connections)). + Dur("mutex_wait_time", mutexWaitDur). + Msg("Successfully added connection after mutex wait") } -// matchesChannel checks if the given channel matches current connection. -func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool { - data := nc.connData.Load() - if data == nil { - return false +// removeConnectionByChannel removes a connection by matching channel pointer. +func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for i, entry := range mc.connections { + if entry.c == c { + // Remove this connection + mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)). + Int("remaining_connections", len(mc.connections)). + Msg("Successfully removed connection") + return true + } } - // Compare channel pointers directly - return data.c == c + return false } -// compressAndVersion atomically reads connection settings. -func (nc *nodeConn) version() tailcfg.CapabilityVersion { - data := nc.connData.Load() - if data == nil { +// hasActiveConnections checks if the node has any active connections. +func (mc *multiChannelNodeConn) hasActiveConnections() bool { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + return len(mc.connections) > 0 +} + +// getActiveConnectionCount returns the number of active connections. +func (mc *multiChannelNodeConn) getActiveConnectionCount() int { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + return len(mc.connections) +} + +// send broadcasts data to all active connections for the node. +func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + if len(mc.connections) == 0 { + // During rapid reconnection, nodes may temporarily have no active connections + // This is not an error - the node will receive a full map when it reconnects + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). + Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + return nil // Return success instead of error + } + + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). + Int("total_connections", len(mc.connections)). + Msg("send: broadcasting to all connections") + + var lastErr error + successCount := 0 + var failedConnections []int // Track failed connections for removal + + // Send to all connections + for i, conn := range mc.connections { + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). + Str("conn.id", conn.id).Int("connection_index", i). + Msg("send: attempting to send to connection") + + if err := conn.send(data); err != nil { + lastErr = err + failedConnections = append(failedConnections, i) + log.Warn().Err(err). + Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). + Str("conn.id", conn.id).Int("connection_index", i). + Msg("send: connection send failed") + } else { + successCount++ + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). + Str("conn.id", conn.id).Int("connection_index", i). + Msg("send: successfully sent to connection") + } + } + + // Remove failed connections (in reverse order to maintain indices) + for i := len(failedConnections) - 1; i >= 0; i-- { + idx := failedConnections[i] + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). + Str("conn.id", mc.connections[idx].id). + Msg("send: removing failed connection") + mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...) + } + + mc.updateCount.Add(1) + + log.Info().Uint64("node.id", mc.id.Uint64()). + Int("successful_sends", successCount). + Int("failed_connections", len(failedConnections)). + Int("remaining_connections", len(mc.connections)). + Msg("send: completed broadcast") + + // Success if at least one send succeeded + if successCount > 0 { + return nil + } + + return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr) +} + +// send sends data to a single connection entry with timeout-based stale connection detection. +func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { + // Use a short timeout to detect stale connections where the client isn't reading the channel. + // This is critical for detecting Docker containers that are forcefully terminated + // but still have channels that appear open. + select { + case entry.c <- data: + // Update last used timestamp on successful send + entry.lastUsed.Store(time.Now().Unix()) + return nil + case <-time.After(50 * time.Millisecond): + // Connection is likely stale - client isn't reading from channel + // This catches the case where Docker containers are killed but channels remain open + return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id) + } +} + +// nodeID returns the node ID. +func (mc *multiChannelNodeConn) nodeID() types.NodeID { + return mc.id +} + +// version returns the capability version from the first active connection. +// All connections for a node should have the same version in practice. +func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + if len(mc.connections) == 0 { return 0 } - return data.version + return mc.connections[0].version } -func (nc *nodeConn) nodeID() types.NodeID { - return nc.id +// change applies a change to all active connections for the node. +func (mc *multiChannelNodeConn) change(c change.ChangeSet) error { + return handleNodeChange(mc, mc.mapper, c) } -func (nc *nodeConn) change(c change.ChangeSet) error { - return handleNodeChange(nc, nc.mapper, c) +// DebugNodeInfo contains debug information about a node's connections. +type DebugNodeInfo struct { + Connected bool `json:"connected"` + ActiveConnections int `json:"active_connections"` } -// send sends data to the node's channel. -// The node will pick it up and send it to the HTTP handler. -func (nc *nodeConn) send(data *tailcfg.MapResponse) error { - connData := nc.connData.Load() - if connData == nil { - return fmt.Errorf("node %d: no connection data", nc.id) - } +// Debug returns a pre-baked map of node debug information for the debug interface. +func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { + result := make(map[types.NodeID]DebugNodeInfo) - // Check if connection has been closed - if connData.closed.Load() { - return fmt.Errorf("node %d: connection closed", nc.id) - } + // Get all nodes with their connection status using immediate connection logic + // (no grace period) for debug purposes + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + nodeConn.mutex.RLock() + activeConnCount := len(nodeConn.connections) + nodeConn.mutex.RUnlock() + + // Use immediate connection status: if active connections exist, node is connected + // If not, check the connected map for nil (connected) vs timestamp (disconnected) + connected := false + if activeConnCount > 0 { + connected = true + } else { + // Check connected map for immediate status + if val, ok := b.connected.Load(id); ok && val == nil { + connected = true + } + } + + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: activeConnCount, + } + return true + }) // Add all entries from the connected map to capture both connected and disconnected nodes b.connected.Range(func(id types.NodeID, val *time.Time) bool { diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 6cf63dca..efc96f98 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -209,6 +209,7 @@ func setupBatcherWithTestData( // Create test users and nodes in the database users := database.CreateUsersForTest(userCount, "testuser") + allNodes := make([]node, 0, userCount*nodesPerUser) for _, user := range users { dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node") @@ -353,6 +354,7 @@ func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected b if len(resp.PeersChangedPatch) > 0 { require.Len(t, resp.PeersChangedPatch, 1) assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online) + return } @@ -412,6 +414,7 @@ func (n *node) start() { n.maxPeersCount = info.PeerCount } } + if info.IsPatch { atomic.AddInt64(&n.patchCount, 1) // For patches, we track how many patch items @@ -550,6 +553,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Reduce verbose application logging for cleaner test output originalLevel := zerolog.GlobalLevel() defer zerolog.SetGlobalLevel(originalLevel) + zerolog.SetGlobalLevel(zerolog.ErrorLevel) // Test cases: different node counts to stress test the all-to-all connectivity @@ -618,6 +622,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Join all nodes as fast as possible t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) + for i := range allNodes { node := &allNodes[i] batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) @@ -693,6 +698,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { if stats.MaxPeersSeen > maxPeersGlobal { maxPeersGlobal = stats.MaxPeersSeen } + if stats.MaxPeersSeen < minPeersSeen { minPeersSeen = stats.MaxPeersSeen } @@ -730,9 +736,11 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Show sample of node details if len(nodeDetails) > 0 { t.Logf(" Node sample:") + for _, detail := range nodeDetails[:min(5, len(nodeDetails))] { t.Logf(" %s", detail) } + if len(nodeDetails) > 5 { t.Logf(" ... (%d more nodes)", len(nodeDetails)-5) } @@ -754,6 +762,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Show details of failed nodes for debugging if len(nodeDetails) > 5 { t.Logf("Failed nodes details:") + for _, detail := range nodeDetails[5:] { if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) { t.Logf(" %s", detail) @@ -875,6 +884,7 @@ func TestBatcherBasicOperations(t *testing.T) { func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { count := 0 + timer := time.NewTimer(timeout) defer timer.Stop() @@ -1026,10 +1036,12 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // Collect updates with timeout updateCount := 0 timeout := time.After(200 * time.Millisecond) + for { select { case data := <-ch: updateCount++ + receivedUpdates = append(receivedUpdates, data) // Validate update content @@ -1058,6 +1070,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // Validate that all updates have valid content validUpdates := 0 + for _, data := range receivedUpdates { if data != nil { if valid, _ := validateUpdateContent(data); valid { @@ -1095,16 +1108,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) { batcher := testData.Batcher testNode := testData.Nodes[0] - var channelIssues int - var mutex sync.Mutex + + var ( + channelIssues int + mutex sync.Mutex + ) // Run rapid connect/disconnect cycles with real updates to test channel closing + for i := range 100 { var wg sync.WaitGroup // First connection ch1 := make(chan *tailcfg.MapResponse, 1) + wg.Add(1) + go func() { defer wg.Done() @@ -1118,17 +1137,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) { // Rapid second connection - should replace ch1 ch2 := make(chan *tailcfg.MapResponse, 1) + wg.Add(1) + go func() { defer wg.Done() + time.Sleep(1 * time.Microsecond) batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }() // Remove second connection wg.Add(1) + go func() { defer wg.Done() + time.Sleep(2 * time.Microsecond) batcher.RemoveNode(testNode.n.ID, ch2) }() @@ -1143,7 +1167,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) { case <-time.After(1 * time.Millisecond): // If no data received, increment issues counter mutex.Lock() + channelIssues++ + mutex.Unlock() } @@ -1185,18 +1211,24 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { batcher := testData.Batcher testNode := testData.Nodes[0] - var panics int - var channelErrors int - var invalidData int - var mutex sync.Mutex + + var ( + panics int + channelErrors int + invalidData int + mutex sync.Mutex + ) // Test rapid connect/disconnect with work generation + for i := range 50 { func() { defer func() { if r := recover(); r != nil { mutex.Lock() + panics++ + mutex.Unlock() t.Logf("Panic caught: %v", r) } @@ -1213,7 +1245,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { defer func() { if r := recover(); r != nil { mutex.Lock() + channelErrors++ + mutex.Unlock() t.Logf("Channel consumer panic: %v", r) } @@ -1229,7 +1263,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Validate the data we received if valid, reason := validateUpdateContent(data); !valid { mutex.Lock() + invalidData++ + mutex.Unlock() t.Logf("Invalid data received: %s", reason) } @@ -1268,9 +1304,11 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { if panics > 0 { t.Errorf("Worker channel safety failed with %d panics", panics) } + if channelErrors > 0 { t.Errorf("Channel handling failed with %d channel errors", channelErrors) } + if invalidData > 0 { t.Errorf("Data validation failed with %d invalid data packets", invalidData) } @@ -1342,15 +1380,19 @@ func TestBatcherConcurrentClients(t *testing.T) { // Use remaining nodes for connection churn testing churningNodes := allNodes[len(allNodes)/2:] churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) + var churningChannelsMutex sync.Mutex // Protect concurrent map access var wg sync.WaitGroup + numCycles := 10 // Reduced for simpler test panicCount := 0 + var panicMutex sync.Mutex // Track deadlock with timeout done := make(chan struct{}) + go func() { defer close(done) @@ -1364,16 +1406,22 @@ func TestBatcherConcurrentClients(t *testing.T) { defer func() { if r := recover(); r != nil { panicMutex.Lock() + panicCount++ + panicMutex.Unlock() t.Logf("Panic in churning connect: %v", r) } + wg.Done() }() ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) + churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) @@ -1400,17 +1448,23 @@ func TestBatcherConcurrentClients(t *testing.T) { defer func() { if r := recover(); r != nil { panicMutex.Lock() + panicCount++ + panicMutex.Unlock() t.Logf("Panic in churning disconnect: %v", r) } + wg.Done() }() time.Sleep(time.Duration(i%5) * time.Millisecond) churningChannelsMutex.Lock() + ch, exists := churningChannels[nodeID] + churningChannelsMutex.Unlock() + if exists { batcher.RemoveNode(nodeID, ch) } @@ -1422,10 +1476,12 @@ func TestBatcherConcurrentClients(t *testing.T) { // DERP changes batcher.AddWork(change.DERPSet) } + if i%5 == 0 { // Full updates using real node data batcher.AddWork(change.FullSet) } + if i%7 == 0 && len(allNodes) > 0 { // Node-specific changes using real nodes node := allNodes[i%len(allNodes)] @@ -1453,7 +1509,9 @@ func TestBatcherConcurrentClients(t *testing.T) { // Validate results panicMutex.Lock() + finalPanicCount := panicCount + panicMutex.Unlock() allStats := tracker.getAllStats() @@ -1536,6 +1594,7 @@ func XTestBatcherScalability(t *testing.T) { // Reduce verbose application logging for cleaner test output originalLevel := zerolog.GlobalLevel() defer zerolog.SetGlobalLevel(originalLevel) + zerolog.SetGlobalLevel(zerolog.ErrorLevel) // Full test matrix for scalability testing @@ -1624,6 +1683,7 @@ func XTestBatcherScalability(t *testing.T) { batcher := testData.Batcher allNodes := testData.Nodes + t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description) t.Logf( " Cycles: %d, Buffer Size: %d, Chaos Type: %s", @@ -1660,12 +1720,16 @@ func XTestBatcherScalability(t *testing.T) { // Connect all nodes first so they can see each other as peers connectedNodes := make(map[types.NodeID]bool) + var connectedNodesMutex sync.RWMutex + for i := range testNodes { node := &testNodes[i] batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) connectedNodesMutex.Lock() + connectedNodes[node.n.ID] = true + connectedNodesMutex.Unlock() } @@ -1676,6 +1740,7 @@ func XTestBatcherScalability(t *testing.T) { go func() { defer close(done) + var wg sync.WaitGroup t.Logf( @@ -1697,14 +1762,17 @@ func XTestBatcherScalability(t *testing.T) { // For chaos testing, only disconnect/reconnect a subset of nodes // This ensures some nodes stay connected to continue receiving updates startIdx := cycle % len(testNodes) + endIdx := startIdx + len(testNodes)/4 if endIdx > len(testNodes) { endIdx = len(testNodes) } + if startIdx >= endIdx { startIdx = 0 endIdx = min(len(testNodes)/4, len(testNodes)) } + chaosNodes := testNodes[startIdx:endIdx] if len(chaosNodes) == 0 { chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos @@ -1722,17 +1790,22 @@ func XTestBatcherScalability(t *testing.T) { if r := recover(); r != nil { atomic.AddInt64(&panicCount, 1) } + wg.Done() }() connectedNodesMutex.RLock() + isConnected := connectedNodes[nodeID] + connectedNodesMutex.RUnlock() if isConnected { batcher.RemoveNode(nodeID, channel) connectedNodesMutex.Lock() + connectedNodes[nodeID] = false + connectedNodesMutex.Unlock() } }( @@ -1746,6 +1819,7 @@ func XTestBatcherScalability(t *testing.T) { if r := recover(); r != nil { atomic.AddInt64(&panicCount, 1) } + wg.Done() }() @@ -1757,7 +1831,9 @@ func XTestBatcherScalability(t *testing.T) { tailcfg.CapabilityVersion(100), ) connectedNodesMutex.Lock() + connectedNodes[nodeID] = true + connectedNodesMutex.Unlock() // Add work to create load @@ -1776,11 +1852,13 @@ func XTestBatcherScalability(t *testing.T) { updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count for i := range updateCount { wg.Add(1) + go func(index int) { defer func() { if r := recover(); r != nil { atomic.AddInt64(&panicCount, 1) } + wg.Done() }() @@ -1823,11 +1901,14 @@ func XTestBatcherScalability(t *testing.T) { deadlockDetected = true // Collect diagnostic information allStats := tracker.getAllStats() + totalUpdates := 0 for _, stats := range allStats { totalUpdates += stats.TotalUpdates } + interimPanics := atomic.LoadInt64(&panicCount) + t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT) t.Logf( " Progress at timeout: %d total updates, %d panics", @@ -1873,6 +1954,7 @@ func XTestBatcherScalability(t *testing.T) { stats := node.cleanup() totalUpdates += stats.TotalUpdates totalPatches += stats.PatchUpdates + totalFull += stats.FullUpdates if stats.MaxPeersSeen > maxPeersGlobal { maxPeersGlobal = stats.MaxPeersSeen @@ -1910,10 +1992,12 @@ func XTestBatcherScalability(t *testing.T) { // Legacy tracker comparison (optional) allStats := tracker.getAllStats() + legacyTotalUpdates := 0 for _, stats := range allStats { legacyTotalUpdates += stats.TotalUpdates } + if legacyTotalUpdates != int(totalUpdates) { t.Logf( "Note: Legacy tracker mismatch - legacy: %d, new: %d", @@ -1926,6 +2010,7 @@ func XTestBatcherScalability(t *testing.T) { // Validation based on expectation testPassed := true + if tc.expectBreak { // For tests expected to break, we're mainly checking that we don't crash if finalPanicCount > 0 { @@ -1947,14 +2032,19 @@ func XTestBatcherScalability(t *testing.T) { // For tests expected to pass, validate proper operation if finalPanicCount > 0 { t.Errorf("Scalability test failed with %d panics", finalPanicCount) + testPassed = false } + if deadlockDetected { t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes)) + testPassed = false } + if totalUpdates == 0 { t.Error("No updates received - system may be completely stalled") + testPassed = false } } @@ -2020,6 +2110,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Read all available updates for each node for i := range allNodes { nodeUpdates := 0 + t.Logf("Reading updates for node %d:", i) // Read up to 10 updates per node or until timeout/no more data @@ -2056,6 +2147,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { if len(data.Peers) > 0 { t.Logf(" Full peer list with %d peers", len(data.Peers)) + for j, peer := range data.Peers[:min(3, len(data.Peers))] { t.Logf( " Peer %d: NodeID=%d, Online=%v", @@ -2065,8 +2157,10 @@ func TestBatcherFullPeerUpdates(t *testing.T) { ) } } + if len(data.PeersChangedPatch) > 0 { t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch)) + for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] { t.Logf( " Patch %d: NodeID=%d, Online=%v", @@ -2080,6 +2174,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { case <-time.After(500 * time.Millisecond): } } + t.Logf("Node %d received %d updates", i, nodeUpdates) } @@ -2095,71 +2190,132 @@ func TestBatcherFullPeerUpdates(t *testing.T) { } } -// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items. -func TestBatcherWorkQueueTracing(t *testing.T) { +// TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID +// at the same time cause /debug/batcher to show nodes as disconnected when they should be connected. +// This specifically tests the multi-channel batcher implementation issue. +func TestBatcherRapidReconnection(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + t.Logf("=== RAPID RECONNECTION TEST ===") + t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes)) + + // Phase 1: Connect all nodes initially + t.Logf("Phase 1: Connecting all nodes...") + for i, node := range allNodes { + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node %d: %v", i, err) + } + } + + time.Sleep(100 * time.Millisecond) // Let connections settle + + // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) + t.Logf("Phase 2: Rapid disconnect all nodes...") + for i, node := range allNodes { + removed := batcher.RemoveNode(node.n.ID, node.ch) + t.Logf("Node %d RemoveNode result: %t", i, removed) + } + + // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) + t.Logf("Phase 3: Rapid reconnect with new channels...") + newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) + for i, node := range allNodes { + newChannels[i] = make(chan *tailcfg.MapResponse, 10) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) + if err != nil { + t.Errorf("Failed to reconnect node %d: %v", i, err) + } + } + + time.Sleep(100 * time.Millisecond) // Let reconnections settle + + // Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR + t.Logf("Phase 4: Checking debug status...") + + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + disconnectedCount := 0 + + for i, node := range allNodes { + if info, exists := debugInfo[node.n.ID]; exists { + t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info) + + // Check if the debug info shows the node as connected + if infoMap, ok := info.(map[string]any); ok { + if connected, ok := infoMap["connected"].(bool); ok && !connected { + disconnectedCount++ + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) + } + } + } else { + disconnectedCount++ + t.Logf("Node %d missing from debug info entirely", i) + } + + // Also check IsConnected method + if !batcher.IsConnected(node.n.ID) { + t.Logf("Node %d IsConnected() returns false", i) + } + } + + if disconnectedCount > 0 { + t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes)) + // This is expected behavior for multi-channel batcher according to user + // "it has never worked with the multi" + } else { + t.Logf("All nodes show as connected - working correctly") + } + } else { + t.Logf("Batcher does not implement Debug() method") + } + + // Phase 5: Test if "disconnected" nodes can actually receive updates + t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...") + + // Send a change that should reach all nodes + batcher.AddWork(change.DERPChange()) + + receivedCount := 0 + timeout := time.After(500 * time.Millisecond) + + for i := 0; i < len(allNodes); i++ { + select { + case update := <-newChannels[i]: + if update != nil { + receivedCount++ + t.Logf("Node %d received update successfully", i) + } + case <-timeout: + t.Logf("Node %d timed out waiting for update", i) + goto done + } + } + + done: + t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes)) + + if receivedCount < len(allNodes) { + t.Logf("Some nodes failed to receive updates - confirming the issue") + } + }) + } +} + +func TestBatcherMultiConnection(t *testing.T) { for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10) defer cleanup() - batcher := testData.Batcher - nodes := testData.Nodes - - t.Logf("=== WORK QUEUE TRACING TEST ===") - - time.Sleep(100 * time.Millisecond) // Let connections settle - - // Wait for initial NodeCameOnline to be processed - time.Sleep(200 * time.Millisecond) - - // Drain any initial updates - drainedCount := 0 - for { - select { - case <-nodes[0].ch: - drainedCount++ - case <-time.After(100 * time.Millisecond): - goto drained - } - } - drained: - t.Logf("Drained %d initial updates", drainedCount) - - // Now send a single FullSet update and trace it closely - t.Logf("Sending change.FullSet work item...") - batcher.AddWork(change.FullSet) - - // Give short time for processing - time.Sleep(100 * time.Millisecond) - - // Check if any update was received - select { - case data := <-nodes[0].ch: - t.Logf("SUCCESS: Received update after FullSet!") - - if data != nil { - // Detailed analysis of the response - data is already a MapResponse - t.Logf("Response details:") - t.Logf(" Peers: %d", len(data.Peers)) - t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch)) - t.Logf(" PeersChanged: %d", len(data.PeersChanged)) - t.Logf(" PeersRemoved: %d", len(data.PeersRemoved)) - t.Logf(" DERPMap: %v", data.DERPMap != nil) - t.Logf(" KeepAlive: %v", data.KeepAlive) - t.Logf(" Node: %v", data.Node != nil) - - if len(data.Peers) > 0 { - t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers)) - } else if len(data.PeersChangedPatch) > 0 { - t.Errorf("ERROR: Received patch update instead of full update!") - } else if data.DERPMap != nil { - t.Logf("Received DERP map update") - } else if data.Node != nil { - t.Logf("Received self node update") - } else { - t.Errorf("ERROR: Received unknown update type!") - } - batcher := testData.Batcher node1 := testData.Nodes[0] node2 := testData.Nodes[1] @@ -2328,12 +2484,53 @@ func TestBatcherWorkQueueTracing(t *testing.T) { } } } - } else { - t.Errorf("Response data is nil") } - case <-time.After(2 * time.Second): - t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!") - t.Errorf("This indicates FullSet work items are not being processed at all") + } + + // Send another update and verify remaining connections still work + clearChannel(node1.ch) + clearChannel(thirdChannel) + + testChangeSet2 := change.ChangeSet{ + NodeID: node2.n.ID, + Change: change.NodeNewOrUpdate, + SelfUpdateOnly: false, + } + + batcher.AddWork(testChangeSet2) + time.Sleep(100 * time.Millisecond) + + // Verify remaining connections still receive updates + remaining1Received := false + remaining3Received := false + + select { + case mapResp := <-node1.ch: + remaining1Received = (mapResp != nil) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 1 did not receive update after removal") + } + + select { + case mapResp := <-thirdChannel: + remaining3Received = (mapResp != nil) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 3 did not receive update after removal") + } + + if remaining1Received && remaining3Received { + t.Logf("SUCCESS: Remaining connections still receive updates after removal") + } else { + t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t", + remaining1Received, remaining3Received) + } + + // Verify second channel no longer receives updates (should be closed/removed) + select { + case <-secondChannel: + t.Errorf("Removed connection still received update - this should not happen") + case <-time.After(100 * time.Millisecond): + t.Logf("SUCCESS: Removed connection correctly no longer receives updates") } }) } diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index dc43b933..819d23a3 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -20,6 +20,8 @@ type MapResponseBuilder struct { nodeID types.NodeID capVer tailcfg.CapabilityVersion errs []error + + debugType debugType } type debugType string diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index bb8340d0..5e9b9a13 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -139,11 +139,11 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { func (m *mapper) fullMapResponse( nodeID types.NodeID, capVer tailcfg.CapabilityVersion, - messages ...string, ) (*tailcfg.MapResponse, error) { peers := m.state.ListPeers(nodeID) return m.NewMapResponseBuilder(nodeID). + WithDebugType(fullResponseDebug). WithCapabilityVersion(capVer). WithSelfNode(). WithDERPMap(). @@ -162,6 +162,7 @@ func (m *mapper) derpMapResponse( nodeID types.NodeID, ) (*tailcfg.MapResponse, error) { return m.NewMapResponseBuilder(nodeID). + WithDebugType(derpResponseDebug). WithDERPMap(). Build() } @@ -173,6 +174,7 @@ func (m *mapper) peerChangedPatchResponse( changed []*tailcfg.PeerChange, ) (*tailcfg.MapResponse, error) { return m.NewMapResponseBuilder(nodeID). + WithDebugType(patchResponseDebug). WithPeerChangedPatch(changed). Build() } @@ -186,6 +188,7 @@ func (m *mapper) peerChangeResponse( peers := m.state.ListPeers(nodeID, changedNodeID) return m.NewMapResponseBuilder(nodeID). + WithDebugType(changeResponseDebug). WithCapabilityVersion(capVer). WithSelfNode(). WithUserProfiles(peers). @@ -199,6 +202,7 @@ func (m *mapper) peerRemovedResponse( removedNodeID types.NodeID, ) (*tailcfg.MapResponse, error) { return m.NewMapResponseBuilder(nodeID). + WithDebugType(removeResponseDebug). WithPeersRemoved(removedNodeID). Build() } @@ -214,7 +218,7 @@ func writeDebugMapResponse( } perms := fs.FileMode(debugMapResponsePerm) - mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID)) + mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -224,7 +228,7 @@ func writeDebugMapResponse( mapResponsePath := path.Join( mPath, - fmt.Sprintf("%s.json", now), + fmt.Sprintf("%s-%s.json", now, t), ) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) @@ -244,7 +248,11 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er return nil, nil } - nodes, err := os.ReadDir(debugDumpMapResponsePath) + return ReadMapResponsesFromDirectory(debugDumpMapResponsePath) +} + +func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) { + nodes, err := os.ReadDir(dir) if err != nil { return nil, err } @@ -263,7 +271,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er nodeID := types.NodeID(nodeIDu) - files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name())) + files, err := os.ReadDir(path.Join(dir, node.Name())) if err != nil { log.Error().Err(err).Msgf("Reading dir %s", node.Name()) continue @@ -278,7 +286,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er continue } - body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name())) + body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name())) if err != nil { log.Error().Err(err).Msgf("Reading file %s", file.Name()) continue diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index c699943f..ac96028e 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -158,7 +158,6 @@ func TestTailNode(t *testing.T) { Tags: []string{}, - LastSeen: &lastSeen, MachineAuthorized: true, CapMap: tailcfg.NodeCapMap{ diff --git a/hscontrol/noise.go b/hscontrol/noise.go index bb59fea6..fa5eb1dd 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -175,8 +175,8 @@ func rejectUnsupported( Int("client_cap_ver", int(version)). Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)). Str("client_version", capver.TailscaleVersion(version)). - Str("node_key", nkey.ShortString()). - Str("machine_key", mkey.ShortString()). + Str("node.key", nkey.ShortString()). + Str("machine.key", mkey.ShortString()). Msg("unsupported client connected") http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest) @@ -282,7 +282,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( writer.WriteHeader(http.StatusOK) if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { - log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") + log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") return } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 021a6272..55f917d7 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -181,7 +181,7 @@ func (a *AuthProviderOIDC) RegisterHandler( a.registrationCache.Set(state, registrationInfo) authURL := a.oauth2Config.AuthCodeURL(state, extras...) - log.Debug().Msgf("Redirecting to %s for authentication", authURL) + log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL) http.Redirect(writer, req, authURL, http.StatusFound) } @@ -311,7 +311,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( log.Error(). Caller(). Err(werr). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -349,7 +349,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) if _, err := writer.Write(content.Bytes()); err != nil { - util.LogErr(err, "Failed to write response") + util.LogErr(err, "Failed to write HTTP response") } return diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index ecd8f83e..338e513b 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -34,7 +34,7 @@ func (pol *Policy) compileFilterRules( srcIPs, err := acl.Sources.Resolve(pol, users, nodes) if err != nil { - log.Trace().Err(err).Msgf("resolving source ips") + log.Trace().Caller().Err(err).Msgf("resolving source ips") } if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { @@ -52,11 +52,11 @@ func (pol *Policy) compileFilterRules( for _, dest := range acl.Destinations { ips, err := dest.Resolve(pol, users, nodes) if err != nil { - log.Trace().Err(err).Msgf("resolving destination ips") + log.Trace().Caller().Err(err).Msgf("resolving destination ips") } if ips == nil { - log.Debug().Msgf("destination resolved to nil ips: %v", dest) + log.Debug().Caller().Msgf("destination resolved to nil ips: %v", dest) continue } @@ -106,7 +106,7 @@ func (pol *Policy) compileSSHPolicy( return nil, nil } - log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname()) + log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname()) var rules []*tailcfg.SSHRule @@ -115,7 +115,7 @@ func (pol *Policy) compileSSHPolicy( for _, src := range rule.Destinations { ips, err := src.Resolve(pol, users, nodes) if err != nil { - log.Trace().Err(err).Msgf("resolving destination ips") + log.Trace().Caller().Err(err).Msgf("resolving destination ips") } dest.AddSet(ips) } @@ -142,7 +142,7 @@ func (pol *Policy) compileSSHPolicy( var principals []*tailcfg.SSHPrincipal srcIPs, err := rule.Sources.Resolve(pol, users, nodes) if err != nil { - log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) + log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) continue // Skip this rule if we can't resolve sources } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 5e7aa34b..4215485a 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" "go4.org/netipx" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -79,6 +80,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) { filterHash := deephash.Hash(&filter) filterChanged := filterHash != pm.filterHash + if filterChanged { + log.Debug(). + Str("filter.hash.old", pm.filterHash.String()[:8]). + Str("filter.hash.new", filterHash.String()[:8]). + Int("filter.rules", len(pm.filter)). + Int("filter.rules.new", len(filter)). + Msg("Policy filter hash changed") + } pm.filter = filter pm.filterHash = filterHash if filterChanged { @@ -95,6 +104,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) { tagOwnerMapHash := deephash.Hash(&tagMap) tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash + if tagOwnerChanged { + log.Debug(). + Str("tagOwner.hash.old", pm.tagOwnerMapHash.String()[:8]). + Str("tagOwner.hash.new", tagOwnerMapHash.String()[:8]). + Int("tagOwners.old", len(pm.tagOwnerMap)). + Int("tagOwners.new", len(tagMap)). + Msg("Tag owner hash changed") + } pm.tagOwnerMap = tagMap pm.tagOwnerMapHash = tagOwnerMapHash @@ -105,19 +122,42 @@ func (pm *PolicyManager) updateLocked() (bool, error) { autoApproveMapHash := deephash.Hash(&autoMap) autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash + if autoApproveChanged { + log.Debug(). + Str("autoApprove.hash.old", pm.autoApproveMapHash.String()[:8]). + Str("autoApprove.hash.new", autoApproveMapHash.String()[:8]). + Int("autoApprovers.old", len(pm.autoApproveMap)). + Int("autoApprovers.new", len(autoMap)). + Msg("Auto-approvers hash changed") + } pm.autoApproveMap = autoMap pm.autoApproveMapHash = autoApproveMapHash - exitSetHash := deephash.Hash(&autoMap) + exitSetHash := deephash.Hash(&exitSet) exitSetChanged := exitSetHash != pm.exitSetHash + if exitSetChanged { + log.Debug(). + Str("exitSet.hash.old", pm.exitSetHash.String()[:8]). + Str("exitSet.hash.new", exitSetHash.String()[:8]). + Msg("Exit node set hash changed") + } pm.exitSet = exitSet pm.exitSetHash = exitSetHash // If neither of the calculated values changed, no need to update nodes if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged { + log.Trace(). + Msg("Policy evaluation detected no changes - all hashes match") return false, nil } + log.Debug(). + Bool("filter.changed", filterChanged). + Bool("tagOwners.changed", tagOwnerChanged). + Bool("autoApprovers.changed", autoApproveChanged). + Bool("exitNodes.changed", exitSetChanged). + Msg("Policy changes require node updates") + return true, nil } @@ -151,6 +191,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() + // Log policy metadata for debugging + log.Debug(). + Int("policy.bytes", len(polB)). + Int("acls.count", len(pol.ACLs)). + Int("groups.count", len(pol.Groups)). + Int("hosts.count", len(pol.Hosts)). + Int("tagOwners.count", len(pol.TagOwners)). + Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)). + Msg("Policy parsed successfully") + pm.pol = pol return pm.updateLocked() diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 4809257b..cfe89b1a 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -216,6 +216,21 @@ func (m *mapSession) serveLongPoll() { m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) + // TODO(kradalby): Redo the comments here + // Add node to batcher so it can receive updates, + // adding this before connecting it to the state ensure that + // it does not miss any updates that might be sent in the split + // time between the node connecting and the batcher being ready. + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { + m.errf(err, "failed to add node to batcher") + log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session") + return + } + log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher") + + m.h.Change(mapReqChange) + m.h.Change(connectChanges...) + // Loop through updates and continuously send them to the // client. for { @@ -227,7 +242,7 @@ func (m *mapSession) serveLongPoll() { return case <-ctx.Done(): - m.tracef("poll context done") + m.tracef("poll context done chan:%p", m.ch) mapResponseEnded.WithLabelValues("done").Inc() return @@ -295,7 +310,15 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { } } - log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") + log.Trace(). + Caller(). + Str("node.name", m.node.Hostname). + Uint64("node.id", m.node.ID.Uint64()). + Str("chan", fmt.Sprintf("%p", m.ch)). + TimeDiff("timeSpent", time.Now(), startWrite). + Str("machine.key", m.node.MachineKey.String()). + Bool("keepalive", msg.KeepAlive). + Msgf("finished writing mapresp to node chan(%p)", m.ch) return nil } @@ -305,14 +328,14 @@ var keepAlive = tailcfg.MapResponse{ } func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) { - trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname) + trace := log.Trace().Caller().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname) if peerChange.Key != nil { - trace = trace.Str("node_key", peerChange.Key.ShortString()) + trace = trace.Str("node.key", peerChange.Key.ShortString()) } if peerChange.DiscoKey != nil { - trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString()) + trace = trace.Str("disco.key", peerChange.DiscoKey.ShortString()) } if peerChange.Online != nil { @@ -349,7 +372,7 @@ func logPollFunc( Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). + Str("node.name", node.Hostname). Msgf(msg, a...) }, func(msg string, a ...any) { @@ -358,7 +381,7 @@ func logPollFunc( Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). + Str("node.name", node.Hostname). Msgf(msg, a...) }, func(msg string, a ...any) { @@ -367,7 +390,7 @@ func logPollFunc( Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). + Str("node.name", node.Hostname). Msgf(msg, a...) }, func(err error, msg string, a ...any) { @@ -376,7 +399,7 @@ func logPollFunc( Bool("omitPeers", mapRequest.OmitPeers). Bool("stream", mapRequest.Stream). Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). + Str("node.name", node.Hostname). Err(err). Msgf(msg, a...) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e137116a..d74814b0 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1430,7 +1430,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err) } - log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users") + log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected") changed, err := s.polMan.SetUsers(users) if err != nil { diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index e38a98f6..5c5ea8b8 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -97,6 +97,35 @@ func (c ChangeSet) IsFull() bool { return c.Change == Full || c.Change == Policy } +func HasFull(cs []ChangeSet) bool { + for _, c := range cs { + if c.IsFull() { + return true + } + } + return false +} + +func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) { + for _, c := range cs { + if c.SelfUpdateOnly { + self = append(self, c) + } else { + all = append(all, c) + } + } + return all, self +} + +func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) { + for _, c := range cs { + if c.NodeID != id || c.Change.AlsoSelf() { + ret = append(ret, c) + } + } + return ret +} + func (c ChangeSet) AlsoSelf() bool { // If NodeID is 0, it means this ChangeSet is not related to a specific node, // so we consider it as a change that should be sent to all nodes. diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index f23b75e8..4a0a366e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -489,6 +489,7 @@ func derpConfig() DERPConfig { urlAddr, err := url.Parse(urlStr) if err != nil { log.Error(). + Caller(). Str("url", urlStr). Err(err). Msg("Failed to parse url, ignoring...") @@ -561,6 +562,7 @@ func logConfig() LogConfig { logFormat = TextLogFormat default: log.Error(). + Caller(). Str("func", "GetLogConfig"). Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 959572a2..1d0b6cc3 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -54,6 +54,20 @@ func (id NodeID) String() string { return strconv.FormatUint(id.Uint64(), util.Base10) } +func ParseNodeID(s string) (NodeID, error) { + id, err := strconv.ParseUint(s, util.Base10, 64) + return NodeID(id), err +} + +func MustParseNodeID(s string) NodeID { + id, err := ParseNodeID(s) + if err != nil { + panic(err) + } + + return id +} + // Node is a Headscale client. type Node struct { ID NodeID `gorm:"primary_key"` diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 46329c12..659e0a76 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -61,6 +61,7 @@ func (pak *PreAuthKey) Validate() error { } log.Debug(). + Caller(). Str("key", pak.Key). Bool("hasExpiration", pak.Expiration != nil). Time("expiration", func() time.Time { diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index b48495ea..131e8019 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -321,7 +321,7 @@ func (u *User) FromClaim(claims *OIDCClaims) { if err == nil { u.Name = claims.Username } else { - log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username) + log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username) } if claims.EmailVerified { diff --git a/integration/acl_test.go b/integration/acl_test.go index 6a6d245c..2d59ac43 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1160,57 +1160,61 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { err = headscale.SetPolicy(&p) require.NoError(t, err) - // Get the current policy and check - // if it is the same as the one we set. - var output *policyv2.Policy - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "policy", - "get", - "--output", - "json", - }, - &output, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Get the current policy and check + // if it is the same as the one we set. + var output *policyv2.Policy + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "policy", + "get", + "--output", + "json", + }, + &output, + ) + assert.NoError(ct, err) - assert.Len(t, output.ACLs, 1) + assert.Len(t, output.ACLs, 1) - if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" { - t.Errorf("unexpected policy(-want +got):\n%s", diff) - } - - // Test that user1 can visit all user2 - for _, client := range user1Clients { - for _, peer := range user2Clients { - fqdn, err := peer.FQDN() - require.NoError(t, err) - - url := fmt.Sprintf("http://%s/etc/hostname", fqdn) - t.Logf("url from %s to %s", client.Hostname(), url) - - result, err := client.Curl(url) - assert.Len(t, result, 13) - require.NoError(t, err) + if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" { + ct.Errorf("unexpected policy(-want +got):\n%s", diff) } - } + }, 30*time.Second, 1*time.Second, "verifying that the new policy took place") - // Test that user2 _cannot_ visit user1 - for _, client := range user2Clients { - for _, peer := range user1Clients { - fqdn, err := peer.FQDN() - require.NoError(t, err) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Test that user1 can visit all user2 + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + assert.NoError(ct, err) - url := fmt.Sprintf("http://%s/etc/hostname", fqdn) - t.Logf("url from %s to %s", client.Hostname(), url) + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Empty(t, result) - require.Error(t, err) + result, err := client.Curl(url) + assert.Len(ct, result, 13) + assert.NoError(ct, err) + } } - } + + // Test that user2 _cannot_ visit user1 + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + assert.NoError(ct, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + result, err := client.Curl(url) + assert.Empty(ct, result) + assert.Error(ct, err) + } + } + }, 30*time.Second, 1*time.Second, "new policy did not get propagated to nodes") } func TestACLAutogroupMember(t *testing.T) { diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 019b85f4..26c6becf 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -9,6 +9,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/samber/lo" @@ -53,6 +54,18 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + expectedNodes := make([]types.NodeID, 0, len(allClients)) + for _, client := range allClients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + assertNoErr(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second) + // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) @@ -64,9 +77,6 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { clientIPs[client] = ips } - headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) - listNodes, err := headscale.ListNodes() assert.Len(t, allClients, len(listNodes)) nodeCountBeforeLogout := len(listNodes) @@ -86,6 +96,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { err = scenario.WaitForTailscaleLogout() assertNoErrLogout(t, err) + // After taking down all nodes, verify all systems show nodes offline + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second) + t.Logf("all clients logged out") assert.EventuallyWithT(t, func(ct *assert.CollectT) { diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 6c784586..0fe1fe12 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -481,10 +481,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - listUsers, err := headscale.ListUsers() - assertNoErr(t, err) - assert.Empty(t, listUsers) - ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) assertNoErr(t, err) @@ -494,26 +490,28 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { _, err = doLoginURL(ts.Hostname(), u) assertNoErr(t, err) - listUsers, err = headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 1) - wantUsers := []*v1.User{ - { - Id: 1, - Name: "user1", - Email: "user1@headscale.net", - Provider: "oidc", - ProviderId: scenario.mockOIDC.Issuer() + "/user1", - }, - } + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 1) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }) + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - t.Fatalf("unexpected users: %s", diff) - } + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating users after first login") listNodes, err := headscale.ListNodes() assertNoErr(t, err) @@ -525,19 +523,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { err = ts.Logout() assertNoErr(t, err) + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + assertNoErr(t, err) + // Wait for logout to complete and then do second logout assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check that the first logout completed status, err := ts.Status() assert.NoError(ct, err) assert.Equal(ct, "NeedsLogin", status.BackendState) - }, 5*time.Second, 1*time.Second) - - // TODO(kradalby): Not sure why we need to logout twice, but it fails and - // logs in immediately after the first logout and I cannot reproduce it - // manually. - err = ts.Logout() - assertNoErr(t, err) + }, 30*time.Second, 1*time.Second) u, err = ts.LoginWithURL(headscale.GetEndpoint()) assertNoErr(t, err) @@ -545,56 +543,56 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { _, err = doLoginURL(ts.Hostname(), u) assertNoErr(t, err) - listUsers, err = headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 2) - wantUsers = []*v1.User{ - { - Id: 1, - Name: "user1", - Email: "user1@headscale.net", - Provider: "oidc", - ProviderId: scenario.mockOIDC.Issuer() + "/user1", - }, - { - Id: 2, - Name: "user2", - Email: "user2@headscale.net", - Provider: "oidc", - ProviderId: scenario.mockOIDC.Issuer() + "/user2", - }, - } + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 2) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + { + Id: 2, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", + }, + } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }) + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - t.Fatalf("unexpected users: %s", diff) - } + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("unexpected users: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating users after new user login") - listNodesAfterNewUserLogin, err := headscale.ListNodes() - assertNoErr(t, err) - assert.Len(t, listNodesAfterNewUserLogin, 2) + var listNodesAfterNewUserLogin []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterNewUserLogin, err = headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, listNodesAfterNewUserLogin, 2) - // Machine key is the same as the "machine" has not changed, - // but Node key is not as it is a new node - assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) - assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) - assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) + // Machine key is the same as the "machine" has not changed, + // but Node key is not as it is a new node + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) + assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) + assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) + }, 30*time.Second, 1*time.Second, "listing nodes after new user login") // Log out user2, and log into user1, no new node should be created, // the node should now "become" node1 again err = ts.Logout() assertNoErr(t, err) - // Wait for logout to complete and then do second logout - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - // Check that the first logout completed - status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "NeedsLogin", status.BackendState) - }, 5*time.Second, 1*time.Second) + t.Logf("Logged out take one") + t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it @@ -602,65 +600,92 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { err = ts.Logout() assertNoErr(t, err) + t.Logf("Logged out take two") + t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") + + // Wait for logout to complete and then do second logout + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState) + }, 30*time.Second, 1*time.Second) + + // We do not actually "change" the user here, it is done by logging in again + // as the OIDC mock server is kind of like a stack, and the next user is + // prepared and ready to go. u, err = ts.LoginWithURL(headscale.GetEndpoint()) assertNoErr(t, err) _, err = doLoginURL(ts.Hostname(), u) assertNoErr(t, err) - listUsers, err = headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 2) - wantUsers = []*v1.User{ - { - Id: 1, - Name: "user1", - Email: "user1@headscale.net", - Provider: "oidc", - ProviderId: scenario.mockOIDC.Issuer() + "/user1", - }, - { - Id: 2, - Name: "user2", - Email: "user2@headscale.net", - Provider: "oidc", - ProviderId: scenario.mockOIDC.Issuer() + "/user2", - }, - } + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "Running", status.BackendState) + }, 30*time.Second, 1*time.Second) - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }) + t.Logf("Logged back in") + t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - t.Fatalf("unexpected users: %s", diff) - } + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err) + assert.Len(ct, listUsers, 2) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + { + Id: 2, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", + }, + } - listNodesAfterLoggingBackIn, err := headscale.ListNodes() - assertNoErr(t, err) - assert.Len(t, listNodesAfterLoggingBackIn, 2) + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) - // Validate that the machine we had when we logged in the first time, has the same - // machine key, but a different ID than the newly logged in version of the same - // machine. - assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) - assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) - assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) - assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) - assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) - assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("unexpected users: %s", diff) + } + }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") - // Even tho we are logging in again with the same user, the previous key has been expired - // and a new one has been generated. The node entry in the database should be the same - // as the user + machinekey still matches. - assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) - assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) - assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterLoggingBackIn, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, listNodesAfterLoggingBackIn, 2) - // The "logged back in" machine should have the same machinekey but a different nodekey - // than the version logged in with a different user. - assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) - assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) + // Validate that the machine we had when we logged in the first time, has the same + // machine key, but a different ID than the newly logged in version of the same + // machine. + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) + assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) + assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) + assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) + + // Even tho we are logging in again with the same user, the previous key has been expired + // and a new one has been generated. The node entry in the database should be the same + // as the user + machinekey still matches. + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) + assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) + + // The "logged back in" machine should have the same machinekey but a different nodekey + // than the version logged in with a different user. + assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) + assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) + }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") } // assertTailscaleNodesLogout verifies that all provided Tailscale clients diff --git a/integration/control.go b/integration/control.go index 3994a4a5..773ddeb8 100644 --- a/integration/control.go +++ b/integration/control.go @@ -4,6 +4,7 @@ import ( "net/netip" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" diff --git a/integration/general_test.go b/integration/general_test.go index 0610ec36..cb6d83dd 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -10,18 +10,21 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale/apitype" - "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -59,13 +62,15 @@ func TestPingAllByIP(t *testing.T) { hs, err := scenario.Headscale() require.NoError(t, err) - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - all, err := hs.GetAllMapReponses() - assert.NoError(ct, err) - - onlineMap := buildExpectedOnlineMap(all) - assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap) - }, 30*time.Second, 2*time.Second) + // Extract node IDs for validation + expectedNodes := make([]types.NodeID, 0, len(allClients)) + for _, client := range allClients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err, "failed to parse node ID") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + requireAllClientsOnline(t, hs, expectedNodes, true, "all clients should be online across all systems", 30*time.Second) // assertClientsState(t, allClients) @@ -73,6 +78,14 @@ func TestPingAllByIP(t *testing.T) { return x.String() }) + // Get headscale instance for batcher debug check + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + // Test our DebugBatcher functionality + t.Logf("Testing DebugBatcher functionality...") + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to the batcher", 30*time.Second) + success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } @@ -962,9 +975,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) { ) assertNoErrHeadscaleEnv(t, err) - hs, err := scenario.Headscale() - require.NoError(t, err) - allClients, err := scenario.ListTailscaleClients() assertNoErrListClients(t, err) @@ -980,14 +990,31 @@ func TestPingAllByIPManyUpDown(t *testing.T) { return x.String() }) + // Get headscale instance for batcher debug checks + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + // Initial check: all nodes should be connected to batcher + // Extract node IDs for validation + expectedNodes := make([]types.NodeID, 0, len(allClients)) + for _, client := range allClients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + assertNoErr(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second) + success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) - wg, _ := errgroup.WithContext(context.Background()) - for run := range 3 { t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999")) + // Create fresh errgroup with timeout for each run + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + wg, _ := errgroup.WithContext(ctx) + for _, client := range allClients { c := client wg.Go(func() error { @@ -1001,6 +1028,9 @@ func TestPingAllByIPManyUpDown(t *testing.T) { } t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999")) + // After taking down all nodes, verify all systems show nodes offline + requireAllClientsOnline(t, headscale, expectedNodes, false, fmt.Sprintf("Run %d: all nodes should be offline after Down()", run+1), 120*time.Second) + for _, client := range allClients { c := client wg.Go(func() error { @@ -1014,22 +1044,22 @@ func TestPingAllByIPManyUpDown(t *testing.T) { } t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999")) + // After bringing up all nodes, verify batcher shows all reconnected + requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all nodes should be reconnected after Up()", run+1), 120*time.Second) + // Wait for sync and successful pings after nodes come back up err = scenario.WaitForTailscaleSync() assert.NoError(t, err) t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999")) - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - all, err := hs.GetAllMapReponses() - assert.NoError(ct, err) - - onlineMap := buildExpectedOnlineMap(all) - assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap) - }, 60*time.Second, 2*time.Second) + requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all systems should show nodes online after reconnection", run+1), 60*time.Second) success := pingAllHelper(t, allClients, allAddrs) assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Clean up context for this run + cancel() } } @@ -1141,51 +1171,158 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) } -func buildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool { - res := make(map[types.NodeID]map[types.NodeID]bool) - for nid, mrs := range all { - res[nid] = make(map[types.NodeID]bool) - for _, mr := range mrs { - for _, peer := range mr.Peers { - if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online - } - } - - for _, peer := range mr.PeersChanged { - if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online - } - } - - for _, peer := range mr.PeersChangedPatch { - if peer.Online != nil { - res[nid][types.NodeID(peer.NodeID)] = *peer.Online - } - } - } - } - return res +// NodeSystemStatus represents the online status of a node across different systems +type NodeSystemStatus struct { + Batcher bool + BatcherConnCount int + MapResponses bool + NodeStore bool } -func assertExpectedOnlineMapAllOnline(t *assert.CollectT, expectedPeerCount int, onlineMap map[types.NodeID]map[types.NodeID]bool) { - for nid, peers := range onlineMap { - onlineCount := 0 - for _, online := range peers { - if online { - onlineCount++ +// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore +func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format("2006-01-02T15:04:05.000"), message) + + var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get batcher state + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + // Get map responses + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate node counts first + expectedCount := len(expectedNodes) + assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch") + assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch") + + // Check that we have map responses for expected nodes + mapResponseCount := len(mapResponses) + assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch") + + // Build status map for each node + nodeStatus := make(map[types.NodeID]NodeSystemStatus) + + // Initialize all expected nodes + for _, nodeID := range expectedNodes { + nodeStatus[nodeID] = NodeSystemStatus{} + } + + // Check batcher state + for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes { + nodeID := types.MustParseNodeID(nodeIDStr) + if status, exists := nodeStatus[nodeID]; exists { + status.Batcher = nodeInfo.Connected + status.BatcherConnCount = nodeInfo.ActiveConnections + nodeStatus[nodeID] = status } } - assert.Equalf(t, expectedPeerCount, len(peers), "node:%d had an unexpected number of peers in online map", nid) - if expectedPeerCount != onlineCount { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Not all of node:%d peers where online:\n", nid)) - for pid, online := range peers { - sb.WriteString(fmt.Sprintf("\tPeer node:%d online: %t\n", pid, online)) + + // Check map responses using buildExpectedOnlineMap + onlineFromMaps := make(map[types.NodeID]bool) + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + for nodeID := range nodeStatus { + NODE_STATUS: + for id, peerMap := range onlineMap { + if id == nodeID { + continue + } + + online := peerMap[nodeID] + // If the node is offline in any map response, we consider it offline + if !online { + onlineFromMaps[nodeID] = false + continue NODE_STATUS + } + + onlineFromMaps[nodeID] = true } - sb.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") - sb.WriteString("expected all peers to be online.") - t.Errorf("%s", sb.String()) } - } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") + + // Update status with map response data + for nodeID, online := range onlineFromMaps { + if status, exists := nodeStatus[nodeID]; exists { + status.MapResponses = online + nodeStatus[nodeID] = status + } + } + + // Check nodestore state + for nodeID, node := range nodeStore { + if status, exists := nodeStatus[nodeID]; exists { + // Check if node is online in nodestore + status.NodeStore = node.IsOnline != nil && *node.IsOnline + nodeStatus[nodeID] = status + } + } + + // Verify all systems show nodes in expected state and report failures + allMatch := true + var failureReport strings.Builder + + ids := types.NodeIDs(maps.Keys(nodeStatus)) + slices.Sort(ids) + for _, nodeID := range ids { + status := nodeStatus[nodeID] + systemsMatch := (status.Batcher == expectedOnline) && + (status.MapResponses == expectedOnline) && + (status.NodeStore == expectedOnline) + + if !systemsMatch { + allMatch = false + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr)) + failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher)) + failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) + failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses)) + failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore)) + } + } + + if !allMatch { + if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" { + t.Log("Diff between reports:") + t.Logf("Prev report: \n%s\n", prevReport) + t.Logf("New report: \n%s\n", failureReport.String()) + t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") + prevReport = failureReport.String() + } + + failureReport.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") + + assert.Fail(c, failureReport.String()) + } + + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr)) + }, timeout, 2*time.Second, message) + + endTime := time.Now() + duration := endTime.Sub(startTime) + t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format("2006-01-02T15:04:05.000"), duration, message) } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index b38677b4..9c28dc00 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -22,6 +22,7 @@ import ( "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 336bf73a..4ddc7ae9 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -19,6 +19,7 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "tailscale.com/tailcfg" ) // PeerSyncTimeout returns the timeout for peer synchronization based on environment: @@ -199,3 +200,30 @@ func CreateCertificate(hostname string) ([]byte, []byte, error) { return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil } + +func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool { + res := make(map[types.NodeID]map[types.NodeID]bool) + for nid, mrs := range all { + res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { + for _, peer := range mr.Peers { + if peer.Online != nil { + res[nid][types.NodeID(peer.ID)] = *peer.Online + } + } + + for _, peer := range mr.PeersChanged { + if peer.Online != nil { + res[nid][types.NodeID(peer.ID)] = *peer.Online + } + } + + for _, peer := range mr.PeersChangedPatch { + if peer.Online != nil { + res[nid][types.NodeID(peer.NodeID)] = *peer.Online + } + } + } + } + return res +} diff --git a/tools/capver/main.go b/tools/capver/main.go index 37bab0bc..1e4512c1 100644 --- a/tools/capver/main.go +++ b/tools/capver/main.go @@ -5,7 +5,9 @@ package main import ( "encoding/json" "fmt" + "go/format" "io" + "log" "net/http" "os" "regexp" @@ -61,14 +63,14 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) { rawURL := fmt.Sprintf(rawFileURL, version) resp, err := http.Get(rawURL) if err != nil { - fmt.Printf("Error fetching raw file for version %s: %v\n", version, err) + log.Printf("Error fetching raw file for version %s: %v\n", version, err) continue } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - fmt.Printf("Error reading raw file for version %s: %v\n", version, err) + log.Printf("Error reading raw file for version %s: %v\n", version, err) continue } @@ -79,7 +81,7 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) { capabilityVersion, _ := strconv.Atoi(capabilityVersionStr) versions[version] = tailcfg.CapabilityVersion(capabilityVersion) } else { - fmt.Printf("Version: %s, CurrentCapabilityVersion not found\n", version) + log.Printf("Version: %s, CurrentCapabilityVersion not found\n", version) } } @@ -87,29 +89,23 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) { } func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion) error { - // Open the output file - file, err := os.Create(outputFile) - if err != nil { - return fmt.Errorf("error creating file: %w", err) - } - defer file.Close() - - // Write the package declaration and variable - file.WriteString("package capver\n\n") - file.WriteString("//Generated DO NOT EDIT\n\n") - file.WriteString(`import "tailscale.com/tailcfg"`) - file.WriteString("\n\n") - file.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n") + // Generate the Go code as a string + var content strings.Builder + content.WriteString("package capver\n\n") + content.WriteString("// Generated DO NOT EDIT\n\n") + content.WriteString(`import "tailscale.com/tailcfg"`) + content.WriteString("\n\n") + content.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n") sortedVersions := xmaps.Keys(versions) sort.Strings(sortedVersions) for _, version := range sortedVersions { - fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version]) + fmt.Fprintf(&content, "\t\"%s\": %d,\n", version, versions[version]) } - file.WriteString("}\n") + content.WriteString("}\n") - file.WriteString("\n\n") - file.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n") + content.WriteString("\n\n") + content.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n") capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string) for _, v := range sortedVersions { @@ -129,9 +125,21 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion return capsSorted[i] < capsSorted[j] }) for _, capVer := range capsSorted { - fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]) + fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]) + } + content.WriteString("}\n") + + // Format the generated code + formatted, err := format.Source([]byte(content.String())) + if err != nil { + return fmt.Errorf("error formatting Go code: %w", err) + } + + // Write to file + err = os.WriteFile(outputFile, formatted, 0644) + if err != nil { + return fmt.Errorf("error writing file: %w", err) } - file.WriteString("}\n") return nil } @@ -139,15 +147,15 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion func main() { versions, err := getCapabilityVersions() if err != nil { - fmt.Println("Error:", err) + log.Println("Error:", err) return } err = writeCapabilityVersionsToFile(versions) if err != nil { - fmt.Println("Error writing to file:", err) + log.Println("Error writing to file:", err) return } - fmt.Println("Capability versions written to", outputFile) + log.Println("Capability versions written to", outputFile) }