1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-16 17:50:44 +02:00

lint and leftover

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-09-05 16:32:46 +02:00 committed by Kristoffer Dalby
parent 39443184d6
commit 233dffc186
34 changed files with 1429 additions and 506 deletions

View File

@ -16,15 +16,13 @@ body:
- type: textarea - type: textarea
attributes: attributes:
label: Description label: Description
description: description: A clear and precise description of what new or changed feature you want.
A clear and precise description of what new or changed feature you want.
validations: validations:
required: true required: true
- type: checkboxes - type: checkboxes
attributes: attributes:
label: Contribution label: Contribution
description: description: Are you willing to contribute to the implementation of this feature?
Are you willing to contribute to the implementation of this feature?
options: options:
- label: I can write the design doc for this feature - label: I can write the design doc for this feature
required: false required: false
@ -33,7 +31,6 @@ body:
- type: textarea - type: textarea
attributes: attributes:
label: How can it be implemented? label: How can it be implemented?
description: description: Free text for your ideas on how this feature could be implemented.
Free text for your ideas on how this feature could be implemented.
validations: validations:
required: false required: false

View File

@ -146,12 +146,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
policyChanged, err := app.state.DeleteNode(node) policyChanged, err := app.state.DeleteNode(node)
if err != nil { if err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node") log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed")
return return
} }
app.Change(policyChanged) app.Change(policyChanged)
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node") log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached")
}) })
app.ephemeralGC = ephemeralGC app.ephemeralGC = ephemeralGC
@ -382,19 +382,20 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
) { ) {
if err := func() error { if err := func() error {
log.Trace(). log.Trace().
Caller().
Str("client_address", req.RemoteAddr).
Msg("HTTP authentication invoked")
authHeader := req.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller(). Caller().
Str("client_address", req.RemoteAddr). Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
writer.WriteHeader(http.StatusUnauthorized) writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized")) _, err := writer.Write([]byte("Unauthorized"))
if err != nil { return err
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
} }
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
@ -407,14 +408,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Unauthorized")) _, err := writer.Write([]byte("Unauthorized"))
if err != nil { return err
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
} }
if !valid { if !valid {
@ -424,13 +418,15 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writer.WriteHeader(http.StatusUnauthorized) writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized")) _, err := writer.Write([]byte("Unauthorized"))
if err != nil { return err
}
return nil
}(); err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
}
return return
} }

View File

@ -260,7 +260,7 @@ func NewHeadscaleDatabase(
log.Error().Err(err).Msg("Error creating route") log.Error().Err(err).Msg("Error creating route")
} else { } else {
log.Info(). log.Info().
Uint64("node_id", route.NodeID). Uint64("node.id", route.NodeID).
Str("prefix", prefix.String()). Str("prefix", prefix.String()).
Msg("Route migrated") Msg("Route migrated")
} }
@ -1131,7 +1131,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
} }
for _, migrationID := range migrationIDs { for _, migrationID := range migrationIDs {
log.Trace().Str("migration_id", migrationID).Msg("Running migration") log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration")
needsFKDisabled := migrationsRequiringFKDisabled[migrationID] needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
if needsFKDisabled { if needsFKDisabled {

View File

@ -275,7 +275,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
return errors.New("backfilling IPs: ip allocator was nil") return errors.New("backfilling IPs: ip allocator was nil")
} }
log.Trace().Msgf("starting to backfill IPs") log.Trace().Caller().Msgf("starting to backfill IPs")
nodes, err := ListNodes(tx) nodes, err := ListNodes(tx)
if err != nil { if err != nil {
@ -283,7 +283,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
} }
for _, node := range nodes { for _, node := range nodes {
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill") log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database")
changed := false changed := false
// IPv4 prefix is set, but node ip is missing, alloc // IPv4 prefix is set, but node ip is missing, alloc

View File

@ -34,9 +34,6 @@ var (
"node not found in registration cache", "node not found in registration cache",
) )
ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
ErrDifferentRegisteredUser = errors.New(
"node was previously registered with a different user",
)
) )
// ListPeers returns peers of node, regardless of any Policy or if the node is expired. // ListPeers returns peers of node, regardless of any Policy or if the node is expired.

View File

@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/arl/statsviz" "github.com/arl/statsviz"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"tailscale.com/tsweb" "tailscale.com/tsweb"
@ -239,6 +240,34 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(resJSON) w.Write(resJSON)
})) }))
// Batcher endpoint
debug.Handle("batcher", "Batcher connected nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON {
batcherInfo := h.debugBatcherJSON()
batcherJSON, err := json.MarshalIndent(batcherInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(batcherJSON)
} else {
// Default to text/plain for backward compatibility
batcherInfo := h.debugBatcher()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(batcherInfo))
}
}))
err := statsviz.Register(debugMux) err := statsviz.Register(debugMux)
if err == nil { if err == nil {
debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)") debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)")
@ -256,3 +285,124 @@ func (h *Headscale) debugHTTPServer() *http.Server {
return debugHTTPServer return debugHTTPServer
} }
// debugBatcher returns debug information about the batcher's connected nodes.
func (h *Headscale) debugBatcher() string {
var sb strings.Builder
sb.WriteString("=== Batcher Connected Nodes ===\n\n")
totalNodes := 0
connectedCount := 0
// Collect nodes and sort them by ID
type nodeStatus struct {
id types.NodeID
connected bool
activeConnections int
}
var nodes []nodeStatus
// Try to get detailed debug info if we have a LockFreeBatcher
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
debugInfo := batcher.Debug()
for nodeID, info := range debugInfo {
nodes = append(nodes, nodeStatus{
id: nodeID,
connected: info.Connected,
activeConnections: info.ActiveConnections,
})
totalNodes++
if info.Connected {
connectedCount++
}
}
} else {
// Fallback to basic connection info
connectedMap := h.mapBatcher.ConnectedMap()
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
nodes = append(nodes, nodeStatus{
id: nodeID,
connected: connected,
activeConnections: 0,
})
totalNodes++
if connected {
connectedCount++
}
return true
})
}
// Sort by node ID
for i := 0; i < len(nodes); i++ {
for j := i + 1; j < len(nodes); j++ {
if nodes[i].id > nodes[j].id {
nodes[i], nodes[j] = nodes[j], nodes[i]
}
}
}
// Output sorted nodes
for _, node := range nodes {
status := "disconnected"
if node.connected {
status = "connected"
}
if node.activeConnections > 0 {
sb.WriteString(fmt.Sprintf("Node %d:\t%s (%d connections)\n", node.id, status, node.activeConnections))
} else {
sb.WriteString(fmt.Sprintf("Node %d:\t%s\n", node.id, status))
}
}
sb.WriteString(fmt.Sprintf("\nSummary: %d connected, %d total\n", connectedCount, totalNodes))
return sb.String()
}
// DebugBatcherInfo represents batcher connection information in a structured format.
type DebugBatcherInfo struct {
ConnectedNodes map[string]DebugBatcherNodeInfo `json:"connected_nodes"` // NodeID -> node connection info
TotalNodes int `json:"total_nodes"`
}
// DebugBatcherNodeInfo represents connection information for a single node.
type DebugBatcherNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
}
// debugBatcherJSON returns structured debug information about the batcher's connected nodes.
func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
info := DebugBatcherInfo{
ConnectedNodes: make(map[string]DebugBatcherNodeInfo),
TotalNodes: 0,
}
// Try to get detailed debug info if we have a LockFreeBatcher
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
debugInfo := batcher.Debug()
for nodeID, debugData := range debugInfo {
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
Connected: debugData.Connected,
ActiveConnections: debugData.ActiveConnections,
}
info.TotalNodes++
}
} else {
// Fallback to basic connection info
connectedMap := h.mapBatcher.ConnectedMap()
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
Connected: connected,
ActiveConnections: 0,
}
info.TotalNodes++
return true
})
}
return info
}

View File

@ -161,7 +161,7 @@ func (d *DERPServer) DERPHandler(
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
return return
@ -199,7 +199,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
return return
@ -229,7 +229,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
return return
@ -245,7 +245,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
return return
@ -284,7 +284,7 @@ func DERPProbeHandler(
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
} }
} }
@ -330,7 +330,7 @@ func DERPBootstrapDNSHandler(
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
} }
} }

View File

@ -237,6 +237,7 @@ func (api headscaleV1APIServer) RegisterNode(
request *v1.RegisterNodeRequest, request *v1.RegisterNodeRequest,
) (*v1.RegisterNodeResponse, error) { ) (*v1.RegisterNodeResponse, error) {
log.Trace(). log.Trace().
Caller().
Str("user", request.GetUser()). Str("user", request.GetUser()).
Str("registration_id", request.GetKey()). Str("registration_id", request.GetKey()).
Msg("Registering node") Msg("Registering node")
@ -525,7 +526,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
ctx context.Context, ctx context.Context,
request *v1.BackfillNodeIPsRequest, request *v1.BackfillNodeIPsRequest,
) (*v1.BackfillNodeIPsResponse, error) { ) (*v1.BackfillNodeIPsResponse, error) {
log.Trace().Msg("Backfill called") log.Trace().Caller().Msg("Backfill called")
if !request.Confirmed { if !request.Confirmed {
return nil, errors.New("not confirmed, aborting") return nil, errors.New("not confirmed, aborting")
@ -709,6 +710,10 @@ func (api headscaleV1APIServer) SetPolicy(
UpdatedAt: timestamppb.New(updated.UpdatedAt), UpdatedAt: timestamppb.New(updated.UpdatedAt),
} }
log.Debug().
Caller().
Msg("gRPC SetPolicy completed successfully because response prepared")
return response, nil return response, nil
} }
@ -731,7 +736,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Caller(). Caller().
Interface("route-prefix", routes). Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()). Interface("route-str", request.GetRoutes()).
Msg("") Msg("Creating routes for node")
hostinfo := tailcfg.Hostinfo{ hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes, RoutableIPs: routes,
@ -760,6 +765,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
} }
log.Debug(). log.Debug().
Caller().
Str("registration_id", registrationId.String()). Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache") Msg("adding debug machine via CLI, appending to registration cache")

View File

