1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-16 17:50:44 +02:00

lint and leftover

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-09-05 16:32:46 +02:00 committed by Kristoffer Dalby
parent 39443184d6
commit 233dffc186
34 changed files with 1429 additions and 506 deletions

View File

@ -16,15 +16,13 @@ body:
- type: textarea
attributes:
label: Description
description:
A clear and precise description of what new or changed feature you want.
description: A clear and precise description of what new or changed feature you want.
validations:
required: true
- type: checkboxes
attributes:
label: Contribution
description:
Are you willing to contribute to the implementation of this feature?
description: Are you willing to contribute to the implementation of this feature?
options:
- label: I can write the design doc for this feature
required: false
@ -33,7 +31,6 @@ body:
- type: textarea
attributes:
label: How can it be implemented?
description:
Free text for your ideas on how this feature could be implemented.
description: Free text for your ideas on how this feature could be implemented.
validations:
required: false

View File

@ -146,12 +146,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
policyChanged, err := app.state.DeleteNode(node)
if err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed")
return
}
app.Change(policyChanged)
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached")
})
app.ephemeralGC = ephemeralGC
@ -384,53 +384,49 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
log.Trace().
Caller().
Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`)
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
Msg("HTTP authentication invoked")
authHeader := req.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller().
Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`)
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
return err
}
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Str("client_address", req.RemoteAddr).
Msg("failed to validate token")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Unauthorized"))
return err
}
return
}
if !valid {
log.Info().
Str("client_address", req.RemoteAddr).
Msg("invalid token")
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
return err
}
return nil
}(); err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", req.RemoteAddr).
Msg("failed to validate token")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
if !valid {
log.Info().
Str("client_address", req.RemoteAddr).
Msg("invalid token")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
Msg("Failed to write HTTP response")
return
}

View File

@ -260,7 +260,7 @@ func NewHeadscaleDatabase(
log.Error().Err(err).Msg("Error creating route")
} else {
log.Info().
Uint64("node_id", route.NodeID).
Uint64("node.id", route.NodeID).
Str("prefix", prefix.String()).
Msg("Route migrated")
}
@ -1131,7 +1131,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
}
for _, migrationID := range migrationIDs {
log.Trace().Str("migration_id", migrationID).Msg("Running migration")
log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration")
needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
if needsFKDisabled {

View File

@ -275,7 +275,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
return errors.New("backfilling IPs: ip allocator was nil")
}
log.Trace().Msgf("starting to backfill IPs")
log.Trace().Caller().Msgf("starting to backfill IPs")
nodes, err := ListNodes(tx)
if err != nil {
@ -283,7 +283,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
}
for _, node := range nodes {
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill")
log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database")
changed := false
// IPv4 prefix is set, but node ip is missing, alloc

View File

@ -34,9 +34,6 @@ var (
"node not found in registration cache",
)
ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
ErrDifferentRegisteredUser = errors.New(
"node was previously registered with a different user",
)
)
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.

View File

@ -7,6 +7,7 @@ import (
"strings"
"github.com/arl/statsviz"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/prometheus/client_golang/prometheus/promhttp"
"tailscale.com/tsweb"
@ -239,6 +240,34 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(resJSON)
}))
// Batcher endpoint
debug.Handle("batcher", "Batcher connected nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check Accept header to determine response format
acceptHeader := r.Header.Get("Accept")
wantsJSON := strings.Contains(acceptHeader, "application/json")
if wantsJSON {
batcherInfo := h.debugBatcherJSON()
batcherJSON, err := json.MarshalIndent(batcherInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(batcherJSON)
} else {
// Default to text/plain for backward compatibility
batcherInfo := h.debugBatcher()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(batcherInfo))
}
}))
err := statsviz.Register(debugMux)
if err == nil {
debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)")
@ -256,3 +285,124 @@ func (h *Headscale) debugHTTPServer() *http.Server {
return debugHTTPServer
}
// debugBatcher returns debug information about the batcher's connected nodes.
func (h *Headscale) debugBatcher() string {
var sb strings.Builder
sb.WriteString("=== Batcher Connected Nodes ===\n\n")
totalNodes := 0
connectedCount := 0
// Collect nodes and sort them by ID
type nodeStatus struct {
id types.NodeID
connected bool
activeConnections int
}
var nodes []nodeStatus
// Try to get detailed debug info if we have a LockFreeBatcher
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
debugInfo := batcher.Debug()
for nodeID, info := range debugInfo {
nodes = append(nodes, nodeStatus{
id: nodeID,
connected: info.Connected,
activeConnections: info.ActiveConnections,
})
totalNodes++
if info.Connected {
connectedCount++
}
}
} else {
// Fallback to basic connection info
connectedMap := h.mapBatcher.ConnectedMap()
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
nodes = append(nodes, nodeStatus{
id: nodeID,
connected: connected,
activeConnections: 0,
})
totalNodes++
if connected {
connectedCount++
}
return true
})
}
// Sort by node ID
for i := 0; i < len(nodes); i++ {
for j := i + 1; j < len(nodes); j++ {
if nodes[i].id > nodes[j].id {
nodes[i], nodes[j] = nodes[j], nodes[i]
}
}
}
// Output sorted nodes
for _, node := range nodes {
status := "disconnected"
if node.connected {
status = "connected"
}
if node.activeConnections > 0 {
sb.WriteString(fmt.Sprintf("Node %d:\t%s (%d connections)\n", node.id, status, node.activeConnections))
} else {
sb.WriteString(fmt.Sprintf("Node %d:\t%s\n", node.id, status))
}
}
sb.WriteString(fmt.Sprintf("\nSummary: %d connected, %d total\n", connectedCount, totalNodes))
return sb.String()
}
// DebugBatcherInfo represents batcher connection information in a structured format.
type DebugBatcherInfo struct {
ConnectedNodes map[string]DebugBatcherNodeInfo `json:"connected_nodes"` // NodeID -> node connection info
TotalNodes int `json:"total_nodes"`
}
// DebugBatcherNodeInfo represents connection information for a single node.
type DebugBatcherNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
}
// debugBatcherJSON returns structured debug information about the batcher's connected nodes.
func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
info := DebugBatcherInfo{
ConnectedNodes: make(map[string]DebugBatcherNodeInfo),
TotalNodes: 0,
}
// Try to get detailed debug info if we have a LockFreeBatcher
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
debugInfo := batcher.Debug()
for nodeID, debugData := range debugInfo {
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
Connected: debugData.Connected,
ActiveConnections: debugData.ActiveConnections,
}
info.TotalNodes++
}
} else {
// Fallback to basic connection info
connectedMap := h.mapBatcher.ConnectedMap()
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
Connected: connected,
ActiveConnections: 0,
}
info.TotalNodes++
return true
})
}
return info
}

View File

@ -161,7 +161,7 @@ func (d *DERPServer) DERPHandler(
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
return
@ -199,7 +199,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
return
@ -229,7 +229,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
return
@ -245,7 +245,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
return
@ -284,7 +284,7 @@ func DERPProbeHandler(
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
}
}
@ -330,7 +330,7 @@ func DERPBootstrapDNSHandler(
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
}
}

View File

@ -237,6 +237,7 @@ func (api headscaleV1APIServer) RegisterNode(
request *v1.RegisterNodeRequest,
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Caller().
Str("user", request.GetUser()).
Str("registration_id", request.GetKey()).
Msg("Registering node")
@ -525,7 +526,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
ctx context.Context,
request *v1.BackfillNodeIPsRequest,
) (*v1.BackfillNodeIPsResponse, error) {
log.Trace().Msg("Backfill called")
log.Trace().Caller().Msg("Backfill called")
if !request.Confirmed {
return nil, errors.New("not confirmed, aborting")
@ -709,6 +710,10 @@ func (api headscaleV1APIServer) SetPolicy(
UpdatedAt: timestamppb.New(updated.UpdatedAt),
}
log.Debug().
Caller().
Msg("gRPC SetPolicy completed successfully because response prepared")
return response, nil
}
@ -731,7 +736,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
Msg("Creating routes for node")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,
@ -760,6 +765,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
}
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")