@ -197,7 +197,7 @@ func (h *Headscale) RobotsHandler(
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4" "github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
) )
@ -23,7 +24,7 @@ type Batcher interface {
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool] ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet) AddWork(c ...change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
} }
@ -36,7 +37,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
// The size of this channel is arbitrary chosen, the sizing should be revisited. // The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200), workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *nodeConn](), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](), connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](), pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
} }
@ -47,6 +48,7 @@ func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
m := newMapper(cfg, state) m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b m.batcher = b
return b return b
} }
@ -72,8 +74,10 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
} }
var mapResp *tailcfg.MapResponse var (
var err error mapResp *tailcfg.MapResponse
err error
)
switch c.Change { switch c.Change {
case change.DERP: case change.DERP:
@ -84,10 +88,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router. // TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
mapResp, err = mapper.fullMapResponse(nodeID, version) mapResp, err = mapper.fullMapResponse(nodeID, version)
} else { } else {
// CRITICAL FIX: Read actual online status from NodeStore when available,
// fall back to deriving from change type for unit tests or when NodeStore is empty
var onlineStatus bool
if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() {
// Use actual NodeStore status when available (production case)
onlineStatus = node.IsOnline().Get()
} else {
// Fall back to deriving from change type (unit test case or initial setup)
onlineStatus = c.Change == change.NodeCameOnline
}
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{ {
NodeID: c.NodeID.NodeID(), NodeID: c.NodeID.NodeID(),
Online: ptr.To(c.Change == change.NodeCameOnline), Online: ptr.To(onlineStatus),
}, },
}) })
} }
@ -125,7 +140,12 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
} }
nodeID := nc.nodeID() nodeID := nc.nodeID()
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
var data *tailcfg.MapResponse
var err error
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
if err != nil { if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err) return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
} }
@ -136,7 +156,8 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
} }
// Send the map response // Send the map response
if err := nc.send(data); err != nil { err = nc.send(data)
if err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err) return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
} }

View File

@ -2,6 +2,7 @@ package mapper
import ( import (
"context" "context"
"crypto/rand"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -57,16 +58,21 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
version: version, version: version,
created: now, created: now,
} }
// Initialize last used timestamp
newEntry.lastUsed.Store(now.Unix())
// Only after validation succeeds, create or update node connection // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
newConn := newNodeConn(id, c, version, b.mapper) nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
if !loaded { if !loaded {
b.totalNodes.Add(1) b.totalNodes.Add(1)
conn = newConn
} }
b.connected.Store(id, nil) // nil = connected // Add connection to the list (lock-free)
nodeConn.addConnection(newEntry)
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil { if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
@ -87,6 +93,16 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
return fmt.Errorf("failed to send initial map to node %d: timeout", id) return fmt.Errorf("failed to send initial map to node %d: timeout", id)
} }
// Update connection status
b.connected.Store(id, nil) // nil = connected
// Node will automatically receive updates through the normal flow
// The initial full map already contains all current state
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection established in batcher because AddNode completed successfully")
return nil return nil
} }
@ -101,10 +117,11 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
return false return false
} }
// Mark the connection as closed to prevent further sends // Remove specific connection
if connData := existing.connData.Load(); connData != nil { removed := nodeConn.removeConnectionByChannel(c)
connData.closed.Store(true) if !removed {
} log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
return false
} }
// Check if node has any remaining active connections // Check if node has any remaining active connections
@ -115,18 +132,17 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
return true // Node still has active connections return true // Node still has active connections
} }
// Remove node and mark disconnected atomically // No active connections - keep the node entry alive for rapid reconnections
b.nodes.Delete(id) // The node will get a fresh full map when it reconnects
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
b.connected.Store(id, ptr.To(time.Now())) b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1)
return false return false
} }
// AddWork queues a change to be processed by the batcher. // AddWork queues a change to be processed by the batcher.
// Critical changes are processed immediately, while others are batched for efficiency. func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) { b.addWork(c...)
b.addWork(c)
} }
func (b *LockFreeBatcher) Start() { func (b *LockFreeBatcher) Start() {
@ -137,23 +153,36 @@ func (b *LockFreeBatcher) Start() {
func (b *LockFreeBatcher) Close() { func (b *LockFreeBatcher) Close() {
if b.cancel != nil { if b.cancel != nil {
b.cancel() b.cancel()
b.cancel = nil // Prevent multiple calls
} }
// Only close workCh once
select {
case <-b.workCh:
// Channel is already closed
default:
close(b.workCh) close(b.workCh)
}
} }
func (b *LockFreeBatcher) doWork() { func (b *LockFreeBatcher) doWork() {
log.Debug().Msg("batcher doWork loop started")
defer log.Debug().Msg("batcher doWork loop stopped")
for i := range b.workers { for i := range b.workers {
go b.worker(i + 1) go b.worker(i + 1)
} }
// Create a cleanup ticker for removing truly disconnected nodes
cleanupTicker := time.NewTicker(5 * time.Minute)
defer cleanupTicker.Stop()
for { for {
select { select {
case <-b.tick.C: case <-b.tick.C:
// Process batched changes // Process batched changes
b.processBatchedChanges() b.processBatchedChanges()
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.ctx.Done(): case <-b.ctx.Done():
return return
} }
@ -161,8 +190,6 @@ func (b *LockFreeBatcher) doWork() {
} }
func (b *LockFreeBatcher) worker(workerID int) { func (b *LockFreeBatcher) worker(workerID int) {
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
for { for {
select { select {
@ -171,7 +198,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
return return
} }
startTime := time.Now()
b.workProcessed.Add(1) b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request // If the resultCh is set, it means that this is a work request
@ -181,7 +207,9 @@ func (b *LockFreeBatcher) worker(workerID int) {
if w.resultCh != nil { if w.resultCh != nil {
var result workResult var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists { if nc, exists := b.nodes.Load(w.nodeID); exists {
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c) var err error
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
result.err = err
if result.err != nil { if result.err != nil {
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(result.err). log.Error().Err(result.err).
@ -192,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
} }
} else { } else {
result.err = fmt.Errorf("node %d not found", w.nodeID) result.err = fmt.Errorf("node %d not found", w.nodeID)
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(result.err). log.Error().Err(result.err).
Int("workerID", workerID). Int("workerID", workerID).
@ -260,19 +289,22 @@ func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
}) })
return return
} }
}
all, self := change.SplitAllAndSelf(c)
for _, changeSet := range self {
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
changes = append(changes, changeSet)
b.pendingChanges.Store(changeSet.NodeID, changes)
return return
} }
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool { b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() { rel := change.RemoveUpdatesForSelf(nodeID, all)
return true
}
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c) changes = append(changes, rel...)
b.pendingChanges.Store(nodeID, changes) b.pendingChanges.Store(nodeID, changes)
return true return true
@ -303,7 +335,44 @@ func (b *LockFreeBatcher) processBatchedChanges() {
}) })
} }
// IsConnected is lock-free read. // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
func (b *LockFreeBatcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute
now := time.Now()
var nodesToCleanup []types.NodeID
// Find nodes that have been offline for too long
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
// Double-check the node doesn't have active connections
if nodeConn, exists := b.nodes.Load(nodeID); exists {
if !nodeConn.hasActiveConnections() {
nodesToCleanup = append(nodesToCleanup, nodeID)
}
}
}
return true
})
// Clean up the identified nodes
for _, nodeID := range nodesToCleanup {
log.Info().Uint64("node.id", nodeID.Uint64()).
Dur("offline_duration", cleanupThreshold).
Msg("Cleaning up node that has been offline for too long")
b.nodes.Delete(nodeID)
b.connected.Delete(nodeID)
b.totalNodes.Add(-1)
}
if len(nodesToCleanup) > 0 {
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
Msg("Completed cleanup of long-offline nodes")
}
}
// IsConnected is lock-free read that checks if a node has any active connections.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
// First check if we have active connections for this node // First check if we have active connections for this node
if nodeConn, exists := b.nodes.Load(id); exists { if nodeConn, exists := b.nodes.Load(id); exists {
@ -373,89 +442,234 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
} }
} }
// connectionData holds the channel and connection parameters. // connectionEntry represents a single connection to a node.
type connectionData struct { type connectionEntry struct {
id string // unique connection ID
c chan<- *tailcfg.MapResponse c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion version tailcfg.CapabilityVersion
closed atomic.Bool // Track if this connection has been closed created time.Time
lastUsed atomic.Int64 // Unix timestamp of last successful send
} }
// nodeConn described the node connection and its associated data. // multiChannelNodeConn manages multiple concurrent connections for a single node.
type nodeConn struct { type multiChannelNodeConn struct {
id types.NodeID id types.NodeID
mapper *mapper mapper *mapper
// Atomic pointer to connection data - allows lock-free updates mutex sync.RWMutex
connData atomic.Pointer[connectionData] connections []*connectionEntry
updateCount atomic.Int64 updateCount atomic.Int64
} }
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn { // generateConnectionID generates a unique connection identifier.
nc := &nodeConn{ func generateConnectionID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id, id: id,
mapper: mapper, mapper: mapper,
} }
// Initialize connection data
data := &connectionData{
c: c,
version: version,
}
nc.connData.Store(data)
return nc
} }
// updateConnection atomically updates connection parameters. // addConnection adds a new connection.
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) { func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
newData := &connectionData{ mutexWaitStart := time.Now()
c: c, log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
version: version, Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
}
nc.connData.Store(newData) mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
Int("total_connections", len(mc.connections)).
Dur("mutex_wait_time", mutexWaitDur).
Msg("Successfully added connection after mutex wait")
} }
// matchesChannel checks if the given channel matches current connection. // removeConnectionByChannel removes a connection by matching channel pointer.
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool { func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
data := nc.connData.Load() mc.mutex.Lock()
if data == nil { defer mc.mutex.Unlock()
for i, entry := range mc.connections {
if entry.c == c {
// Remove this connection
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("Successfully removed connection")
return true
}
}
return false return false
}
// Compare channel pointers directly
return data.c == c
} }
// compressAndVersion atomically reads connection settings. // hasActiveConnections checks if the node has any active connections.
func (nc *nodeConn) version() tailcfg.CapabilityVersion { func (mc *multiChannelNodeConn) hasActiveConnections() bool {
data := nc.connData.Load() mc.mutex.RLock()
if data == nil { defer mc.mutex.RUnlock()
return len(mc.connections) > 0
}
// getActiveConnectionCount returns the number of active connections.
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections)
}
// send broadcasts data to all active connections for the node.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
mc.mutex.Lock()
defer mc.mutex.Unlock()
if len(mc.connections) == 0 {
// During rapid reconnection, nodes may temporarily have no active connections
// This is not an error - the node will receive a full map when it reconnects
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error
}
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Int("total_connections", len(mc.connections)).
Msg("send: broadcasting to all connections")
var lastErr error
successCount := 0
var failedConnections []int // Track failed connections for removal
// Send to all connections
for i, conn := range mc.connections {
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: attempting to send to connection")
if err := conn.send(data); err != nil {
lastErr = err
failedConnections = append(failedConnections, i)
log.Warn().Err(err).
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: connection send failed")
} else {
successCount++
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: successfully sent to connection")
}
}
// Remove failed connections (in reverse order to maintain indices)
for i := len(failedConnections) - 1; i >= 0; i-- {
idx := failedConnections[i]
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Str("conn.id", mc.connections[idx].id).
Msg("send: removing failed connection")
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
}
mc.updateCount.Add(1)
log.Info().Uint64("node.id", mc.id.Uint64()).
Int("successful_sends", successCount).
Int("failed_connections", len(failedConnections)).
Int("remaining_connections", len(mc.connections)).
Msg("send: completed broadcast")
// Success if at least one send succeeded
if successCount > 0 {
return nil
}
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
}
// send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.
select {
case entry.c <- data:
// Update last used timestamp on successful send
entry.lastUsed.Store(time.Now().Unix())
return nil
case <-time.After(50 * time.Millisecond):
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
}
}
// nodeID returns the node ID.
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
return mc.id
}
// version returns the capability version from the first active connection.
// All connections for a node should have the same version in practice.
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if len(mc.connections) == 0 {
return 0 return 0
} }
return data.version return mc.connections[0].version
} }
func (nc *nodeConn) nodeID() types.NodeID { // change applies a change to all active connections for the node.
return nc.id func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
return handleNodeChange(mc, mc.mapper, c)
} }
func (nc *nodeConn) change(c change.ChangeSet) error { // DebugNodeInfo contains debug information about a node's connections.
return handleNodeChange(nc, nc.mapper, c) type DebugNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
} }
// send sends data to the node's channel. // Debug returns a pre-baked map of node debug information for the debug interface.
// The node will pick it up and send it to the HTTP handler. func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
func (nc *nodeConn) send(data *tailcfg.MapResponse) error { result := make(map[types.NodeID]DebugNodeInfo)
connData := nc.connData.Load()
if connData == nil { // Get all nodes with their connection status using immediate connection logic
return fmt.Errorf("node %d: no connection data", nc.id) // (no grace period) for debug purposes
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
nodeConn.mutex.RLock()
activeConnCount := len(nodeConn.connections)
nodeConn.mutex.RUnlock()
// Use immediate connection status: if active connections exist, node is connected
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
connected := false
if activeConnCount > 0 {
connected = true
} else {
// Check connected map for immediate status
if val, ok := b.connected.Load(id); ok && val == nil {
connected = true
}
} }
// Check if connection has been closed result[id] = DebugNodeInfo{
if connData.closed.Load() { Connected: connected,
return fmt.Errorf("node %d: connection closed", nc.id) ActiveConnections: activeConnCount,
} }
return true
})
// Add all entries from the connected map to capture both connected and disconnected nodes // Add all entries from the connected map to capture both connected and disconnected nodes
b.connected.Range(func(id types.NodeID, val *time.Time) bool { b.connected.Range(func(id types.NodeID, val *time.Time) bool {

View File

@ -209,6 +209,7 @@ func setupBatcherWithTestData(
// Create test users and nodes in the database // Create test users and nodes in the database
users := database.CreateUsersForTest(userCount, "testuser") users := database.CreateUsersForTest(userCount, "testuser")
allNodes := make([]node, 0, userCount*nodesPerUser) allNodes := make([]node, 0, userCount*nodesPerUser)
for _, user := range users { for _, user := range users {
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node") dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
@ -353,6 +354,7 @@ func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected b
if len(resp.PeersChangedPatch) > 0 { if len(resp.PeersChangedPatch) > 0 {
require.Len(t, resp.PeersChangedPatch, 1) require.Len(t, resp.PeersChangedPatch, 1)
assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online) assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online)
return return
} }
@ -412,6 +414,7 @@ func (n *node) start() {
n.maxPeersCount = info.PeerCount n.maxPeersCount = info.PeerCount
} }
} }
if info.IsPatch { if info.IsPatch {
atomic.AddInt64(&n.patchCount, 1) atomic.AddInt64(&n.patchCount, 1)
// For patches, we track how many patch items // For patches, we track how many patch items
@ -550,6 +553,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Reduce verbose application logging for cleaner test output // Reduce verbose application logging for cleaner test output
originalLevel := zerolog.GlobalLevel() originalLevel := zerolog.GlobalLevel()
defer zerolog.SetGlobalLevel(originalLevel) defer zerolog.SetGlobalLevel(originalLevel)
zerolog.SetGlobalLevel(zerolog.ErrorLevel) zerolog.SetGlobalLevel(zerolog.ErrorLevel)
// Test cases: different node counts to stress test the all-to-all connectivity // Test cases: different node counts to stress test the all-to-all connectivity
@ -618,6 +622,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Join all nodes as fast as possible // Join all nodes as fast as possible
t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
@ -693,6 +698,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
if stats.MaxPeersSeen > maxPeersGlobal { if stats.MaxPeersSeen > maxPeersGlobal {
maxPeersGlobal = stats.MaxPeersSeen maxPeersGlobal = stats.MaxPeersSeen
} }
if stats.MaxPeersSeen < minPeersSeen { if stats.MaxPeersSeen < minPeersSeen {
minPeersSeen = stats.MaxPeersSeen minPeersSeen = stats.MaxPeersSeen
} }
@ -730,9 +736,11 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Show sample of node details // Show sample of node details
if len(nodeDetails) > 0 { if len(nodeDetails) > 0 {
t.Logf(" Node sample:") t.Logf(" Node sample:")
for _, detail := range nodeDetails[:min(5, len(nodeDetails))] { for _, detail := range nodeDetails[:min(5, len(nodeDetails))] {
t.Logf(" %s", detail) t.Logf(" %s", detail)
} }
if len(nodeDetails) > 5 { if len(nodeDetails) > 5 {
t.Logf(" ... (%d more nodes)", len(nodeDetails)-5) t.Logf(" ... (%d more nodes)", len(nodeDetails)-5)
} }
@ -754,6 +762,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Show details of failed nodes for debugging // Show details of failed nodes for debugging
if len(nodeDetails) > 5 { if len(nodeDetails) > 5 {
t.Logf("Failed nodes details:") t.Logf("Failed nodes details:")
for _, detail := range nodeDetails[5:] { for _, detail := range nodeDetails[5:] {
if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) { if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) {
t.Logf(" %s", detail) t.Logf(" %s", detail)
@ -875,6 +884,7 @@ func TestBatcherBasicOperations(t *testing.T) {
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
count := 0 count := 0
timer := time.NewTimer(timeout) timer := time.NewTimer(timeout)
defer timer.Stop() defer timer.Stop()
@ -1026,10 +1036,12 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// Collect updates with timeout // Collect updates with timeout
updateCount := 0 updateCount := 0
timeout := time.After(200 * time.Millisecond) timeout := time.After(200 * time.Millisecond)
for { for {
select { select {
case data := <-ch: case data := <-ch:
updateCount++ updateCount++
receivedUpdates = append(receivedUpdates, data) receivedUpdates = append(receivedUpdates, data)
// Validate update content // Validate update content
@ -1058,6 +1070,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// Validate that all updates have valid content // Validate that all updates have valid content
validUpdates := 0 validUpdates := 0
for _, data := range receivedUpdates { for _, data := range receivedUpdates {
if data != nil { if data != nil {
if valid, _ := validateUpdateContent(data); valid { if valid, _ := validateUpdateContent(data); valid {
@ -1095,16 +1108,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
batcher := testData.Batcher batcher := testData.Batcher
testNode := testData.Nodes[0] testNode := testData.Nodes[0]
var channelIssues int
var mutex sync.Mutex var (
channelIssues int
mutex sync.Mutex
)
// Run rapid connect/disconnect cycles with real updates to test channel closing // Run rapid connect/disconnect cycles with real updates to test channel closing
for i := range 100 { for i := range 100 {
var wg sync.WaitGroup var wg sync.WaitGroup
// First connection // First connection
ch1 := make(chan *tailcfg.MapResponse, 1) ch1 := make(chan *tailcfg.MapResponse, 1)
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
@ -1118,17 +1137,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
// Rapid second connection - should replace ch1 // Rapid second connection - should replace ch1
ch2 := make(chan *tailcfg.MapResponse, 1) ch2 := make(chan *tailcfg.MapResponse, 1)
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
time.Sleep(1 * time.Microsecond) time.Sleep(1 * time.Microsecond)
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
}() }()
// Remove second connection // Remove second connection
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
time.Sleep(2 * time.Microsecond) time.Sleep(2 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch2) batcher.RemoveNode(testNode.n.ID, ch2)
}() }()
@ -1143,7 +1167,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
case <-time.After(1 * time.Millisecond): case <-time.After(1 * time.Millisecond):
// If no data received, increment issues counter // If no data received, increment issues counter
mutex.Lock() mutex.Lock()
channelIssues++ channelIssues++
mutex.Unlock() mutex.Unlock()
} }
@ -1185,18 +1211,24 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
batcher := testData.Batcher batcher := testData.Batcher
testNode := testData.Nodes[0] testNode := testData.Nodes[0]
var panics int
var channelErrors int var (
var invalidData int panics int
var mutex sync.Mutex channelErrors int
invalidData int
mutex sync.Mutex
)
// Test rapid connect/disconnect with work generation // Test rapid connect/disconnect with work generation
for i := range 50 { for i := range 50 {
func() { func() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
mutex.Lock() mutex.Lock()
panics++ panics++
mutex.Unlock() mutex.Unlock()
t.Logf("Panic caught: %v", r) t.Logf("Panic caught: %v", r)
} }
@ -1213,7 +1245,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
mutex.Lock() mutex.Lock()
channelErrors++ channelErrors++
mutex.Unlock() mutex.Unlock()
t.Logf("Channel consumer panic: %v", r) t.Logf("Channel consumer panic: %v", r)
} }
@ -1229,7 +1263,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Validate the data we received // Validate the data we received
if valid, reason := validateUpdateContent(data); !valid { if valid, reason := validateUpdateContent(data); !valid {
mutex.Lock() mutex.Lock()
invalidData++ invalidData++
mutex.Unlock() mutex.Unlock()
t.Logf("Invalid data received: %s", reason) t.Logf("Invalid data received: %s", reason)
} }
@ -1268,9 +1304,11 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
if panics > 0 { if panics > 0 {
t.Errorf("Worker channel safety failed with %d panics", panics) t.Errorf("Worker channel safety failed with %d panics", panics)
} }
if channelErrors > 0 { if channelErrors > 0 {
t.Errorf("Channel handling failed with %d channel errors", channelErrors) t.Errorf("Channel handling failed with %d channel errors", channelErrors)
} }
if invalidData > 0 { if invalidData > 0 {
t.Errorf("Data validation failed with %d invalid data packets", invalidData) t.Errorf("Data validation failed with %d invalid data packets", invalidData)
} }
@ -1342,15 +1380,19 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Use remaining nodes for connection churn testing // Use remaining nodes for connection churn testing
churningNodes := allNodes[len(allNodes)/2:] churningNodes := allNodes[len(allNodes)/2:]
churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
var churningChannelsMutex sync.Mutex // Protect concurrent map access var churningChannelsMutex sync.Mutex // Protect concurrent map access
var wg sync.WaitGroup var wg sync.WaitGroup
numCycles := 10 // Reduced for simpler test numCycles := 10 // Reduced for simpler test
panicCount := 0 panicCount := 0
var panicMutex sync.Mutex var panicMutex sync.Mutex
// Track deadlock with timeout // Track deadlock with timeout
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
@ -1364,16 +1406,22 @@ func TestBatcherConcurrentClients(t *testing.T) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
panicMutex.Lock() panicMutex.Lock()
panicCount++ panicCount++
panicMutex.Unlock() panicMutex.Unlock()
t.Logf("Panic in churning connect: %v", r) t.Logf("Panic in churning connect: %v", r)
} }
wg.Done() wg.Done()
}() }()
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock() churningChannelsMutex.Lock()
churningChannels[nodeID] = ch churningChannels[nodeID] = ch
churningChannelsMutex.Unlock() churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
@ -1400,17 +1448,23 @@ func TestBatcherConcurrentClients(t *testing.T) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
panicMutex.Lock() panicMutex.Lock()
panicCount++ panicCount++
panicMutex.Unlock() panicMutex.Unlock()
t.Logf("Panic in churning disconnect: %v", r) t.Logf("Panic in churning disconnect: %v", r)
} }
wg.Done() wg.Done()
}() }()
time.Sleep(time.Duration(i%5) * time.Millisecond) time.Sleep(time.Duration(i%5) * time.Millisecond)
churningChannelsMutex.Lock() churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID] ch, exists := churningChannels[nodeID]
churningChannelsMutex.Unlock() churningChannelsMutex.Unlock()
if exists { if exists {
batcher.RemoveNode(nodeID, ch) batcher.RemoveNode(nodeID, ch)
} }
@ -1422,10 +1476,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
// DERP changes // DERP changes
batcher.AddWork(change.DERPSet) batcher.AddWork(change.DERPSet)
} }
if i%5 == 0 { if i%5 == 0 {
// Full updates using real node data // Full updates using real node data
batcher.AddWork(change.FullSet) batcher.AddWork(change.FullSet)
} }
if i%7 == 0 && len(allNodes) > 0 { if i%7 == 0 && len(allNodes) > 0 {
// Node-specific changes using real nodes // Node-specific changes using real nodes
node := allNodes[i%len(allNodes)] node := allNodes[i%len(allNodes)]
@ -1453,7 +1509,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Validate results // Validate results
panicMutex.Lock() panicMutex.Lock()
finalPanicCount := panicCount finalPanicCount := panicCount
panicMutex.Unlock() panicMutex.Unlock()
allStats := tracker.getAllStats() allStats := tracker.getAllStats()
@ -1536,6 +1594,7 @@ func XTestBatcherScalability(t *testing.T) {
// Reduce verbose application logging for cleaner test output // Reduce verbose application logging for cleaner test output
originalLevel := zerolog.GlobalLevel() originalLevel := zerolog.GlobalLevel()
defer zerolog.SetGlobalLevel(originalLevel) defer zerolog.SetGlobalLevel(originalLevel)
zerolog.SetGlobalLevel(zerolog.ErrorLevel) zerolog.SetGlobalLevel(zerolog.ErrorLevel)
// Full test matrix for scalability testing // Full test matrix for scalability testing
@ -1624,6 +1683,7 @@ func XTestBatcherScalability(t *testing.T) {
batcher := testData.Batcher batcher := testData.Batcher
allNodes := testData.Nodes allNodes := testData.Nodes
t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description) t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description)
t.Logf( t.Logf(
" Cycles: %d, Buffer Size: %d, Chaos Type: %s", " Cycles: %d, Buffer Size: %d, Chaos Type: %s",
@ -1660,12 +1720,16 @@ func XTestBatcherScalability(t *testing.T) {
// Connect all nodes first so they can see each other as peers // Connect all nodes first so they can see each other as peers
connectedNodes := make(map[types.NodeID]bool) connectedNodes := make(map[types.NodeID]bool)
var connectedNodesMutex sync.RWMutex var connectedNodesMutex sync.RWMutex
for i := range testNodes { for i := range testNodes {
node := &testNodes[i] node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true connectedNodes[node.n.ID] = true
connectedNodesMutex.Unlock() connectedNodesMutex.Unlock()
} }
@ -1676,6 +1740,7 @@ func XTestBatcherScalability(t *testing.T) {
go func() { go func() {
defer close(done) defer close(done)
var wg sync.WaitGroup var wg sync.WaitGroup
t.Logf( t.Logf(
@ -1697,14 +1762,17 @@ func XTestBatcherScalability(t *testing.T) {
// For chaos testing, only disconnect/reconnect a subset of nodes // For chaos testing, only disconnect/reconnect a subset of nodes
// This ensures some nodes stay connected to continue receiving updates // This ensures some nodes stay connected to continue receiving updates
startIdx := cycle % len(testNodes) startIdx := cycle % len(testNodes)
endIdx := startIdx + len(testNodes)/4 endIdx := startIdx + len(testNodes)/4
if endIdx > len(testNodes) { if endIdx > len(testNodes) {
endIdx = len(testNodes) endIdx = len(testNodes)
} }
if startIdx >= endIdx { if startIdx >= endIdx {
startIdx = 0 startIdx = 0
endIdx = min(len(testNodes)/4, len(testNodes)) endIdx = min(len(testNodes)/4, len(testNodes))
} }
chaosNodes := testNodes[startIdx:endIdx] chaosNodes := testNodes[startIdx:endIdx]
if len(chaosNodes) == 0 { if len(chaosNodes) == 0 {
chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos
@ -1722,17 +1790,22 @@ func XTestBatcherScalability(t *testing.T) {
if r := recover(); r != nil { if r := recover(); r != nil {
atomic.AddInt64(&panicCount, 1) atomic.AddInt64(&panicCount, 1)
} }
wg.Done() wg.Done()
}() }()
connectedNodesMutex.RLock() connectedNodesMutex.RLock()
isConnected := connectedNodes[nodeID] isConnected := connectedNodes[nodeID]
connectedNodesMutex.RUnlock() connectedNodesMutex.RUnlock()
if isConnected { if isConnected {
batcher.RemoveNode(nodeID, channel) batcher.RemoveNode(nodeID, channel)
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[nodeID] = false connectedNodes[nodeID] = false
connectedNodesMutex.Unlock() connectedNodesMutex.Unlock()
} }
}( }(
@ -1746,6 +1819,7 @@ func XTestBatcherScalability(t *testing.T) {
if r := recover(); r != nil { if r := recover(); r != nil {
atomic.AddInt64(&panicCount, 1) atomic.AddInt64(&panicCount, 1)
} }
wg.Done() wg.Done()
}() }()
@ -1757,7 +1831,9 @@ func XTestBatcherScalability(t *testing.T) {
tailcfg.CapabilityVersion(100), tailcfg.CapabilityVersion(100),
) )
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[nodeID] = true connectedNodes[nodeID] = true
connectedNodesMutex.Unlock() connectedNodesMutex.Unlock()
// Add work to create load // Add work to create load
@ -1776,11 +1852,13 @@ func XTestBatcherScalability(t *testing.T) {
updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count
for i := range updateCount { for i := range updateCount {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
atomic.AddInt64(&panicCount, 1) atomic.AddInt64(&panicCount, 1)
} }
wg.Done() wg.Done()
}() }()
@ -1823,11 +1901,14 @@ func XTestBatcherScalability(t *testing.T) {
deadlockDetected = true deadlockDetected = true
// Collect diagnostic information // Collect diagnostic information
allStats := tracker.getAllStats() allStats := tracker.getAllStats()
totalUpdates := 0 totalUpdates := 0
for _, stats := range allStats { for _, stats := range allStats {
totalUpdates += stats.TotalUpdates totalUpdates += stats.TotalUpdates
} }
interimPanics := atomic.LoadInt64(&panicCount) interimPanics := atomic.LoadInt64(&panicCount)
t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT) t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT)
t.Logf( t.Logf(
" Progress at timeout: %d total updates, %d panics", " Progress at timeout: %d total updates, %d panics",
@ -1873,6 +1954,7 @@ func XTestBatcherScalability(t *testing.T) {
stats := node.cleanup() stats := node.cleanup()
totalUpdates += stats.TotalUpdates totalUpdates += stats.TotalUpdates
totalPatches += stats.PatchUpdates totalPatches += stats.PatchUpdates
totalFull += stats.FullUpdates totalFull += stats.FullUpdates
if stats.MaxPeersSeen > maxPeersGlobal { if stats.MaxPeersSeen > maxPeersGlobal {
maxPeersGlobal = stats.MaxPeersSeen maxPeersGlobal = stats.MaxPeersSeen
@ -1910,10 +1992,12 @@ func XTestBatcherScalability(t *testing.T) {
// Legacy tracker comparison (optional) // Legacy tracker comparison (optional)
allStats := tracker.getAllStats() allStats := tracker.getAllStats()
legacyTotalUpdates := 0 legacyTotalUpdates := 0
for _, stats := range allStats { for _, stats := range allStats {
legacyTotalUpdates += stats.TotalUpdates legacyTotalUpdates += stats.TotalUpdates
} }
if legacyTotalUpdates != int(totalUpdates) { if legacyTotalUpdates != int(totalUpdates) {
t.Logf( t.Logf(
"Note: Legacy tracker mismatch - legacy: %d, new: %d", "Note: Legacy tracker mismatch - legacy: %d, new: %d",
@ -1926,6 +2010,7 @@ func XTestBatcherScalability(t *testing.T) {
// Validation based on expectation // Validation based on expectation
testPassed := true testPassed := true
if tc.expectBreak { if tc.expectBreak {
// For tests expected to break, we're mainly checking that we don't crash // For tests expected to break, we're mainly checking that we don't crash
if finalPanicCount > 0 { if finalPanicCount > 0 {
@ -1947,14 +2032,19 @@ func XTestBatcherScalability(t *testing.T) {
// For tests expected to pass, validate proper operation // For tests expected to pass, validate proper operation
if finalPanicCount > 0 { if finalPanicCount > 0 {
t.Errorf("Scalability test failed with %d panics", finalPanicCount) t.Errorf("Scalability test failed with %d panics", finalPanicCount)
testPassed = false testPassed = false
} }
if deadlockDetected { if deadlockDetected {
t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes)) t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes))
testPassed = false testPassed = false
} }
if totalUpdates == 0 { if totalUpdates == 0 {
t.Error("No updates received - system may be completely stalled") t.Error("No updates received - system may be completely stalled")
testPassed = false testPassed = false
} }
} }
@ -2020,6 +2110,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Read all available updates for each node // Read all available updates for each node
for i := range allNodes { for i := range allNodes {
nodeUpdates := 0 nodeUpdates := 0
t.Logf("Reading updates for node %d:", i) t.Logf("Reading updates for node %d:", i)
// Read up to 10 updates per node or until timeout/no more data // Read up to 10 updates per node or until timeout/no more data
@ -2056,6 +2147,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
if len(data.Peers) > 0 { if len(data.Peers) > 0 {
t.Logf(" Full peer list with %d peers", len(data.Peers)) t.Logf(" Full peer list with %d peers", len(data.Peers))
for j, peer := range data.Peers[:min(3, len(data.Peers))] { for j, peer := range data.Peers[:min(3, len(data.Peers))] {
t.Logf( t.Logf(
" Peer %d: NodeID=%d, Online=%v", " Peer %d: NodeID=%d, Online=%v",
@ -2065,8 +2157,10 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
) )
} }
} }
if len(data.PeersChangedPatch) > 0 { if len(data.PeersChangedPatch) > 0 {
t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch)) t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch))
for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] { for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] {
t.Logf( t.Logf(
" Patch %d: NodeID=%d, Online=%v", " Patch %d: NodeID=%d, Online=%v",
@ -2080,6 +2174,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):
} }
} }
t.Logf("Node %d received %d updates", i, nodeUpdates) t.Logf("Node %d received %d updates", i, nodeUpdates)
} }
@ -2095,71 +2190,132 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
} }
} }
// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items. // TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID
func TestBatcherWorkQueueTracing(t *testing.T) { // at the same time cause /debug/batcher to show nodes as disconnected when they should be connected.
// This specifically tests the multi-channel batcher implementation issue.
func TestBatcherRapidReconnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10)
defer cleanup()
batcher := testData.Batcher
allNodes := testData.Nodes
t.Logf("=== RAPID RECONNECTION TEST ===")
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
// Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...")
for i, node := range allNodes {
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node %d: %v", i, err)
}
}
time.Sleep(100 * time.Millisecond) // Let connections settle
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
for i, node := range allNodes {
removed := batcher.RemoveNode(node.n.ID, node.ch)
t.Logf("Node %d RemoveNode result: %t", i, removed)
}
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i, node := range allNodes {
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err)
}
}
time.Sleep(100 * time.Millisecond) // Let reconnections settle
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
t.Logf("Phase 4: Checking debug status...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
debugInfo := debugBatcher.Debug()
disconnectedCount := 0
for i, node := range allNodes {
if info, exists := debugInfo[node.n.ID]; exists {
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
// Check if the debug info shows the node as connected
if infoMap, ok := info.(map[string]any); ok {
if connected, ok := infoMap["connected"].(bool); ok && !connected {
disconnectedCount++
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
}
}
} else {
disconnectedCount++
t.Logf("Node %d missing from debug info entirely", i)
}
// Also check IsConnected method
if !batcher.IsConnected(node.n.ID) {
t.Logf("Node %d IsConnected() returns false", i)
}
}
if disconnectedCount > 0 {
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
// This is expected behavior for multi-channel batcher according to user
// "it has never worked with the multi"
} else {
t.Logf("All nodes show as connected - working correctly")
}
} else {
t.Logf("Batcher does not implement Debug() method")
}
// Phase 5: Test if "disconnected" nodes can actually receive updates
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
// Send a change that should reach all nodes
batcher.AddWork(change.DERPChange())
receivedCount := 0
timeout := time.After(500 * time.Millisecond)
for i := 0; i < len(allNodes); i++ {
select {
case update := <-newChannels[i]:
if update != nil {
receivedCount++
t.Logf("Node %d received update successfully", i)
}
case <-timeout:
t.Logf("Node %d timed out waiting for update", i)
goto done
}
}
done:
t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes))
if receivedCount < len(allNodes) {
t.Logf("Some nodes failed to receive updates - confirming the issue")
}
})
}
}
func TestBatcherMultiConnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions { for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) { t.Run(batcherFunc.name, func(t *testing.T) {
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10) testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10)
defer cleanup() defer cleanup()
batcher := testData.Batcher
nodes := testData.Nodes
t.Logf("=== WORK QUEUE TRACING TEST ===")
time.Sleep(100 * time.Millisecond) // Let connections settle
// Wait for initial NodeCameOnline to be processed
time.Sleep(200 * time.Millisecond)
// Drain any initial updates
drainedCount := 0
for {
select {
case <-nodes[0].ch:
drainedCount++
case <-time.After(100 * time.Millisecond):
goto drained
}
}
drained:
t.Logf("Drained %d initial updates", drainedCount)
// Now send a single FullSet update and trace it closely
t.Logf("Sending change.FullSet work item...")
batcher.AddWork(change.FullSet)
// Give short time for processing
time.Sleep(100 * time.Millisecond)
// Check if any update was received
select {
case data := <-nodes[0].ch:
t.Logf("SUCCESS: Received update after FullSet!")
if data != nil {
// Detailed analysis of the response - data is already a MapResponse
t.Logf("Response details:")
t.Logf(" Peers: %d", len(data.Peers))
t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch))
t.Logf(" PeersChanged: %d", len(data.PeersChanged))
t.Logf(" PeersRemoved: %d", len(data.PeersRemoved))
t.Logf(" DERPMap: %v", data.DERPMap != nil)
t.Logf(" KeepAlive: %v", data.KeepAlive)
t.Logf(" Node: %v", data.Node != nil)
if len(data.Peers) > 0 {
t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers))
} else if len(data.PeersChangedPatch) > 0 {
t.Errorf("ERROR: Received patch update instead of full update!")
} else if data.DERPMap != nil {
t.Logf("Received DERP map update")
} else if data.Node != nil {
t.Logf("Received self node update")
} else {
t.Errorf("ERROR: Received unknown update type!")
}
batcher := testData.Batcher batcher := testData.Batcher
node1 := testData.Nodes[0] node1 := testData.Nodes[0]
node2 := testData.Nodes[1] node2 := testData.Nodes[1]
@ -2328,12 +2484,53 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
} }
} }
} }
} else {
t.Errorf("Response data is nil")
} }
case <-time.After(2 * time.Second): }
t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!")
t.Errorf("This indicates FullSet work items are not being processed at all") // Send another update and verify remaining connections still work
clearChannel(node1.ch)
clearChannel(thirdChannel)
testChangeSet2 := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
batcher.AddWork(testChangeSet2)
time.Sleep(100 * time.Millisecond)
// Verify remaining connections still receive updates
remaining1Received := false
remaining3Received := false
select {
case mapResp := <-node1.ch:
remaining1Received = (mapResp != nil)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 1 did not receive update after removal")
}
select {
case mapResp := <-thirdChannel:
remaining3Received = (mapResp != nil)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 3 did not receive update after removal")
}
if remaining1Received && remaining3Received {
t.Logf("SUCCESS: Remaining connections still receive updates after removal")
} else {
t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t",
remaining1Received, remaining3Received)
}
// Verify second channel no longer receives updates (should be closed/removed)
select {
case <-secondChannel:
t.Errorf("Removed connection still received update - this should not happen")
case <-time.After(100 * time.Millisecond):
t.Logf("SUCCESS: Removed connection correctly no longer receives updates")
} }
}) })
} }