View File

@ -197,7 +197,7 @@ func (h *Headscale) RobotsHandler(
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
}

View File

@ -9,6 +9,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
@ -23,7 +24,7 @@ type Batcher interface {
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet)
AddWork(c ...change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
}
@ -36,7 +37,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
}
@ -47,6 +48,7 @@ func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
@ -72,8 +74,10 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
var mapResp *tailcfg.MapResponse
var err error
var (
mapResp *tailcfg.MapResponse
err error
)
switch c.Change {
case change.DERP:
@ -84,10 +88,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
// CRITICAL FIX: Read actual online status from NodeStore when available,
// fall back to deriving from change type for unit tests or when NodeStore is empty
var onlineStatus bool
if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() {
// Use actual NodeStore status when available (production case)
onlineStatus = node.IsOnline().Get()
} else {
// Fall back to deriving from change type (unit test case or initial setup)
onlineStatus = c.Change == change.NodeCameOnline
}
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
Online: ptr.To(c.Change == change.NodeCameOnline),
Online: ptr.To(onlineStatus),
},
})
}
@ -125,7 +140,12 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
}
nodeID := nc.nodeID()
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
var data *tailcfg.MapResponse
var err error
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
@ -136,7 +156,8 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
}
// Send the map response
if err := nc.send(data); err != nil {
err = nc.send(data)
if err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}

View File