View File

@ -20,6 +20,8 @@ type MapResponseBuilder struct {
nodeID types.NodeID nodeID types.NodeID
capVer tailcfg.CapabilityVersion capVer tailcfg.CapabilityVersion
errs []error errs []error
debugType debugType
} }
type debugType string type debugType string

View File

@ -139,11 +139,11 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
func (m *mapper) fullMapResponse( func (m *mapper) fullMapResponse(
nodeID types.NodeID, nodeID types.NodeID,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
peers := m.state.ListPeers(nodeID) peers := m.state.ListPeers(nodeID)
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithDebugType(fullResponseDebug).
WithCapabilityVersion(capVer). WithCapabilityVersion(capVer).
WithSelfNode(). WithSelfNode().
WithDERPMap(). WithDERPMap().
@ -162,6 +162,7 @@ func (m *mapper) derpMapResponse(
nodeID types.NodeID, nodeID types.NodeID,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithDebugType(derpResponseDebug).
WithDERPMap(). WithDERPMap().
Build() Build()
} }
@ -173,6 +174,7 @@ func (m *mapper) peerChangedPatchResponse(
changed []*tailcfg.PeerChange, changed []*tailcfg.PeerChange,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithDebugType(patchResponseDebug).
WithPeerChangedPatch(changed). WithPeerChangedPatch(changed).
Build() Build()
} }
@ -186,6 +188,7 @@ func (m *mapper) peerChangeResponse(
peers := m.state.ListPeers(nodeID, changedNodeID) peers := m.state.ListPeers(nodeID, changedNodeID)
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithDebugType(changeResponseDebug).
WithCapabilityVersion(capVer). WithCapabilityVersion(capVer).
WithSelfNode(). WithSelfNode().
WithUserProfiles(peers). WithUserProfiles(peers).
@ -199,6 +202,7 @@ func (m *mapper) peerRemovedResponse(
removedNodeID types.NodeID, removedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithDebugType(removeResponseDebug).
WithPeersRemoved(removedNodeID). WithPeersRemoved(removedNodeID).
Build() Build()
} }
@ -214,7 +218,7 @@ func writeDebugMapResponse(
} }
perms := fs.FileMode(debugMapResponsePerm) perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID)) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
err = os.MkdirAll(mPath, perms) err = os.MkdirAll(mPath, perms)
if err != nil { if err != nil {
panic(err) panic(err)
@ -224,7 +228,7 @@ func writeDebugMapResponse(
mapResponsePath := path.Join( mapResponsePath := path.Join(
mPath, mPath,
fmt.Sprintf("%s.json", now), fmt.Sprintf("%s-%s.json", now, t),
) )
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -244,7 +248,11 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
return nil, nil return nil, nil
} }
nodes, err := os.ReadDir(debugDumpMapResponsePath) return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
}
func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) {
nodes, err := os.ReadDir(dir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -263,7 +271,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
nodeID := types.NodeID(nodeIDu) nodeID := types.NodeID(nodeIDu)
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name())) files, err := os.ReadDir(path.Join(dir, node.Name()))
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Reading dir %s", node.Name()) log.Error().Err(err).Msgf("Reading dir %s", node.Name())
continue continue
@ -278,7 +286,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
continue continue
} }
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name())) body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Reading file %s", file.Name()) log.Error().Err(err).Msgf("Reading file %s", file.Name())
continue continue

View File

@ -158,7 +158,6 @@ func TestTailNode(t *testing.T) {
Tags: []string{}, Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true, MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{ CapMap: tailcfg.NodeCapMap{

View File

@ -175,8 +175,8 @@ func rejectUnsupported(
Int("client_cap_ver", int(version)). Int("client_cap_ver", int(version)).
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)). Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
Str("client_version", capver.TailscaleVersion(version)). Str("client_version", capver.TailscaleVersion(version)).
Str("node_key", nkey.ShortString()). Str("node.key", nkey.ShortString()).
Str("machine_key", mkey.ShortString()). Str("machine.key", mkey.ShortString()).
Msg("unsupported client connected") Msg("unsupported client connected")
http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest) http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest)
@ -282,7 +282,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
return return
} }

View File

@ -181,7 +181,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
a.registrationCache.Set(state, registrationInfo) a.registrationCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...) authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound) http.Redirect(writer, req, authURL, http.StatusFound)
} }
@ -311,7 +311,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
log.Error(). log.Error().
Caller(). Caller().
Err(werr). Err(werr).
Msg("Failed to write response") Msg("Failed to write HTTP response")
} }
return return
@ -349,7 +349,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil { if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write response") util.LogErr(err, "Failed to write HTTP response")
} }
return return

View File