@ -2,6 +2,7 @@ package mapper
import (
"context"
"crypto/rand"
"fmt"
"sync"
"sync/atomic"
@ -57,16 +58,21 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
version: version,
created: now,
}
// Initialize last used timestamp
newEntry.lastUsed.Store(now.Unix())
// Only after validation succeeds, create or update node connection
newConn := newNodeConn(id, c, version, b.mapper)
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
if !loaded {
b.totalNodes.Add(1)
conn = newConn
}
b.connected.Store(id, nil) // nil = connected
// Add connection to the list (lock-free)
nodeConn.addConnection(newEntry)
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
@ -87,6 +93,16 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
}
// Update connection status
b.connected.Store(id, nil) // nil = connected
// Node will automatically receive updates through the normal flow
// The initial full map already contains all current state
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection established in batcher because AddNode completed successfully")
return nil
}
@ -101,10 +117,11 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
return false
}
// Mark the connection as closed to prevent further sends
if connData := existing.connData.Load(); connData != nil {
connData.closed.Store(true)
}
// Remove specific connection
removed := nodeConn.removeConnectionByChannel(c)
if !removed {
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
return false
}
// Check if node has any remaining active connections
@ -115,18 +132,17 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
return true // Node still has active connections
}
// Remove node and mark disconnected atomically
b.nodes.Delete(id)
// No active connections - keep the node entry alive for rapid reconnections
// The node will get a fresh full map when it reconnects
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1)
return false
}
// AddWork queues a change to be processed by the batcher.
// Critical changes are processed immediately, while others are batched for efficiency.
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
b.addWork(c)
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
b.addWork(c...)
}
func (b *LockFreeBatcher) Start() {
@ -137,23 +153,36 @@ func (b *LockFreeBatcher) Start() {
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
b.cancel = nil // Prevent multiple calls
}
// Only close workCh once
select {
case <-b.workCh:
// Channel is already closed
default:
close(b.workCh)
}
close(b.workCh)
}
func (b *LockFreeBatcher) doWork() {
log.Debug().Msg("batcher doWork loop started")
defer log.Debug().Msg("batcher doWork loop stopped")
for i := range b.workers {
go b.worker(i + 1)
}
// Create a cleanup ticker for removing truly disconnected nodes
cleanupTicker := time.NewTicker(5 * time.Minute)
defer cleanupTicker.Stop()
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.ctx.Done():
return
}
@ -161,8 +190,6 @@ func (b *LockFreeBatcher) doWork() {
}
func (b *LockFreeBatcher) worker(workerID int) {
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
for {
select {
@ -171,7 +198,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
return
}
startTime := time.Now()
b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request
@ -181,7 +207,9 @@ func (b *LockFreeBatcher) worker(workerID int) {
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
var err error
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
@ -192,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
@ -260,19 +289,22 @@ func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
})
return
}
}
all, self := change.SplitAllAndSelf(c)
for _, changeSet := range self {
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
changes = append(changes, changeSet)
b.pendingChanges.Store(changeSet.NodeID, changes)
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
rel := change.RemoveUpdatesForSelf(nodeID, all)
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c)
changes = append(changes, rel...)
b.pendingChanges.Store(nodeID, changes)
return true
@ -303,7 +335,44 @@ func (b *LockFreeBatcher) processBatchedChanges() {
})
}
// IsConnected is lock-free read.
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
func (b *LockFreeBatcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute
now := time.Now()
var nodesToCleanup []types.NodeID
// Find nodes that have been offline for too long
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
// Double-check the node doesn't have active connections
if nodeConn, exists := b.nodes.Load(nodeID); exists {
if !nodeConn.hasActiveConnections() {
nodesToCleanup = append(nodesToCleanup, nodeID)
}
}
}
return true
})
// Clean up the identified nodes
for _, nodeID := range nodesToCleanup {
log.Info().Uint64("node.id", nodeID.Uint64()).
Dur("offline_duration", cleanupThreshold).
Msg("Cleaning up node that has been offline for too long")
b.nodes.Delete(nodeID)
b.connected.Delete(nodeID)
b.totalNodes.Add(-1)
}
if len(nodesToCleanup) > 0 {
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
Msg("Completed cleanup of long-offline nodes")
}
}
// IsConnected is lock-free read that checks if a node has any active connections.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
// First check if we have active connections for this node
if nodeConn, exists := b.nodes.Load(id); exists {
@ -373,89 +442,234 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
}
}
// connectionData holds the channel and connection parameters.
type connectionData struct {
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
closed atomic.Bool // Track if this connection has been closed
// connectionEntry represents a single connection to a node.
type connectionEntry struct {
id string // unique connection ID
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
created time.Time
lastUsed atomic.Int64 // Unix timestamp of last successful send
}
// nodeConn described the node connection and its associated data.
type nodeConn struct {
// multiChannelNodeConn manages multiple concurrent connections for a single node.
type multiChannelNodeConn struct {
id types.NodeID
mapper *mapper
// Atomic pointer to connection data - allows lock-free updates
connData atomic.Pointer[connectionData]
mutex sync.RWMutex
connections []*connectionEntry
updateCount atomic.Int64
}
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
nc := &nodeConn{
// generateConnectionID generates a unique connection identifier.
func generateConnectionID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id,
mapper: mapper,
}
// Initialize connection data
data := &connectionData{
c: c,
version: version,
}
nc.connData.Store(data)
return nc
}
// updateConnection atomically updates connection parameters.
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
newData := &connectionData{
c: c,
version: version,
}
nc.connData.Store(newData)
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
Int("total_connections", len(mc.connections)).
Dur("mutex_wait_time", mutexWaitDur).
Msg("Successfully added connection after mutex wait")
}
// matchesChannel checks if the given channel matches current connection.
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
data := nc.connData.Load()
if data == nil {
return false
// removeConnectionByChannel removes a connection by matching channel pointer.
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for i, entry := range mc.connections {
if entry.c == c {
// Remove this connection
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("Successfully removed connection")
return true
}
}
// Compare channel pointers directly
return data.c == c
return false
}
// compressAndVersion atomically reads connection settings.
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
data := nc.connData.Load()
if data == nil {
// hasActiveConnections checks if the node has any active connections.
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections) > 0
}
// getActiveConnectionCount returns the number of active connections.
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections)
}
// send broadcasts data to all active connections for the node.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
mc.mutex.Lock()
defer mc.mutex.Unlock()
if len(mc.connections) == 0 {
// During rapid reconnection, nodes may temporarily have no active connections
// This is not an error - the node will receive a full map when it reconnects
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error
}
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Int("total_connections", len(mc.connections)).
Msg("send: broadcasting to all connections")
var lastErr error
successCount := 0
var failedConnections []int // Track failed connections for removal
// Send to all connections
for i, conn := range mc.connections {
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: attempting to send to connection")
if err := conn.send(data); err != nil {
lastErr = err
failedConnections = append(failedConnections, i)
log.Warn().Err(err).
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: connection send failed")
} else {
successCount++
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: successfully sent to connection")
}
}
// Remove failed connections (in reverse order to maintain indices)
for i := len(failedConnections) - 1; i >= 0; i-- {
idx := failedConnections[i]
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Str("conn.id", mc.connections[idx].id).
Msg("send: removing failed connection")
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
}
mc.updateCount.Add(1)
log.Info().Uint64("node.id", mc.id.Uint64()).
Int("successful_sends", successCount).
Int("failed_connections", len(failedConnections)).
Int("remaining_connections", len(mc.connections)).
Msg("send: completed broadcast")
// Success if at least one send succeeded
if successCount > 0 {
return nil
}
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
}
// send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.
select {
case entry.c <- data:
// Update last used timestamp on successful send
entry.lastUsed.Store(time.Now().Unix())
return nil
case <-time.After(50 * time.Millisecond):
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
}
}
// nodeID returns the node ID.
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
return mc.id
}
// version returns the capability version from the first active connection.
// All connections for a node should have the same version in practice.
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if len(mc.connections) == 0 {
return 0
}
return data.version
return mc.connections[0].version
}
func (nc *nodeConn) nodeID() types.NodeID {
return nc.id
// change applies a change to all active connections for the node.
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
return handleNodeChange(mc, mc.mapper, c)
}
func (nc *nodeConn) change(c change.ChangeSet) error {
return handleNodeChange(nc, nc.mapper, c)
// DebugNodeInfo contains debug information about a node's connections.
type DebugNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
}
// send sends data to the node's channel.
// The node will pick it up and send it to the HTTP handler.
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
connData := nc.connData.Load()
if connData == nil {
return fmt.Errorf("node %d: no connection data", nc.id)
}
// Debug returns a pre-baked map of node debug information for the debug interface.
func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
result := make(map[types.NodeID]DebugNodeInfo)
// Check if connection has been closed
if connData.closed.Load() {
return fmt.Errorf("node %d: connection closed", nc.id)
}
// Get all nodes with their connection status using immediate connection logic
// (no grace period) for debug purposes
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
nodeConn.mutex.RLock()
activeConnCount := len(nodeConn.connections)
nodeConn.mutex.RUnlock()
// Use immediate connection status: if active connections exist, node is connected
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
connected := false
if activeConnCount > 0 {
connected = true
} else {
// Check connected map for immediate status
if val, ok := b.connected.Load(id); ok && val == nil {
connected = true
}
}
result[id] = DebugNodeInfo{
Connected: connected,
ActiveConnections: activeConnCount,
}
return true
})
// Add all entries from the connected map to capture both connected and disconnected nodes
b.connected.Range(func(id types.NodeID, val *time.Time) bool {

View File

@ -209,6 +209,7 @@ func setupBatcherWithTestData(
// Create test users and nodes in the database
users := database.CreateUsersForTest(userCount, "testuser")
allNodes := make([]node, 0, userCount*nodesPerUser)
for _, user := range users {
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
@ -353,6 +354,7 @@ func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected b
if len(resp.PeersChangedPatch) > 0 {
require.Len(t, resp.PeersChangedPatch, 1)
assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online)
return
}
@ -412,6 +414,7 @@ func (n *node) start() {
n.maxPeersCount = info.PeerCount
}
}
if info.IsPatch {
atomic.AddInt64(&n.patchCount, 1)
// For patches, we track how many patch items
@ -550,6 +553,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Reduce verbose application logging for cleaner test output
originalLevel := zerolog.GlobalLevel()
defer zerolog.SetGlobalLevel(originalLevel)
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
// Test cases: different node counts to stress test the all-to-all connectivity
@ -618,6 +622,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Join all nodes as fast as possible
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
for i := range allNodes {
node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
@ -693,6 +698,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
if stats.MaxPeersSeen > maxPeersGlobal {
maxPeersGlobal = stats.MaxPeersSeen
}
if stats.MaxPeersSeen < minPeersSeen {
minPeersSeen = stats.MaxPeersSeen
}
@ -730,9 +736,11 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Show sample of node details
if len(nodeDetails) > 0 {
t.Logf(" Node sample:")
for _, detail := range nodeDetails[:min(5, len(nodeDetails))] {
t.Logf(" %s", detail)
}
if len(nodeDetails) > 5 {
t.Logf(" ... (%d more nodes)", len(nodeDetails)-5)
}
@ -754,6 +762,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Show details of failed nodes for debugging
if len(nodeDetails) > 5 {
t.Logf("Failed nodes details:")
for _, detail := range nodeDetails[5:] {
if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) {
t.Logf(" %s", detail)
@ -875,6 +884,7 @@ func TestBatcherBasicOperations(t *testing.T) {
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
count := 0
timer := time.NewTimer(timeout)
defer timer.Stop()
@ -1026,10 +1036,12 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// Collect updates with timeout
updateCount := 0
timeout := time.After(200 * time.Millisecond)
for {
select {
case data := <-ch:
updateCount++
receivedUpdates = append(receivedUpdates, data)
// Validate update content
@ -1058,6 +1070,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// Validate that all updates have valid content
validUpdates := 0
for _, data := range receivedUpdates {
if data != nil {
if valid, _ := validateUpdateContent(data); valid {
@ -1095,16 +1108,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
batcher := testData.Batcher
testNode := testData.Nodes[0]
var channelIssues int
var mutex sync.Mutex
var (
channelIssues int
mutex sync.Mutex
)
// Run rapid connect/disconnect cycles with real updates to test channel closing
for i := range 100 {
var wg sync.WaitGroup
// First connection
ch1 := make(chan *tailcfg.MapResponse, 1)
wg.Add(1)
go func() {
defer wg.Done()
@ -1118,17 +1137,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
// Rapid second connection - should replace ch1
ch2 := make(chan *tailcfg.MapResponse, 1)
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(1 * time.Microsecond)
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
}()
// Remove second connection
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(2 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch2)
}()
@ -1143,7 +1167,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
case <-time.After(1 * time.Millisecond):
// If no data received, increment issues counter
mutex.Lock()
channelIssues++
mutex.Unlock()
}
@ -1185,18 +1211,24 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
batcher := testData.Batcher
testNode := testData.Nodes[0]
var panics int
var channelErrors int
var invalidData int
var mutex sync.Mutex
var (
panics int
channelErrors int
invalidData int
mutex sync.Mutex
)
// Test rapid connect/disconnect with work generation
for i := range 50 {
func() {
defer func() {
if r := recover(); r != nil {
mutex.Lock()
panics++
mutex.Unlock()
t.Logf("Panic caught: %v", r)
}
@ -1213,7 +1245,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
defer func() {
if r := recover(); r != nil {
mutex.Lock()
channelErrors++
mutex.Unlock()
t.Logf("Channel consumer panic: %v", r)
}
@ -1229,7 +1263,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Validate the data we received
if valid, reason := validateUpdateContent(data); !valid {
mutex.Lock()
invalidData++
mutex.Unlock()
t.Logf("Invalid data received: %s", reason)
}
@ -1268,9 +1304,11 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
if panics > 0 {
t.Errorf("Worker channel safety failed with %d panics", panics)
}
if channelErrors > 0 {
t.Errorf("Channel handling failed with %d channel errors", channelErrors)
}
if invalidData > 0 {
t.Errorf("Data validation failed with %d invalid data packets", invalidData)
}
@ -1342,15 +1380,19 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Use remaining nodes for connection churn testing
churningNodes := allNodes[len(allNodes)/2:]
churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
var churningChannelsMutex sync.Mutex // Protect concurrent map access
var wg sync.WaitGroup
numCycles := 10 // Reduced for simpler test
panicCount := 0
var panicMutex sync.Mutex
// Track deadlock with timeout
done := make(chan struct{})
go func() {
defer close(done)
@ -1364,16 +1406,22 @@ func TestBatcherConcurrentClients(t *testing.T) {
defer func() {
if r := recover(); r != nil {
panicMutex.Lock()
panicCount++
panicMutex.Unlock()
t.Logf("Panic in churning connect: %v", r)
}
wg.Done()
}()
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
@ -1400,17 +1448,23 @@ func TestBatcherConcurrentClients(t *testing.T) {
defer func() {
if r := recover(); r != nil {
panicMutex.Lock()
panicCount++
panicMutex.Unlock()
t.Logf("Panic in churning disconnect: %v", r)
}
wg.Done()
}()
time.Sleep(time.Duration(i%5) * time.Millisecond)
churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID]
churningChannelsMutex.Unlock()
if exists {
batcher.RemoveNode(nodeID, ch)
}
@ -1422,10 +1476,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
// DERP changes
batcher.AddWork(change.DERPSet)
}
if i%5 == 0 {
// Full updates using real node data
batcher.AddWork(change.FullSet)
}
if i%7 == 0 && len(allNodes) > 0 {
// Node-specific changes using real nodes
node := allNodes[i%len(allNodes)]
@ -1453,7 +1509,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Validate results
panicMutex.Lock()
finalPanicCount := panicCount
panicMutex.Unlock()
allStats := tracker.getAllStats()
@ -1536,6 +1594,7 @@ func XTestBatcherScalability(t *testing.T) {
// Reduce verbose application logging for cleaner test output
originalLevel := zerolog.GlobalLevel()
defer zerolog.SetGlobalLevel(originalLevel)
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
// Full test matrix for scalability testing
@ -1624,6 +1683,7 @@ func XTestBatcherScalability(t *testing.T) {
batcher := testData.Batcher
allNodes := testData.Nodes
t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description)
t.Logf(
" Cycles: %d, Buffer Size: %d, Chaos Type: %s",
@ -1660,12 +1720,16 @@ func XTestBatcherScalability(t *testing.T) {
// Connect all nodes first so they can see each other as peers
connectedNodes := make(map[types.NodeID]bool)
var connectedNodesMutex sync.RWMutex
for i := range testNodes {
node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true
connectedNodesMutex.Unlock()
}
@ -1676,6 +1740,7 @@ func XTestBatcherScalability(t *testing.T) {
go func() {
defer close(done)
var wg sync.WaitGroup
t.Logf(
@ -1697,14 +1762,17 @@ func XTestBatcherScalability(t *testing.T) {
// For chaos testing, only disconnect/reconnect a subset of nodes
// This ensures some nodes stay connected to continue receiving updates
startIdx := cycle % len(testNodes)
endIdx := startIdx + len(testNodes)/4
if endIdx > len(testNodes) {
endIdx = len(testNodes)
}
if startIdx >= endIdx {
startIdx = 0
endIdx = min(len(testNodes)/4, len(testNodes))
}
chaosNodes := testNodes[startIdx:endIdx]
if len(chaosNodes) == 0 {
chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos
@ -1722,17 +1790,22 @@ func XTestBatcherScalability(t *testing.T) {
if r := recover(); r != nil {
atomic.AddInt64(&panicCount, 1)
}
wg.Done()
}()
connectedNodesMutex.RLock()
isConnected := connectedNodes[nodeID]
connectedNodesMutex.RUnlock()
if isConnected {
batcher.RemoveNode(nodeID, channel)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = false
connectedNodesMutex.Unlock()
}
}(
@ -1746,6 +1819,7 @@ func XTestBatcherScalability(t *testing.T) {
if r := recover(); r != nil {
atomic.AddInt64(&panicCount, 1)
}
wg.Done()
}()
@ -1757,7 +1831,9 @@ func XTestBatcherScalability(t *testing.T) {
tailcfg.CapabilityVersion(100),
)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = true
connectedNodesMutex.Unlock()
// Add work to create load
@ -1776,11 +1852,13 @@ func XTestBatcherScalability(t *testing.T) {
updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count
for i := range updateCount {
wg.Add(1)
go func(index int) {
defer func() {
if r := recover(); r != nil {
atomic.AddInt64(&panicCount, 1)
}
wg.Done()
}()
@ -1823,11 +1901,14 @@ func XTestBatcherScalability(t *testing.T) {
deadlockDetected = true
// Collect diagnostic information
allStats := tracker.getAllStats()
totalUpdates := 0
for _, stats := range allStats {
totalUpdates += stats.TotalUpdates
}
interimPanics := atomic.LoadInt64(&panicCount)
t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT)
t.Logf(
" Progress at timeout: %d total updates, %d panics",
@ -1873,6 +1954,7 @@ func XTestBatcherScalability(t *testing.T) {
stats := node.cleanup()
totalUpdates += stats.TotalUpdates
totalPatches += stats.PatchUpdates
totalFull += stats.FullUpdates
if stats.MaxPeersSeen > maxPeersGlobal {
maxPeersGlobal = stats.MaxPeersSeen
@ -1910,10 +1992,12 @@ func XTestBatcherScalability(t *testing.T) {
// Legacy tracker comparison (optional)
allStats := tracker.getAllStats()
legacyTotalUpdates := 0
for _, stats := range allStats {
legacyTotalUpdates += stats.TotalUpdates
}
if legacyTotalUpdates != int(totalUpdates) {
t.Logf(
"Note: Legacy tracker mismatch - legacy: %d, new: %d",
@ -1926,6 +2010,7 @@ func XTestBatcherScalability(t *testing.T) {
// Validation based on expectation
testPassed := true
if tc.expectBreak {
// For tests expected to break, we're mainly checking that we don't crash
if finalPanicCount > 0 {
@ -1947,14 +2032,19 @@ func XTestBatcherScalability(t *testing.T) {
// For tests expected to pass, validate proper operation
if finalPanicCount > 0 {
t.Errorf("Scalability test failed with %d panics", finalPanicCount)
testPassed = false
}
if deadlockDetected {
t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes))
testPassed = false
}
if totalUpdates == 0 {
t.Error("No updates received - system may be completely stalled")
testPassed = false
}
}
@ -2020,6 +2110,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Read all available updates for each node
for i := range allNodes {
nodeUpdates := 0
t.Logf("Reading updates for node %d:", i)
// Read up to 10 updates per node or until timeout/no more data
@ -2056,6 +2147,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
if len(data.Peers) > 0 {
t.Logf(" Full peer list with %d peers", len(data.Peers))
for j, peer := range data.Peers[:min(3, len(data.Peers))] {
t.Logf(
" Peer %d: NodeID=%d, Online=%v",
@ -2065,8 +2157,10 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
)
}
}
if len(data.PeersChangedPatch) > 0 {
t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch))
for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] {
t.Logf(
" Patch %d: NodeID=%d, Online=%v",
@ -2080,6 +2174,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
case <-time.After(500 * time.Millisecond):
}
}
t.Logf("Node %d received %d updates", i, nodeUpdates)
}
@ -2095,71 +2190,132 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
}
}
// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items.
func TestBatcherWorkQueueTracing(t *testing.T) {
// TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID
// at the same time cause /debug/batcher to show nodes as disconnected when they should be connected.
// This specifically tests the multi-channel batcher implementation issue.
func TestBatcherRapidReconnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10)
defer cleanup()
batcher := testData.Batcher
allNodes := testData.Nodes
t.Logf("=== RAPID RECONNECTION TEST ===")
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
// Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...")
for i, node := range allNodes {
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node %d: %v", i, err)
}
}
time.Sleep(100 * time.Millisecond) // Let connections settle
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
for i, node := range allNodes {
removed := batcher.RemoveNode(node.n.ID, node.ch)
t.Logf("Node %d RemoveNode result: %t", i, removed)
}
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i, node := range allNodes {
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err)
}
}
time.Sleep(100 * time.Millisecond) // Let reconnections settle
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
t.Logf("Phase 4: Checking debug status...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
debugInfo := debugBatcher.Debug()
disconnectedCount := 0
for i, node := range allNodes {
if info, exists := debugInfo[node.n.ID]; exists {
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
// Check if the debug info shows the node as connected
if infoMap, ok := info.(map[string]any); ok {
if connected, ok := infoMap["connected"].(bool); ok && !connected {
disconnectedCount++
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
}
}
} else {
disconnectedCount++
t.Logf("Node %d missing from debug info entirely", i)
}
// Also check IsConnected method
if !batcher.IsConnected(node.n.ID) {
t.Logf("Node %d IsConnected() returns false", i)
}
}
if disconnectedCount > 0 {
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
// This is expected behavior for multi-channel batcher according to user
// "it has never worked with the multi"
} else {
t.Logf("All nodes show as connected - working correctly")
}
} else {
t.Logf("Batcher does not implement Debug() method")
}
// Phase 5: Test if "disconnected" nodes can actually receive updates
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
// Send a change that should reach all nodes
batcher.AddWork(change.DERPChange())
receivedCount := 0
timeout := time.After(500 * time.Millisecond)
for i := 0; i < len(allNodes); i++ {
select {
case update := <-newChannels[i]:
if update != nil {
receivedCount++
t.Logf("Node %d received update successfully", i)
}
case <-timeout:
t.Logf("Node %d timed out waiting for update", i)
goto done
}
}
done:
t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes))
if receivedCount < len(allNodes) {
t.Logf("Some nodes failed to receive updates - confirming the issue")
}
})
}
}
func TestBatcherMultiConnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10)
defer cleanup()
batcher := testData.Batcher
nodes := testData.Nodes
t.Logf("=== WORK QUEUE TRACING TEST ===")
time.Sleep(100 * time.Millisecond) // Let connections settle
// Wait for initial NodeCameOnline to be processed
time.Sleep(200 * time.Millisecond)
// Drain any initial updates
drainedCount := 0
for {
select {
case <-nodes[0].ch:
drainedCount++
case <-time.After(100 * time.Millisecond):
goto drained
}
}
drained:
t.Logf("Drained %d initial updates", drainedCount)
// Now send a single FullSet update and trace it closely
t.Logf("Sending change.FullSet work item...")
batcher.AddWork(change.FullSet)
// Give short time for processing
time.Sleep(100 * time.Millisecond)
// Check if any update was received
select {
case data := <-nodes[0].ch:
t.Logf("SUCCESS: Received update after FullSet!")
if data != nil {
// Detailed analysis of the response - data is already a MapResponse
t.Logf("Response details:")
t.Logf(" Peers: %d", len(data.Peers))
t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch))
t.Logf(" PeersChanged: %d", len(data.PeersChanged))
t.Logf(" PeersRemoved: %d", len(data.PeersRemoved))
t.Logf(" DERPMap: %v", data.DERPMap != nil)
t.Logf(" KeepAlive: %v", data.KeepAlive)
t.Logf(" Node: %v", data.Node != nil)
if len(data.Peers) > 0 {
t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers))
} else if len(data.PeersChangedPatch) > 0 {
t.Errorf("ERROR: Received patch update instead of full update!")
} else if data.DERPMap != nil {
t.Logf("Received DERP map update")
} else if data.Node != nil {
t.Logf("Received self node update")
} else {
t.Errorf("ERROR: Received unknown update type!")
}
batcher := testData.Batcher
node1 := testData.Nodes[0]
node2 := testData.Nodes[1]
@ -2328,12 +2484,53 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
}
}
}
} else {
t.Errorf("Response data is nil")
}
case <-time.After(2 * time.Second):
t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!")
t.Errorf("This indicates FullSet work items are not being processed at all")
}
// Send another update and verify remaining connections still work
clearChannel(node1.ch)
clearChannel(thirdChannel)
testChangeSet2 := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
batcher.AddWork(testChangeSet2)
time.Sleep(100 * time.Millisecond)
// Verify remaining connections still receive updates
remaining1Received := false
remaining3Received := false
select {
case mapResp := <-node1.ch:
remaining1Received = (mapResp != nil)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 1 did not receive update after removal")
}
select {
case mapResp := <-thirdChannel:
remaining3Received = (mapResp != nil)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 3 did not receive update after removal")
}
if remaining1Received && remaining3Received {
t.Logf("SUCCESS: Remaining connections still receive updates after removal")
} else {
t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t",
remaining1Received, remaining3Received)
}
// Verify second channel no longer receives updates (should be closed/removed)
select {
case <-secondChannel:
t.Errorf("Removed connection still received update - this should not happen")
case <-time.After(100 * time.Millisecond):
t.Logf("SUCCESS: Removed connection correctly no longer receives updates")
}
})
}

View File

@ -20,6 +20,8 @@ type MapResponseBuilder struct {
nodeID types.NodeID
capVer tailcfg.CapabilityVersion
errs []error
debugType debugType
}
type debugType string

View File

@ -139,11 +139,11 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
func (m *mapper) fullMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) {
peers := m.state.ListPeers(nodeID)
return m.NewMapResponseBuilder(nodeID).
WithDebugType(fullResponseDebug).
WithCapabilityVersion(capVer).
WithSelfNode().
WithDERPMap().
@ -162,6 +162,7 @@ func (m *mapper) derpMapResponse(
nodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(derpResponseDebug).
WithDERPMap().
Build()
}
@ -173,6 +174,7 @@ func (m *mapper) peerChangedPatchResponse(
changed []*tailcfg.PeerChange,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(patchResponseDebug).
WithPeerChangedPatch(changed).
Build()
}
@ -186,6 +188,7 @@ func (m *mapper) peerChangeResponse(
peers := m.state.ListPeers(nodeID, changedNodeID)
return m.NewMapResponseBuilder(nodeID).
WithDebugType(changeResponseDebug).
WithCapabilityVersion(capVer).
WithSelfNode().
WithUserProfiles(peers).
@ -199,6 +202,7 @@ func (m *mapper) peerRemovedResponse(
removedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(removeResponseDebug).
WithPeersRemoved(removedNodeID).
Build()
}
@ -214,7 +218,7 @@ func writeDebugMapResponse(
}
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID))
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
@ -224,7 +228,7 @@ func writeDebugMapResponse(
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s.json", now),
fmt.Sprintf("%s-%s.json", now, t),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -244,7 +248,11 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
return nil, nil
}
nodes, err := os.ReadDir(debugDumpMapResponsePath)
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
}
func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) {
nodes, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
@ -263,7 +271,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
nodeID := types.NodeID(nodeIDu)
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name()))
files, err := os.ReadDir(path.Join(dir, node.Name()))
if err != nil {
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
continue
@ -278,7 +286,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
continue
}
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name()))
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
if err != nil {
log.Error().Err(err).Msgf("Reading file %s", file.Name())
continue

View File

@ -158,7 +158,6 @@ func TestTailNode(t *testing.T) {
Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{

View File

@ -175,8 +175,8 @@ func rejectUnsupported(
Int("client_cap_ver", int(version)).
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
Str("client_version", capver.TailscaleVersion(version)).
Str("node_key", nkey.ShortString()).
Str("machine_key", mkey.ShortString()).
Str("node.key", nkey.ShortString()).
Str("machine.key", mkey.ShortString()).
Msg("unsupported client connected")
http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest)
@ -282,7 +282,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
return
}

View File

@ -181,7 +181,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
a.registrationCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
}
@ -311,7 +311,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
Msg("Failed to write HTTP response")
}
return
@ -349,7 +349,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write response")
util.LogErr(err, "Failed to write HTTP response")
}
return

View File

@ -34,7 +34,7 @@ func (pol *Policy) compileFilterRules(
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving source ips")
log.Trace().Caller().Err(err).Msgf("resolving source ips")
}
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
@ -52,11 +52,11 @@ func (pol *Policy) compileFilterRules(
for _, dest := range acl.Destinations {
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
}
if ips == nil {
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
log.Debug().Caller().Msgf("destination resolved to nil ips: %v", dest)
continue
}
@ -106,7 +106,7 @@ func (pol *Policy) compileSSHPolicy(
return nil, nil
}
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
var rules []*tailcfg.SSHRule
@ -115,7 +115,7 @@ func (pol *Policy) compileSSHPolicy(
for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
}
dest.AddSet(ips)
}
@ -142,7 +142,7 @@ func (pol *Policy) compileSSHPolicy(
var principals []*tailcfg.SSHPrincipal
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
continue // Skip this rule if we can't resolve sources
}

View File

@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@ -79,6 +80,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
filterHash := deephash.Hash(&filter)
filterChanged := filterHash != pm.filterHash
if filterChanged {
log.Debug().
Str("filter.hash.old", pm.filterHash.String()[:8]).
Str("filter.hash.new", filterHash.String()[:8]).
Int("filter.rules", len(pm.filter)).
Int("filter.rules.new", len(filter)).
Msg("Policy filter hash changed")
}
pm.filter = filter
pm.filterHash = filterHash
if filterChanged {
@ -95,6 +104,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
tagOwnerMapHash := deephash.Hash(&tagMap)
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
if tagOwnerChanged {
log.Debug().
Str("tagOwner.hash.old", pm.tagOwnerMapHash.String()[:8]).
Str("tagOwner.hash.new", tagOwnerMapHash.String()[:8]).
Int("tagOwners.old", len(pm.tagOwnerMap)).
Int("tagOwners.new", len(tagMap)).
Msg("Tag owner hash changed")
}
pm.tagOwnerMap = tagMap
pm.tagOwnerMapHash = tagOwnerMapHash
@ -105,19 +122,42 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
autoApproveMapHash := deephash.Hash(&autoMap)
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
if autoApproveChanged {
log.Debug().
Str("autoApprove.hash.old", pm.autoApproveMapHash.String()[:8]).
Str("autoApprove.hash.new", autoApproveMapHash.String()[:8]).
Int("autoApprovers.old", len(pm.autoApproveMap)).
Int("autoApprovers.new", len(autoMap)).
Msg("Auto-approvers hash changed")
}
pm.autoApproveMap = autoMap
pm.autoApproveMapHash = autoApproveMapHash
exitSetHash := deephash.Hash(&autoMap)
exitSetHash := deephash.Hash(&exitSet)
exitSetChanged := exitSetHash != pm.exitSetHash
if exitSetChanged {
log.Debug().
Str("exitSet.hash.old", pm.exitSetHash.String()[:8]).
Str("exitSet.hash.new", exitSetHash.String()[:8]).
Msg("Exit node set hash changed")
}
pm.exitSet = exitSet
pm.exitSetHash = exitSetHash
// If neither of the calculated values changed, no need to update nodes
if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged {
log.Trace().
Msg("Policy evaluation detected no changes - all hashes match")
return false, nil
}
log.Debug().
Bool("filter.changed", filterChanged).
Bool("tagOwners.changed", tagOwnerChanged).
Bool("autoApprovers.changed", autoApproveChanged).
Bool("exitNodes.changed", exitSetChanged).
Msg("Policy changes require node updates")
return true, nil
}
@ -151,6 +191,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
// Log policy metadata for debugging
log.Debug().
Int("policy.bytes", len(polB)).
Int("acls.count", len(pol.ACLs)).
Int("groups.count", len(pol.Groups)).
Int("hosts.count", len(pol.Hosts)).
Int("tagOwners.count", len(pol.TagOwners)).
Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)).
Msg("Policy parsed successfully")
pm.pol = pol
return pm.updateLocked()

View File

@ -216,6 +216,21 @@ func (m *mapSession) serveLongPoll() {
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
// TODO(kradalby): Redo the comments here
// Add node to batcher so it can receive updates,
// adding this before connecting it to the state ensure that
// it does not miss any updates that might be sent in the split
// time between the node connecting and the batcher being ready.
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
return
}
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
m.h.Change(mapReqChange)
m.h.Change(connectChanges...)
// Loop through updates and continuously send them to the
// client.
for {
@ -227,7 +242,7 @@ func (m *mapSession) serveLongPoll() {
return
case <-ctx.Done():
m.tracef("poll context done")
m.tracef("poll context done chan:%p", m.ch)
mapResponseEnded.WithLabelValues("done").Inc()
return
@ -295,7 +310,15 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
}
}
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
log.Trace().
Caller().
Str("node.name", m.node.Hostname).
Uint64("node.id", m.node.ID.Uint64()).
Str("chan", fmt.Sprintf("%p", m.ch)).
TimeDiff("timeSpent", time.Now(), startWrite).
Str("machine.key", m.node.MachineKey.String()).
Bool("keepalive", msg.KeepAlive).
Msgf("finished writing mapresp to node chan(%p)", m.ch)
return nil
}
@ -305,14 +328,14 @@ var keepAlive = tailcfg.MapResponse{
}
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
trace := log.Trace().Caller().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
if peerChange.Key != nil {
trace = trace.Str("node_key", peerChange.Key.ShortString())
trace = trace.Str("node.key", peerChange.Key.ShortString())
}
if peerChange.DiscoKey != nil {
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
trace = trace.Str("disco.key", peerChange.DiscoKey.ShortString())
}
if peerChange.Online != nil {
@ -349,7 +372,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(msg string, a ...any) {
@ -358,7 +381,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(msg string, a ...any) {
@ -367,7 +390,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
@ -376,7 +399,7 @@ func logPollFunc(
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node", node.Hostname).
Str("node.name", node.Hostname).
Err(err).
Msgf(msg, a...)
}

View File

@ -1430,7 +1430,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
}
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected")
changed, err := s.polMan.SetUsers(users)
if err != nil {

View File

@ -97,6 +97,35 @@ func (c ChangeSet) IsFull() bool {
return c.Change == Full || c.Change == Policy
}
func HasFull(cs []ChangeSet) bool {
for _, c := range cs {
if c.IsFull() {
return true
}
}
return false
}
func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) {
for _, c := range cs {
if c.SelfUpdateOnly {
self = append(self, c)
} else {
all = append(all, c)
}
}
return all, self
}
func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
for _, c := range cs {
if c.NodeID != id || c.Change.AlsoSelf() {
ret = append(ret, c)
}
}
return ret
}
func (c ChangeSet) AlsoSelf() bool {
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
// so we consider it as a change that should be sent to all nodes.

View File

@ -489,6 +489,7 @@ func derpConfig() DERPConfig {
urlAddr, err := url.Parse(urlStr)
if err != nil {
log.Error().
Caller().
Str("url", urlStr).
Err(err).
Msg("Failed to parse url, ignoring...")
@ -561,6 +562,7 @@ func logConfig() LogConfig {
logFormat = TextLogFormat
default:
log.Error().
Caller().
Str("func", "GetLogConfig").
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
}

View File

@ -54,6 +54,20 @@ func (id NodeID) String() string {
return strconv.FormatUint(id.Uint64(), util.Base10)
}
func ParseNodeID(s string) (NodeID, error) {
id, err := strconv.ParseUint(s, util.Base10, 64)
return NodeID(id), err
}
func MustParseNodeID(s string) NodeID {
id, err := ParseNodeID(s)
if err != nil {
panic(err)
}
return id
}
// Node is a Headscale client.
type Node struct {
ID NodeID `gorm:"primary_key"`

View File

@ -61,6 +61,7 @@ func (pak *PreAuthKey) Validate() error {
}
log.Debug().
Caller().
Str("key", pak.Key).
Bool("hasExpiration", pak.Expiration != nil).
Time("expiration", func() time.Time {

View File

@ -321,7 +321,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
if err == nil {
u.Name = claims.Username
} else {
log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username)
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
}
if claims.EmailVerified {

View File

@ -1160,57 +1160,61 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
err = headscale.SetPolicy(&p)
require.NoError(t, err)
// Get the current policy and check
// if it is the same as the one we set.
var output *policyv2.Policy
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"policy",
"get",
"--output",
"json",
},
&output,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Get the current policy and check
// if it is the same as the one we set.
var output *policyv2.Policy
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"policy",
"get",
"--output",
"json",
},
&output,
)
assert.NoError(ct, err)
assert.Len(t, output.ACLs, 1)
assert.Len(t, output.ACLs, 1)
if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" {
t.Errorf("unexpected policy(-want +got):\n%s", diff)
}
// Test that user1 can visit all user2
for _, client := range user1Clients {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
require.NoError(t, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url)
assert.Len(t, result, 13)
require.NoError(t, err)
if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" {
ct.Errorf("unexpected policy(-want +got):\n%s", diff)
}
}
}, 30*time.Second, 1*time.Second, "verifying that the new policy took place")
// Test that user2 _cannot_ visit user1
for _, client := range user2Clients {
for _, peer := range user1Clients {
fqdn, err := peer.FQDN()
require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Test that user1 can visit all user2
for _, client := range user1Clients {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
assert.NoError(ct, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url)
assert.Empty(t, result)
require.Error(t, err)
result, err := client.Curl(url)
assert.Len(ct, result, 13)
assert.NoError(ct, err)
}
}
}
// Test that user2 _cannot_ visit user1
for _, client := range user2Clients {
for _, peer := range user1Clients {
fqdn, err := peer.FQDN()
assert.NoError(ct, err)
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s to %s", client.Hostname(), url)
result, err := client.Curl(url)
assert.Empty(ct, result)
assert.Error(ct, err)
}
}
}, 30*time.Second, 1*time.Second, "new policy did not get propagated to nodes")
}
func TestACLAutogroupMember(t *testing.T) {

View File

@ -9,6 +9,7 @@ import (
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/samber/lo"
@ -53,6 +54,18 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second)
// assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr)
@ -64,9 +77,6 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
clientIPs[client] = ips
}
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
@ -86,6 +96,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
// After taking down all nodes, verify all systems show nodes offline
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second)
t.Logf("all clients logged out")
assert.EventuallyWithT(t, func(ct *assert.CollectT) {

View File

@ -481,10 +481,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
headscale, err := scenario.Headscale()
assertNoErr(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err)
@ -494,26 +490,28 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
_, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
}, 30*time.Second, 1*time.Second, "validating users after first login")
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
@ -525,19 +523,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout()
assertNoErr(t, err)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
assertNoErr(t, err)
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
assertNoErr(t, err)
}, 30*time.Second, 1*time.Second)
u, err = ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
@ -545,56 +543,56 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
_, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 2)
wantUsers = []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 2,
Name: "user2",
Email: "user2@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 2)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 2,
Name: "user2",
Email: "user2@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
ct.Errorf("unexpected users: %s", diff)
}
}, 30*time.Second, 1*time.Second, "validating users after new user login")
listNodesAfterNewUserLogin, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodesAfterNewUserLogin, 2)
var listNodesAfterNewUserLogin []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodesAfterNewUserLogin, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodesAfterNewUserLogin, 2)
// Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
// Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
}, 30*time.Second, 1*time.Second, "listing nodes after new user login")
// Log out user2, and log into user1, no new node should be created,
// the node should now "become" node1 again
err = ts.Logout()
assertNoErr(t, err)
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
t.Logf("Logged out take one")
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
@ -602,65 +600,92 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout()
assertNoErr(t, err)
t.Logf("Logged out take two")
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 30*time.Second, 1*time.Second)
// We do not actually "change" the user here, it is done by logging in again
// as the OIDC mock server is kind of like a stack, and the next user is
// prepared and ready to go.
u, err = ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
_, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 2)
wantUsers = []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 2,
Name: "user2",
Email: "user2@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "Running", status.BackendState)
}, 30*time.Second, 1*time.Second)
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
t.Logf("Logged back in")
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err)
assert.Len(ct, listUsers, 2)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
{
Id: 2,
Name: "user2",
Email: "user2@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
},
}
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodesAfterLoggingBackIn, 2)
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
// Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same
// machine.
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
ct.Errorf("unexpected users: %s", diff)
}
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
// Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches.
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodesAfterLoggingBackIn, 2)
// The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user.
assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
// Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same
// machine.
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
// Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches.
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
// The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user.
assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients

View File

@ -4,6 +4,7 @@ import (
"net/netip"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"

View File

@ -10,18 +10,21 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -59,13 +62,15 @@ func TestPingAllByIP(t *testing.T) {
hs, err := scenario.Headscale()
require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
all, err := hs.GetAllMapReponses()
assert.NoError(ct, err)
onlineMap := buildExpectedOnlineMap(all)
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
}, 30*time.Second, 2*time.Second)
// Extract node IDs for validation
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
require.NoError(t, err, "failed to parse node ID")
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, hs, expectedNodes, true, "all clients should be online across all systems", 30*time.Second)
// assertClientsState(t, allClients)
@ -73,6 +78,14 @@ func TestPingAllByIP(t *testing.T) {
return x.String()
})
// Get headscale instance for batcher debug check
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Test our DebugBatcher functionality
t.Logf("Testing DebugBatcher functionality...")
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to the batcher", 30*time.Second)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
@ -962,9 +975,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
)
assertNoErrHeadscaleEnv(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
@ -980,14 +990,31 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
return x.String()
})
// Get headscale instance for batcher debug checks
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Initial check: all nodes should be connected to batcher
// Extract node IDs for validation
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
wg, _ := errgroup.WithContext(context.Background())
for run := range 3 {
t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999"))
// Create fresh errgroup with timeout for each run
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
wg, _ := errgroup.WithContext(ctx)
for _, client := range allClients {
c := client
wg.Go(func() error {
@ -1001,6 +1028,9 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
}
t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
// After taking down all nodes, verify all systems show nodes offline
requireAllClientsOnline(t, headscale, expectedNodes, false, fmt.Sprintf("Run %d: all nodes should be offline after Down()", run+1), 120*time.Second)
for _, client := range allClients {
c := client
wg.Go(func() error {
@ -1014,22 +1044,22 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
}
t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
// After bringing up all nodes, verify batcher shows all reconnected
requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all nodes should be reconnected after Up()", run+1), 120*time.Second)
// Wait for sync and successful pings after nodes come back up
err = scenario.WaitForTailscaleSync()
assert.NoError(t, err)
t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
all, err := hs.GetAllMapReponses()
assert.NoError(ct, err)
onlineMap := buildExpectedOnlineMap(all)
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
}, 60*time.Second, 2*time.Second)
requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all systems should show nodes online after reconnection", run+1), 60*time.Second)
success := pingAllHelper(t, allClients, allAddrs)
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
// Clean up context for this run
cancel()
}
}
@ -1141,51 +1171,158 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
}
func buildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool {
res := make(map[types.NodeID]map[types.NodeID]bool)
for nid, mrs := range all {
res[nid] = make(map[types.NodeID]bool)
for _, mr := range mrs {
for _, peer := range mr.Peers {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChanged {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChangedPatch {
if peer.Online != nil {
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
}
}
}
}
return res
// NodeSystemStatus represents the online status of a node across different systems
type NodeSystemStatus struct {
Batcher bool
BatcherConnCount int
MapResponses bool
NodeStore bool
}
func assertExpectedOnlineMapAllOnline(t *assert.CollectT, expectedPeerCount int, onlineMap map[types.NodeID]map[types.NodeID]bool) {
for nid, peers := range onlineMap {
onlineCount := 0
for _, online := range peers {
if online {
onlineCount++
// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format("2006-01-02T15:04:05.000"), message)
var prevReport string
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get batcher state
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
// Get map responses
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate node counts first
expectedCount := len(expectedNodes)
assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch")
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
// Check that we have map responses for expected nodes
mapResponseCount := len(mapResponses)
assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch")
// Build status map for each node
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
// Initialize all expected nodes
for _, nodeID := range expectedNodes {
nodeStatus[nodeID] = NodeSystemStatus{}
}
// Check batcher state
for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes {
nodeID := types.MustParseNodeID(nodeIDStr)
if status, exists := nodeStatus[nodeID]; exists {
status.Batcher = nodeInfo.Connected
status.BatcherConnCount = nodeInfo.ActiveConnections
nodeStatus[nodeID] = status
}
}
assert.Equalf(t, expectedPeerCount, len(peers), "node:%d had an unexpected number of peers in online map", nid)
if expectedPeerCount != onlineCount {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Not all of node:%d peers where online:\n", nid))
for pid, online := range peers {
sb.WriteString(fmt.Sprintf("\tPeer node:%d online: %t\n", pid, online))
// Check map responses using buildExpectedOnlineMap
onlineFromMaps := make(map[types.NodeID]bool)
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
for nodeID := range nodeStatus {
NODE_STATUS:
for id, peerMap := range onlineMap {
if id == nodeID {
continue
}
online := peerMap[nodeID]
// If the node is offline in any map response, we consider it offline
if !online {
onlineFromMaps[nodeID] = false
continue NODE_STATUS
}
onlineFromMaps[nodeID] = true
}
sb.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
sb.WriteString("expected all peers to be online.")
t.Errorf("%s", sb.String())
}
}
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
// Update status with map response data
for nodeID, online := range onlineFromMaps {
if status, exists := nodeStatus[nodeID]; exists {
status.MapResponses = online
nodeStatus[nodeID] = status
}
}
// Check nodestore state
for nodeID, node := range nodeStore {
if status, exists := nodeStatus[nodeID]; exists {
// Check if node is online in nodestore
status.NodeStore = node.IsOnline != nil && *node.IsOnline
nodeStatus[nodeID] = status
}
}
// Verify all systems show nodes in expected state and report failures
allMatch := true
var failureReport strings.Builder
ids := types.NodeIDs(maps.Keys(nodeStatus))
slices.Sort(ids)
for _, nodeID := range ids {
status := nodeStatus[nodeID]
systemsMatch := (status.Batcher == expectedOnline) &&
(status.MapResponses == expectedOnline) &&
(status.NodeStore == expectedOnline)
if !systemsMatch {
allMatch = false
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr))
failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher))
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses))
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore))
}
}
if !allMatch {
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
t.Log("Diff between reports:")
t.Logf("Prev report: \n%s\n", prevReport)
t.Logf("New report: \n%s\n", failureReport.String())
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
prevReport = failureReport.String()
}
failureReport.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
assert.Fail(c, failureReport.String())
}
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr))
}, timeout, 2*time.Second, message)
endTime := time.Now()
duration := endTime.Sub(startTime)
t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format("2006-01-02T15:04:05.000"), duration, message)
}