@ -34,7 +34,7 @@ func (pol *Policy) compileFilterRules(
srcIPs, err := acl.Sources.Resolve(pol, users, nodes) srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil { if err != nil {
log.Trace().Err(err).Msgf("resolving source ips") log.Trace().Caller().Err(err).Msgf("resolving source ips")
} }
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
@ -52,11 +52,11 @@ func (pol *Policy) compileFilterRules(
for _, dest := range acl.Destinations { for _, dest := range acl.Destinations {
ips, err := dest.Resolve(pol, users, nodes) ips, err := dest.Resolve(pol, users, nodes)
if err != nil { if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips") log.Trace().Caller().Err(err).Msgf("resolving destination ips")
} }
if ips == nil { if ips == nil {
log.Debug().Msgf("destination resolved to nil ips: %v", dest) log.Debug().Caller().Msgf("destination resolved to nil ips: %v", dest)
continue continue
} }
@ -106,7 +106,7 @@ func (pol *Policy) compileSSHPolicy(
return nil, nil return nil, nil
} }
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname()) log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
var rules []*tailcfg.SSHRule var rules []*tailcfg.SSHRule
@ -115,7 +115,7 @@ func (pol *Policy) compileSSHPolicy(
for _, src := range rule.Destinations { for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes) ips, err := src.Resolve(pol, users, nodes)
if err != nil { if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips") log.Trace().Caller().Err(err).Msgf("resolving destination ips")
} }
dest.AddSet(ips) dest.AddSet(ips)
} }
@ -142,7 +142,7 @@ func (pol *Policy) compileSSHPolicy(
var principals []*tailcfg.SSHPrincipal var principals []*tailcfg.SSHPrincipal
srcIPs, err := rule.Sources.Resolve(pol, users, nodes) srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
if err != nil { if err != nil {
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
continue // Skip this rule if we can't resolve sources continue // Skip this rule if we can't resolve sources
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -79,6 +80,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
filterHash := deephash.Hash(&filter) filterHash := deephash.Hash(&filter)
filterChanged := filterHash != pm.filterHash filterChanged := filterHash != pm.filterHash
if filterChanged {
log.Debug().
Str("filter.hash.old", pm.filterHash.String()[:8]).
Str("filter.hash.new", filterHash.String()[:8]).
Int("filter.rules", len(pm.filter)).
Int("filter.rules.new", len(filter)).
Msg("Policy filter hash changed")
}
pm.filter = filter pm.filter = filter
pm.filterHash = filterHash pm.filterHash = filterHash
if filterChanged { if filterChanged {
@ -95,6 +104,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
tagOwnerMapHash := deephash.Hash(&tagMap) tagOwnerMapHash := deephash.Hash(&tagMap)
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
if tagOwnerChanged {
log.Debug().
Str("tagOwner.hash.old", pm.tagOwnerMapHash.String()[:8]).
Str("tagOwner.hash.new", tagOwnerMapHash.String()[:8]).
Int("tagOwners.old", len(pm.tagOwnerMap)).
Int("tagOwners.new", len(tagMap)).
Msg("Tag owner hash changed")
}
pm.tagOwnerMap = tagMap pm.tagOwnerMap = tagMap
pm.tagOwnerMapHash = tagOwnerMapHash pm.tagOwnerMapHash = tagOwnerMapHash
@ -105,19 +122,42 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
autoApproveMapHash := deephash.Hash(&autoMap) autoApproveMapHash := deephash.Hash(&autoMap)
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
if autoApproveChanged {
log.Debug().
Str("autoApprove.hash.old", pm.autoApproveMapHash.String()[:8]).
Str("autoApprove.hash.new", autoApproveMapHash.String()[:8]).
Int("autoApprovers.old", len(pm.autoApproveMap)).
Int("autoApprovers.new", len(autoMap)).
Msg("Auto-approvers hash changed")
}
pm.autoApproveMap = autoMap pm.autoApproveMap = autoMap
pm.autoApproveMapHash = autoApproveMapHash pm.autoApproveMapHash = autoApproveMapHash
exitSetHash := deephash.Hash(&autoMap) exitSetHash := deephash.Hash(&exitSet)
exitSetChanged := exitSetHash != pm.exitSetHash exitSetChanged := exitSetHash != pm.exitSetHash
if exitSetChanged {
log.Debug().
Str("exitSet.hash.old", pm.exitSetHash.String()[:8]).
Str("exitSet.hash.new", exitSetHash.String()[:8]).
Msg("Exit node set hash changed")
}
pm.exitSet = exitSet pm.exitSet = exitSet
pm.exitSetHash = exitSetHash pm.exitSetHash = exitSetHash
// If neither of the calculated values changed, no need to update nodes // If neither of the calculated values changed, no need to update nodes
if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged { if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged {
log.Trace().
Msg("Policy evaluation detected no changes - all hashes match")
return false, nil return false, nil
} }
log.Debug().
Bool("filter.changed", filterChanged).
Bool("tagOwners.changed", tagOwnerChanged).
Bool("autoApprovers.changed", autoApproveChanged).
Bool("exitNodes.changed", exitSetChanged).
Msg("Policy changes require node updates")
return true, nil return true, nil
} }
@ -151,6 +191,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
// Log policy metadata for debugging
log.Debug().
Int("policy.bytes", len(polB)).
Int("acls.count", len(pol.ACLs)).
Int("groups.count", len(pol.Groups)).
Int("hosts.count", len(pol.Hosts)).
Int("tagOwners.count", len(pol.TagOwners)).
Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)).
Msg("Policy parsed successfully")
pm.pol = pol pm.pol = pol
return pm.updateLocked() return pm.updateLocked()

View File

@ -216,6 +216,21 @@ func (m *mapSession) serveLongPoll() {
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
// TODO(kradalby): Redo the comments here
// Add node to batcher so it can receive updates,
// adding this before connecting it to the state ensure that
// it does not miss any updates that might be sent in the split
// time between the node connecting and the batcher being ready.
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
return
}
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
m.h.Change(mapReqChange)
m.h.Change(connectChanges...)
// Loop through updates and continuously send them to the // Loop through updates and continuously send them to the
// client. // client.
for { for {
@ -227,7 +242,7 @@ func (m *mapSession) serveLongPoll() {
return return
case <-ctx.Done(): case <-ctx.Done():
m.tracef("poll context done") m.tracef("poll context done chan:%p", m.ch)
mapResponseEnded.WithLabelValues("done").Inc() mapResponseEnded.WithLabelValues("done").Inc()
return return
@ -295,7 +310,15 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
} }
} }
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") log.Trace().
Caller().
Str("node.name", m.node.Hostname).
Uint64("node.id", m.node.ID.Uint64()).
Str("chan", fmt.Sprintf("%p", m.ch)).
TimeDiff("timeSpent", time.Now(), startWrite).
Str("machine.key", m.node.MachineKey.String()).
Bool("keepalive", msg.KeepAlive).
Msgf("finished writing mapresp to node chan(%p)", m.ch)
return nil return nil
} }
@ -305,14 +328,14 @@ var keepAlive = tailcfg.MapResponse{
} }
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) { func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname) trace := log.Trace().Caller().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
if peerChange.Key != nil { if peerChange.Key != nil {
trace = trace.Str("node_key", peerChange.Key.ShortString()) trace = trace.Str("node.key", peerChange.Key.ShortString())
} }
if peerChange.DiscoKey != nil { if peerChange.DiscoKey != nil {
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString()) trace = trace.Str("disco.key", peerChange.DiscoKey.ShortString())
} }
if peerChange.Online != nil { if peerChange.Online != nil {
@ -349,7 +372,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream). Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()). Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname). Str("node.name", node.Hostname).
Msgf(msg, a...) Msgf(msg, a...)
}, },
func(msg string, a ...any) { func(msg string, a ...any) {
@ -358,7 +381,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream). Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()). Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname). Str("node.name", node.Hostname).
Msgf(msg, a...) Msgf(msg, a...)
}, },
func(msg string, a ...any) { func(msg string, a ...any) {
@ -367,7 +390,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream). Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()). Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname). Str("node.name", node.Hostname).
Msgf(msg, a...) Msgf(msg, a...)
}, },
func(err error, msg string, a ...any) { func(err error, msg string, a ...any) {
@ -376,7 +399,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream). Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()). Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname). Str("node.name", node.Hostname).
Err(err). Err(err).
Msgf(msg, a...) Msgf(msg, a...)
} }

View File

@ -1430,7 +1430,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err) return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
} }
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users") log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected")
changed, err := s.polMan.SetUsers(users) changed, err := s.polMan.SetUsers(users)
if err != nil { if err != nil {

View File

@ -97,6 +97,35 @@ func (c ChangeSet) IsFull() bool {
return c.Change == Full || c.Change == Policy return c.Change == Full || c.Change == Policy
} }
func HasFull(cs []ChangeSet) bool {
for _, c := range cs {
if c.IsFull() {
return true
}
}
return false
}
func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) {
for _, c := range cs {
if c.SelfUpdateOnly {
self = append(self, c)
} else {
all = append(all, c)
}
}
return all, self
}
func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
for _, c := range cs {
if c.NodeID != id || c.Change.AlsoSelf() {
ret = append(ret, c)
}
}
return ret
}
func (c ChangeSet) AlsoSelf() bool { func (c ChangeSet) AlsoSelf() bool {
// If NodeID is 0, it means this ChangeSet is not related to a specific node, // If NodeID is 0, it means this ChangeSet is not related to a specific node,
// so we consider it as a change that should be sent to all nodes. // so we consider it as a change that should be sent to all nodes.

View File

@ -489,6 +489,7 @@ func derpConfig() DERPConfig {
urlAddr, err := url.Parse(urlStr) urlAddr, err := url.Parse(urlStr)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller().
Str("url", urlStr). Str("url", urlStr).
Err(err). Err(err).
Msg("Failed to parse url, ignoring...") Msg("Failed to parse url, ignoring...")
@ -561,6 +562,7 @@ func logConfig() LogConfig {
logFormat = TextLogFormat logFormat = TextLogFormat
default: default:
log.Error(). log.Error().
Caller().
Str("func", "GetLogConfig"). Str("func", "GetLogConfig").
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt) Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
} }

View File

@ -54,6 +54,20 @@ func (id NodeID) String() string {
return strconv.FormatUint(id.Uint64(), util.Base10) return strconv.FormatUint(id.Uint64(), util.Base10)
} }
func ParseNodeID(s string) (NodeID, error) {
id, err := strconv.ParseUint(s, util.Base10, 64)
return NodeID(id), err
}
func MustParseNodeID(s string) NodeID {
id, err := ParseNodeID(s)
if err != nil {
panic(err)
}
return id
}
// Node is a Headscale client. // Node is a Headscale client.
type Node struct { type Node struct {
ID NodeID `gorm:"primary_key"` ID NodeID `gorm:"primary_key"`

View File

@ -61,6 +61,7 @@ func (pak *PreAuthKey) Validate() error {
} }
log.Debug(). log.Debug().
Caller().
Str("key", pak.Key). Str("key", pak.Key).
Bool("hasExpiration", pak.Expiration != nil). Bool("hasExpiration", pak.Expiration != nil).
Time("expiration", func() time.Time { Time("expiration", func() time.Time {

View File

@ -321,7 +321,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
if err == nil { if err == nil {
u.Name = claims.Username u.Name = claims.Username
} else { } else {
log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username) log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
} }
if claims.EmailVerified { if claims.EmailVerified {

View File

@ -1160,6 +1160,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
err = headscale.SetPolicy(&p) err = headscale.SetPolicy(&p)
require.NoError(t, err) require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Get the current policy and check // Get the current policy and check
// if it is the same as the one we set. // if it is the same as the one we set.
var output *policyv2.Policy var output *policyv2.Policy
@ -1174,26 +1175,28 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
}, },
&output, &output,
) )
require.NoError(t, err) assert.NoError(ct, err)
assert.Len(t, output.ACLs, 1) assert.Len(t, output.ACLs, 1)
if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" { if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" {
t.Errorf("unexpected policy(-want +got):\n%s", diff) ct.Errorf("unexpected policy(-want +got):\n%s", diff)
} }
}, 30*time.Second, 1*time.Second, "verifying that the new policy took place")
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Test that user1 can visit all user2 // Test that user1 can visit all user2
for _, client := range user1Clients { for _, client := range user1Clients {
for _, peer := range user2Clients { for _, peer := range user2Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
require.NoError(t, err) assert.NoError(ct, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Len(t, result, 13) assert.Len(ct, result, 13)
require.NoError(t, err) assert.NoError(ct, err)
} }
} }
@ -1201,16 +1204,17 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
for _, client := range user2Clients { for _, client := range user2Clients {
for _, peer := range user1Clients { for _, peer := range user1Clients {
fqdn, err := peer.FQDN() fqdn, err := peer.FQDN()
require.NoError(t, err) assert.NoError(ct, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn) url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url) t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url) result, err := client.Curl(url)
assert.Empty(t, result) assert.Empty(ct, result)
require.Error(t, err) assert.Error(ct, err)
} }
} }
}, 30*time.Second, 1*time.Second, "new policy did not get propagated to nodes")
} }
func TestACLAutogroupMember(t *testing.T) { func TestACLAutogroupMember(t *testing.T) {

View File

@ -9,6 +9,7 @@ import (
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/samber/lo" "github.com/samber/lo"
@ -53,6 +54,18 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second)
// assertClientsState(t, allClients) // assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr) clientIPs := make(map[TailscaleClient][]netip.Addr)
@ -64,9 +77,6 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
clientIPs[client] = ips clientIPs[client] = ips
} }
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes)) assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes) nodeCountBeforeLogout := len(listNodes)
@ -86,6 +96,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
err = scenario.WaitForTailscaleLogout() err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err) assertNoErrLogout(t, err)
// After taking down all nodes, verify all systems show nodes offline
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second)
t.Logf("all clients logged out") t.Logf("all clients logged out")
assert.EventuallyWithT(t, func(ct *assert.CollectT) { assert.EventuallyWithT(t, func(ct *assert.CollectT) {

View File

@ -481,10 +481,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err) assertNoErr(t, err)
@ -494,7 +490,8 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
_, err = doLoginURL(ts.Hostname(), u) _, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err) assertNoErr(t, err)
listUsers, err = headscale.ListUsers() assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, listUsers, 1) assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{ wantUsers := []*v1.User{
@ -514,6 +511,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff) t.Fatalf("unexpected users: %s", diff)
} }
}, 30*time.Second, 1*time.Second, "validating users after first login")
listNodes, err := headscale.ListNodes() listNodes, err := headscale.ListNodes()
assertNoErr(t, err) assertNoErr(t, err)
@ -525,19 +523,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout() err = ts.Logout()
assertNoErr(t, err) assertNoErr(t, err)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
assertNoErr(t, err)
// Wait for logout to complete and then do second logout // Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed // Check that the first logout completed
status, err := ts.Status() status, err := ts.Status()
assert.NoError(ct, err) assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState) assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second) }, 30*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
assertNoErr(t, err)
u, err = ts.LoginWithURL(headscale.GetEndpoint()) u, err = ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err) assertNoErr(t, err)
@ -545,10 +543,11 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
_, err = doLoginURL(ts.Hostname(), u) _, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err) assertNoErr(t, err)
listUsers, err = headscale.ListUsers() assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, listUsers, 2) assert.Len(t, listUsers, 2)
wantUsers = []*v1.User{ wantUsers := []*v1.User{
{ {
Id: 1, Id: 1,
Name: "user1", Name: "user1",
@ -570,31 +569,30 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}) })
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff) ct.Errorf("unexpected users: %s", diff)
} }
}, 30*time.Second, 1*time.Second, "validating users after new user login")
listNodesAfterNewUserLogin, err := headscale.ListNodes() var listNodesAfterNewUserLogin []*v1.Node
assertNoErr(t, err) assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assert.Len(t, listNodesAfterNewUserLogin, 2) listNodesAfterNewUserLogin, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodesAfterNewUserLogin, 2)
// Machine key is the same as the "machine" has not changed, // Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node // but Node key is not as it is a new node
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
}, 30*time.Second, 1*time.Second, "listing nodes after new user login")
// Log out user2, and log into user1, no new node should be created, // Log out user2, and log into user1, no new node should be created,
// the node should now "become" node1 again // the node should now "become" node1 again
err = ts.Logout() err = ts.Logout()
assertNoErr(t, err) assertNoErr(t, err)
// Wait for logout to complete and then do second logout t.Logf("Logged out take one")
assert.EventuallyWithT(t, func(ct *assert.CollectT) { t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and // TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it // logs in immediately after the first logout and I cannot reproduce it
@ -602,16 +600,40 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout() err = ts.Logout()
assertNoErr(t, err) assertNoErr(t, err)
t.Logf("Logged out take two")
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 30*time.Second, 1*time.Second)
// We do not actually "change" the user here, it is done by logging in again
// as the OIDC mock server is kind of like a stack, and the next user is
// prepared and ready to go.
u, err = ts.LoginWithURL(headscale.GetEndpoint()) u, err = ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err) assertNoErr(t, err)
_, err = doLoginURL(ts.Hostname(), u) _, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err) assertNoErr(t, err)
listUsers, err = headscale.ListUsers() assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assertNoErr(t, err) status, err := ts.Status()
assert.Len(t, listUsers, 2) assert.NoError(ct, err)
wantUsers = []*v1.User{ assert.Equal(ct, "Running", status.BackendState)
}, 30*time.Second, 1*time.Second)
t.Logf("Logged back in")
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err)
assert.Len(ct, listUsers, 2)
wantUsers := []*v1.User{
{ {
Id: 1, Id: 1,
Name: "user1", Name: "user1",
@ -633,34 +655,37 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}) })
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff) ct.Errorf("unexpected users: %s", diff)
} }
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodesAfterLoggingBackIn, err := headscale.ListNodes() listNodesAfterLoggingBackIn, err := headscale.ListNodes()
assertNoErr(t, err) assert.NoError(ct, err)
assert.Len(t, listNodesAfterLoggingBackIn, 2) assert.Len(ct, listNodesAfterLoggingBackIn, 2)
// Validate that the machine we had when we logged in the first time, has the same // Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same // machine key, but a different ID than the newly logged in version of the same
// machine. // machine.
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
// Even tho we are logging in again with the same user, the previous key has been expired // Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same // and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches. // as the user + machinekey still matches.
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
// The "logged back in" machine should have the same machinekey but a different nodekey // The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user. // than the version logged in with a different user.
assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
} }
// assertTailscaleNodesLogout verifies that all provided Tailscale clients // assertTailscaleNodesLogout verifies that all provided Tailscale clients

View File