View File

@ -22,6 +22,7 @@ import (
"github.com/davecgh/go-spew/spew"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"

View File

@ -19,6 +19,7 @@ import (
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"tailscale.com/tailcfg"
)
// PeerSyncTimeout returns the timeout for peer synchronization based on environment:
@ -199,3 +200,30 @@ func CreateCertificate(hostname string) ([]byte, []byte, error) {
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
}
func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool {
res := make(map[types.NodeID]map[types.NodeID]bool)
for nid, mrs := range all {
res[nid] = make(map[types.NodeID]bool)
for _, mr := range mrs {
for _, peer := range mr.Peers {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChanged {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChangedPatch {
if peer.Online != nil {
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
}
}
}
}
return res
}

View File

@ -5,7 +5,9 @@ package main
import (
"encoding/json"
"fmt"
"go/format"
"io"
"log"
"net/http"
"os"
"regexp"
@ -61,14 +63,14 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
rawURL := fmt.Sprintf(rawFileURL, version)
resp, err := http.Get(rawURL)
if err != nil {
fmt.Printf("Error fetching raw file for version %s: %v\n", version, err)
log.Printf("Error fetching raw file for version %s: %v\n", version, err)
continue
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading raw file for version %s: %v\n", version, err)
log.Printf("Error reading raw file for version %s: %v\n", version, err)
continue
}
@ -79,7 +81,7 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
capabilityVersion, _ := strconv.Atoi(capabilityVersionStr)
versions[version] = tailcfg.CapabilityVersion(capabilityVersion)
} else {
fmt.Printf("Version: %s, CurrentCapabilityVersion not found\n", version)
log.Printf("Version: %s, CurrentCapabilityVersion not found\n", version)
}
}
@ -87,29 +89,23 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
}
func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion) error {
// Open the output file
file, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("error creating file: %w", err)
}
defer file.Close()
// Write the package declaration and variable
file.WriteString("package capver\n\n")
file.WriteString("//Generated DO NOT EDIT\n\n")
file.WriteString(`import "tailscale.com/tailcfg"`)
file.WriteString("\n\n")
file.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
// Generate the Go code as a string
var content strings.Builder
content.WriteString("package capver\n\n")
content.WriteString("// Generated DO NOT EDIT\n\n")
content.WriteString(`import "tailscale.com/tailcfg"`)
content.WriteString("\n\n")
content.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
sortedVersions := xmaps.Keys(versions)
sort.Strings(sortedVersions)
for _, version := range sortedVersions {
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version])
fmt.Fprintf(&content, "\t\"%s\": %d,\n", version, versions[version])
}
file.WriteString("}\n")
content.WriteString("}\n")
file.WriteString("\n\n")
file.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n")
content.WriteString("\n\n")
content.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n")
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
for _, v := range sortedVersions {
@ -129,9 +125,21 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
return capsSorted[i] < capsSorted[j]
})
for _, capVer := range capsSorted {
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
}
content.WriteString("}\n")
// Format the generated code
formatted, err := format.Source([]byte(content.String()))
if err != nil {
return fmt.Errorf("error formatting Go code: %w", err)
}
// Write to file
err = os.WriteFile(outputFile, formatted, 0644)
if err != nil {
return fmt.Errorf("error writing file: %w", err)
}
file.WriteString("}\n")
return nil
}
@ -139,15 +147,15 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
func main() {
versions, err := getCapabilityVersions()
if err != nil {
fmt.Println("Error:", err)
log.Println("Error:", err)
return
}
err = writeCapabilityVersionsToFile(versions)
if err != nil {
fmt.Println("Error writing to file:", err)
log.Println("Error writing to file:", err)
return
}
fmt.Println("Capability versions written to", outputFile)
log.Println("Capability versions written to", outputFile)
}