@ -4,6 +4,7 @@ import (
"net/netip" "net/netip"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"

View File

@ -10,18 +10,21 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -59,13 +62,15 @@ func TestPingAllByIP(t *testing.T) {
hs, err := scenario.Headscale() hs, err := scenario.Headscale()
require.NoError(t, err) require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Extract node IDs for validation
all, err := hs.GetAllMapReponses() expectedNodes := make([]types.NodeID, 0, len(allClients))
assert.NoError(ct, err) for _, client := range allClients {
status := client.MustStatus()
onlineMap := buildExpectedOnlineMap(all) nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap) require.NoError(t, err, "failed to parse node ID")
}, 30*time.Second, 2*time.Second) expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, hs, expectedNodes, true, "all clients should be online across all systems", 30*time.Second)
// assertClientsState(t, allClients) // assertClientsState(t, allClients)
@ -73,6 +78,14 @@ func TestPingAllByIP(t *testing.T) {
return x.String() return x.String()
}) })
// Get headscale instance for batcher debug check
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Test our DebugBatcher functionality
t.Logf("Testing DebugBatcher functionality...")
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to the batcher", 30*time.Second)
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
} }
@ -962,9 +975,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
@ -980,14 +990,31 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
return x.String() return x.String()
}) })
// Get headscale instance for batcher debug checks
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Initial check: all nodes should be connected to batcher
// Extract node IDs for validation
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second)
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
wg, _ := errgroup.WithContext(context.Background())
for run := range 3 { for run := range 3 {
t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999")) t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999"))
// Create fresh errgroup with timeout for each run
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
wg, _ := errgroup.WithContext(ctx)
for _, client := range allClients { for _, client := range allClients {
c := client c := client
wg.Go(func() error { wg.Go(func() error {
@ -1001,6 +1028,9 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
} }
t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999")) t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
// After taking down all nodes, verify all systems show nodes offline
requireAllClientsOnline(t, headscale, expectedNodes, false, fmt.Sprintf("Run %d: all nodes should be offline after Down()", run+1), 120*time.Second)
for _, client := range allClients { for _, client := range allClients {
c := client c := client
wg.Go(func() error { wg.Go(func() error {
@ -1014,22 +1044,22 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
} }
t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999")) t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
// After bringing up all nodes, verify batcher shows all reconnected
requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all nodes should be reconnected after Up()", run+1), 120*time.Second)
// Wait for sync and successful pings after nodes come back up // Wait for sync and successful pings after nodes come back up
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assert.NoError(t, err) assert.NoError(t, err)
t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999")) t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
assert.EventuallyWithT(t, func(ct *assert.CollectT) { requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all systems should show nodes online after reconnection", run+1), 60*time.Second)
all, err := hs.GetAllMapReponses()
assert.NoError(ct, err)
onlineMap := buildExpectedOnlineMap(all)
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
}, 60*time.Second, 2*time.Second)
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps)) assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
// Clean up context for this run
cancel()
} }
} }
@ -1141,51 +1171,158 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
} }
func buildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool { // NodeSystemStatus represents the online status of a node across different systems
res := make(map[types.NodeID]map[types.NodeID]bool) type NodeSystemStatus struct {
for nid, mrs := range all { Batcher bool
res[nid] = make(map[types.NodeID]bool) BatcherConnCount int
for _, mr := range mrs { MapResponses bool
for _, peer := range mr.Peers { NodeStore bool
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChanged {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChangedPatch {
if peer.Online != nil {
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
}
}
}
}
return res
} }
func assertExpectedOnlineMapAllOnline(t *assert.CollectT, expectedPeerCount int, onlineMap map[types.NodeID]map[types.NodeID]bool) { // requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore
for nid, peers := range onlineMap { func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
onlineCount := 0 t.Helper()
for _, online := range peers {
if online { startTime := time.Now()
onlineCount++ t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format("2006-01-02T15:04:05.000"), message)
var prevReport string
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get batcher state
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
// Get map responses
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate node counts first
expectedCount := len(expectedNodes)
assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch")
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
// Check that we have map responses for expected nodes
mapResponseCount := len(mapResponses)
assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch")
// Build status map for each node
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
// Initialize all expected nodes
for _, nodeID := range expectedNodes {
nodeStatus[nodeID] = NodeSystemStatus{}
}
// Check batcher state
for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes {
nodeID := types.MustParseNodeID(nodeIDStr)
if status, exists := nodeStatus[nodeID]; exists {
status.Batcher = nodeInfo.Connected
status.BatcherConnCount = nodeInfo.ActiveConnections
nodeStatus[nodeID] = status
} }
} }
assert.Equalf(t, expectedPeerCount, len(peers), "node:%d had an unexpected number of peers in online map", nid)
if expectedPeerCount != onlineCount { // Check map responses using buildExpectedOnlineMap
var sb strings.Builder onlineFromMaps := make(map[types.NodeID]bool)
sb.WriteString(fmt.Sprintf("Not all of node:%d peers where online:\n", nid)) onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
for pid, online := range peers { for nodeID := range nodeStatus {
sb.WriteString(fmt.Sprintf("\tPeer node:%d online: %t\n", pid, online)) NODE_STATUS:
for id, peerMap := range onlineMap {
if id == nodeID {
continue
} }
sb.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
sb.WriteString("expected all peers to be online.") online := peerMap[nodeID]
t.Errorf("%s", sb.String()) // If the node is offline in any map response, we consider it offline
if !online {
onlineFromMaps[nodeID] = false
continue NODE_STATUS
}
onlineFromMaps[nodeID] = true
} }
} }
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
// Update status with map response data
for nodeID, online := range onlineFromMaps {
if status, exists := nodeStatus[nodeID]; exists {
status.MapResponses = online
nodeStatus[nodeID] = status
}
}
// Check nodestore state
for nodeID, node := range nodeStore {
if status, exists := nodeStatus[nodeID]; exists {
// Check if node is online in nodestore
status.NodeStore = node.IsOnline != nil && *node.IsOnline
nodeStatus[nodeID] = status
}
}
// Verify all systems show nodes in expected state and report failures
allMatch := true
var failureReport strings.Builder
ids := types.NodeIDs(maps.Keys(nodeStatus))
slices.Sort(ids)
for _, nodeID := range ids {
status := nodeStatus[nodeID]
systemsMatch := (status.Batcher == expectedOnline) &&
(status.MapResponses == expectedOnline) &&
(status.NodeStore == expectedOnline)
if !systemsMatch {
allMatch = false
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr))
failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher))
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses))
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore))
}
}
if !allMatch {
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
t.Log("Diff between reports:")
t.Logf("Prev report: \n%s\n", prevReport)
t.Logf("New report: \n%s\n", failureReport.String())
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
prevReport = failureReport.String()
}
failureReport.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
assert.Fail(c, failureReport.String())
}
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr))
}, timeout, 2*time.Second, message)
endTime := time.Now()
duration := endTime.Sub(startTime)
t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format("2006-01-02T15:04:05.000"), duration, message)
} }

View File

@ -22,6 +22,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"

View File

@ -19,6 +19,7 @@ import (
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
"tailscale.com/tailcfg"
) )
// PeerSyncTimeout returns the timeout for peer synchronization based on environment: // PeerSyncTimeout returns the timeout for peer synchronization based on environment:
@ -199,3 +200,30 @@ func CreateCertificate(hostname string) ([]byte, []byte, error) {
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
} }
func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool {
res := make(map[types.NodeID]map[types.NodeID]bool)
for nid, mrs := range all {
res[nid] = make(map[types.NodeID]bool)
for _, mr := range mrs {
for _, peer := range mr.Peers {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChanged {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChangedPatch {
if peer.Online != nil {
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
}
}
}
}
return res
}

View File

@ -5,7 +5,9 @@ package main
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"go/format"
"io" "io"
"log"
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
@ -61,14 +63,14 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
rawURL := fmt.Sprintf(rawFileURL, version) rawURL := fmt.Sprintf(rawFileURL, version)
resp, err := http.Get(rawURL) resp, err := http.Get(rawURL)
if err != nil { if err != nil {
fmt.Printf("Error fetching raw file for version %s: %v\n", version, err) log.Printf("Error fetching raw file for version %s: %v\n", version, err)
continue continue
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
fmt.Printf("Error reading raw file for version %s: %v\n", version, err) log.Printf("Error reading raw file for version %s: %v\n", version, err)
continue continue
} }
@ -79,7 +81,7 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
capabilityVersion, _ := strconv.Atoi(capabilityVersionStr) capabilityVersion, _ := strconv.Atoi(capabilityVersionStr)
versions[version] = tailcfg.CapabilityVersion(capabilityVersion) versions[version] = tailcfg.CapabilityVersion(capabilityVersion)
} else { } else {
fmt.Printf("Version: %s, CurrentCapabilityVersion not found\n", version) log.Printf("Version: %s, CurrentCapabilityVersion not found\n", version)
} }
} }
@ -87,29 +89,23 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
} }
func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion) error { func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion) error {
// Open the output file // Generate the Go code as a string
file, err := os.Create(outputFile) var content strings.Builder
if err != nil { content.WriteString("package capver\n\n")
return fmt.Errorf("error creating file: %w", err) content.WriteString("// Generated DO NOT EDIT\n\n")
} content.WriteString(`import "tailscale.com/tailcfg"`)
defer file.Close() content.WriteString("\n\n")
content.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
// Write the package declaration and variable
file.WriteString("package capver\n\n")
file.WriteString("//Generated DO NOT EDIT\n\n")
file.WriteString(`import "tailscale.com/tailcfg"`)
file.WriteString("\n\n")
file.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
sortedVersions := xmaps.Keys(versions) sortedVersions := xmaps.Keys(versions)
sort.Strings(sortedVersions) sort.Strings(sortedVersions)
for _, version := range sortedVersions { for _, version := range sortedVersions {
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version]) fmt.Fprintf(&content, "\t\"%s\": %d,\n", version, versions[version])
} }
file.WriteString("}\n") content.WriteString("}\n")
file.WriteString("\n\n") content.WriteString("\n\n")
file.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n") content.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n")
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string) capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
for _, v := range sortedVersions { for _, v := range sortedVersions {
@ -129,9 +125,21 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
return capsSorted[i] < capsSorted[j] return capsSorted[i] < capsSorted[j]
}) })
for _, capVer := range capsSorted { for _, capVer := range capsSorted {
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]) fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
}
content.WriteString("}\n")
// Format the generated code
formatted, err := format.Source([]byte(content.String()))
if err != nil {
return fmt.Errorf("error formatting Go code: %w", err)
}
// Write to file
err = os.WriteFile(outputFile, formatted, 0644)
if err != nil {
return fmt.Errorf("error writing file: %w", err)
} }
file.WriteString("}\n")
return nil return nil
} }
@ -139,15 +147,15 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
func main() { func main() {
versions, err := getCapabilityVersions() versions, err := getCapabilityVersions()
if err != nil { if err != nil {
fmt.Println("Error:", err) log.Println("Error:", err)
return return
} }
err = writeCapabilityVersionsToFile(versions) err = writeCapabilityVersionsToFile(versions)
if err != nil { if err != nil {
fmt.Println("Error writing to file:", err) log.Println("Error writing to file:", err)
return return
} }
fmt.Println("Capability versions written to", outputFile) log.Println("Capability versions written to", outputFile)
} }