mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-16 17:50:44 +02:00
lint and leftover
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
39443184d6
commit
233dffc186
9
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
9
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
@ -16,15 +16,13 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description
|
||||
description:
|
||||
A clear and precise description of what new or changed feature you want.
|
||||
description: A clear and precise description of what new or changed feature you want.
|
||||
validations:
|
||||
required: true
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Contribution
|
||||
description:
|
||||
Are you willing to contribute to the implementation of this feature?
|
||||
description: Are you willing to contribute to the implementation of this feature?
|
||||
options:
|
||||
- label: I can write the design doc for this feature
|
||||
required: false
|
||||
@ -33,7 +31,6 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: How can it be implemented?
|
||||
description:
|
||||
Free text for your ideas on how this feature could be implemented.
|
||||
description: Free text for your ideas on how this feature could be implemented.
|
||||
validations:
|
||||
required: false
|
||||
|
@ -146,12 +146,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
||||
policyChanged, err := app.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
|
||||
log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed")
|
||||
return
|
||||
}
|
||||
|
||||
app.Change(policyChanged)
|
||||
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached")
|
||||
})
|
||||
app.ephemeralGC = ephemeralGC
|
||||
|
||||
@ -384,53 +384,49 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
Msg("HTTP authentication invoked")
|
||||
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
return err
|
||||
}
|
||||
|
||||
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("failed to validate token")
|
||||
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
return err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
log.Info().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("invalid token")
|
||||
|
||||
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}(); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("failed to validate token")
|
||||
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !valid {
|
||||
log.Info().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("invalid token")
|
||||
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
Msg("Failed to write HTTP response")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -260,7 +260,7 @@ func NewHeadscaleDatabase(
|
||||
log.Error().Err(err).Msg("Error creating route")
|
||||
} else {
|
||||
log.Info().
|
||||
Uint64("node_id", route.NodeID).
|
||||
Uint64("node.id", route.NodeID).
|
||||
Str("prefix", prefix.String()).
|
||||
Msg("Route migrated")
|
||||
}
|
||||
@ -1131,7 +1131,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
}
|
||||
|
||||
for _, migrationID := range migrationIDs {
|
||||
log.Trace().Str("migration_id", migrationID).Msg("Running migration")
|
||||
log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration")
|
||||
needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
|
||||
|
||||
if needsFKDisabled {
|
||||
|
@ -275,7 +275,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
return errors.New("backfilling IPs: ip allocator was nil")
|
||||
}
|
||||
|
||||
log.Trace().Msgf("starting to backfill IPs")
|
||||
log.Trace().Caller().Msgf("starting to backfill IPs")
|
||||
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
@ -283,7 +283,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
log.Trace().Uint64("node.id", node.ID.Uint64()).Msg("checking if need backfill")
|
||||
log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database")
|
||||
|
||||
changed := false
|
||||
// IPv4 prefix is set, but node ip is missing, alloc
|
||||
|
@ -34,9 +34,6 @@ var (
|
||||
"node not found in registration cache",
|
||||
)
|
||||
ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
|
||||
ErrDifferentRegisteredUser = errors.New(
|
||||
"node was previously registered with a different user",
|
||||
)
|
||||
)
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/arl/statsviz"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"tailscale.com/tsweb"
|
||||
@ -239,6 +240,34 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
w.Write(resJSON)
|
||||
}))
|
||||
|
||||
// Batcher endpoint
|
||||
debug.Handle("batcher", "Batcher connected nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check Accept header to determine response format
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
wantsJSON := strings.Contains(acceptHeader, "application/json")
|
||||
|
||||
if wantsJSON {
|
||||
batcherInfo := h.debugBatcherJSON()
|
||||
|
||||
batcherJSON, err := json.MarshalIndent(batcherInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(batcherJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
batcherInfo := h.debugBatcher()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(batcherInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
err := statsviz.Register(debugMux)
|
||||
if err == nil {
|
||||
debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)")
|
||||
@ -256,3 +285,124 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
return debugHTTPServer
|
||||
}
|
||||
|
||||
// debugBatcher returns debug information about the batcher's connected nodes.
|
||||
func (h *Headscale) debugBatcher() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("=== Batcher Connected Nodes ===\n\n")
|
||||
|
||||
totalNodes := 0
|
||||
connectedCount := 0
|
||||
|
||||
// Collect nodes and sort them by ID
|
||||
type nodeStatus struct {
|
||||
id types.NodeID
|
||||
connected bool
|
||||
activeConnections int
|
||||
}
|
||||
|
||||
var nodes []nodeStatus
|
||||
|
||||
// Try to get detailed debug info if we have a LockFreeBatcher
|
||||
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
|
||||
debugInfo := batcher.Debug()
|
||||
for nodeID, info := range debugInfo {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: info.Connected,
|
||||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to basic connection info
|
||||
connectedMap := h.mapBatcher.ConnectedMap()
|
||||
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: connected,
|
||||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by node ID
|
||||
for i := 0; i < len(nodes); i++ {
|
||||
for j := i + 1; j < len(nodes); j++ {
|
||||
if nodes[i].id > nodes[j].id {
|
||||
nodes[i], nodes[j] = nodes[j], nodes[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Output sorted nodes
|
||||
for _, node := range nodes {
|
||||
status := "disconnected"
|
||||
if node.connected {
|
||||
status = "connected"
|
||||
}
|
||||
|
||||
if node.activeConnections > 0 {
|
||||
sb.WriteString(fmt.Sprintf("Node %d:\t%s (%d connections)\n", node.id, status, node.activeConnections))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("Node %d:\t%s\n", node.id, status))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("\nSummary: %d connected, %d total\n", connectedCount, totalNodes))
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// DebugBatcherInfo represents batcher connection information in a structured format.
|
||||
type DebugBatcherInfo struct {
|
||||
ConnectedNodes map[string]DebugBatcherNodeInfo `json:"connected_nodes"` // NodeID -> node connection info
|
||||
TotalNodes int `json:"total_nodes"`
|
||||
}
|
||||
|
||||
// DebugBatcherNodeInfo represents connection information for a single node.
|
||||
type DebugBatcherNodeInfo struct {
|
||||
Connected bool `json:"connected"`
|
||||
ActiveConnections int `json:"active_connections"`
|
||||
}
|
||||
|
||||
// debugBatcherJSON returns structured debug information about the batcher's connected nodes.
|
||||
func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
||||
info := DebugBatcherInfo{
|
||||
ConnectedNodes: make(map[string]DebugBatcherNodeInfo),
|
||||
TotalNodes: 0,
|
||||
}
|
||||
|
||||
// Try to get detailed debug info if we have a LockFreeBatcher
|
||||
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
|
||||
debugInfo := batcher.Debug()
|
||||
for nodeID, debugData := range debugInfo {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: debugData.Connected,
|
||||
ActiveConnections: debugData.ActiveConnections,
|
||||
}
|
||||
info.TotalNodes++
|
||||
}
|
||||
} else {
|
||||
// Fallback to basic connection info
|
||||
connectedMap := h.mapBatcher.ConnectedMap()
|
||||
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ func (d *DERPServer) DERPHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@ -199,7 +199,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@ -229,7 +229,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@ -245,7 +245,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@ -284,7 +284,7 @@ func DERPProbeHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -330,7 +330,7 @@ func DERPBootstrapDNSHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -237,6 +237,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
request *v1.RegisterNodeRequest,
|
||||
) (*v1.RegisterNodeResponse, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("user", request.GetUser()).
|
||||
Str("registration_id", request.GetKey()).
|
||||
Msg("Registering node")
|
||||
@ -525,7 +526,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
|
||||
ctx context.Context,
|
||||
request *v1.BackfillNodeIPsRequest,
|
||||
) (*v1.BackfillNodeIPsResponse, error) {
|
||||
log.Trace().Msg("Backfill called")
|
||||
log.Trace().Caller().Msg("Backfill called")
|
||||
|
||||
if !request.Confirmed {
|
||||
return nil, errors.New("not confirmed, aborting")
|
||||
@ -709,6 +710,10 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
UpdatedAt: timestamppb.New(updated.UpdatedAt),
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Msg("gRPC SetPolicy completed successfully because response prepared")
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@ -731,7 +736,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
Caller().
|
||||
Interface("route-prefix", routes).
|
||||
Interface("route-str", request.GetRoutes()).
|
||||
Msg("")
|
||||
Msg("Creating routes for node")
|
||||
|
||||
hostinfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: routes,
|
||||
@ -760,6 +765,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("registration_id", registrationId.String()).
|
||||
Msg("adding debug machine via CLI, appending to registration cache")
|
||||
|
||||
|
@ -197,7 +197,7 @@ func (h *Headscale) RobotsHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
@ -23,7 +24,7 @@ type Batcher interface {
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
AddWork(c ...change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
|
||||
}
|
||||
@ -36,7 +37,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
@ -47,6 +48,7 @@ func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@ -72,8 +74,10 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
var (
|
||||
mapResp *tailcfg.MapResponse
|
||||
err error
|
||||
)
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
@ -84,10 +88,21 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
// CRITICAL FIX: Read actual online status from NodeStore when available,
|
||||
// fall back to deriving from change type for unit tests or when NodeStore is empty
|
||||
var onlineStatus bool
|
||||
if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() {
|
||||
// Use actual NodeStore status when available (production case)
|
||||
onlineStatus = node.IsOnline().Get()
|
||||
} else {
|
||||
// Fall back to deriving from change type (unit test case or initial setup)
|
||||
onlineStatus = c.Change == change.NodeCameOnline
|
||||
}
|
||||
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
Online: ptr.To(onlineStatus),
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -125,7 +140,12 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
|
||||
|
||||
var data *tailcfg.MapResponse
|
||||
var err error
|
||||
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
@ -136,7 +156,8 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
err = nc.send(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -57,16 +58,21 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
version: version,
|
||||
created: now,
|
||||
}
|
||||
// Initialize last used timestamp
|
||||
newEntry.lastUsed.Store(now.Unix())
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
|
||||
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
|
||||
|
||||
if !loaded {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
// Add connection to the list (lock-free)
|
||||
nodeConn.addConnection(newEntry)
|
||||
|
||||
// Use the worker pool for controlled concurrency instead of direct generation
|
||||
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
@ -87,6 +93,16 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
// Node will automatically receive updates through the normal flow
|
||||
// The initial full map already contains all current state
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection established in batcher because AddNode completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -101,10 +117,11 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return false
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
// Remove specific connection
|
||||
removed := nodeConn.removeConnectionByChannel(c)
|
||||
if !removed {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if node has any remaining active connections
|
||||
@ -115,18 +132,17 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
// No active connections - keep the node entry alive for rapid reconnections
|
||||
// The node will get a fresh full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
|
||||
b.addWork(c...)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
@ -137,23 +153,36 @@ func (b *LockFreeBatcher) Start() {
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel = nil // Prevent multiple calls
|
||||
}
|
||||
|
||||
// Only close workCh once
|
||||
select {
|
||||
case <-b.workCh:
|
||||
// Channel is already closed
|
||||
default:
|
||||
close(b.workCh)
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
// Create a cleanup ticker for removing truly disconnected nodes
|
||||
cleanupTicker := time.NewTicker(5 * time.Minute)
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-cleanupTicker.C:
|
||||
// Clean up nodes that have been offline for too long
|
||||
b.cleanupOfflineNodes()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
@ -161,8 +190,6 @@ func (b *LockFreeBatcher) doWork() {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
@ -171,7 +198,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
@ -181,7 +207,9 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
var err error
|
||||
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
@ -192,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
@ -260,19 +289,22 @@ func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
all, self := change.SplitAllAndSelf(c)
|
||||
|
||||
for _, changeSet := range self {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, changeSet)
|
||||
b.pendingChanges.Store(changeSet.NodeID, changes)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
rel := change.RemoveUpdatesForSelf(nodeID, all)
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
changes = append(changes, rel...)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
|
||||
return true
|
||||
@ -303,7 +335,44 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
|
||||
func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
cleanupThreshold := 15 * time.Minute
|
||||
now := time.Now()
|
||||
|
||||
var nodesToCleanup []types.NodeID
|
||||
|
||||
// Find nodes that have been offline for too long
|
||||
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
|
||||
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
|
||||
// Double-check the node doesn't have active connections
|
||||
if nodeConn, exists := b.nodes.Load(nodeID); exists {
|
||||
if !nodeConn.hasActiveConnections() {
|
||||
nodesToCleanup = append(nodesToCleanup, nodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up the identified nodes
|
||||
for _, nodeID := range nodesToCleanup {
|
||||
log.Info().Uint64("node.id", nodeID.Uint64()).
|
||||
Dur("offline_duration", cleanupThreshold).
|
||||
Msg("Cleaning up node that has been offline for too long")
|
||||
|
||||
b.nodes.Delete(nodeID)
|
||||
b.connected.Delete(nodeID)
|
||||
b.totalNodes.Add(-1)
|
||||
}
|
||||
|
||||
if len(nodesToCleanup) > 0 {
|
||||
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
|
||||
Msg("Completed cleanup of long-offline nodes")
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read that checks if a node has any active connections.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
// First check if we have active connections for this node
|
||||
if nodeConn, exists := b.nodes.Load(id); exists {
|
||||
@ -373,89 +442,234 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
// connectionEntry represents a single connection to a node.
|
||||
type connectionEntry struct {
|
||||
id string // unique connection ID
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
created time.Time
|
||||
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
||||
type multiChannelNodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
mutex sync.RWMutex
|
||||
connections []*connectionEntry
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
||||
return &multiChannelNodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Dur("mutex_wait_time", mutexWaitDur).
|
||||
Msg("Successfully added connection after mutex wait")
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
// removeConnectionByChannel removes a connection by matching channel pointer.
|
||||
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for i, entry := range mc.connections {
|
||||
if entry.c == c {
|
||||
// Remove this connection
|
||||
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("Successfully removed connection")
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
return false
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
// hasActiveConnections checks if the node has any active connections.
|
||||
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections) > 0
|
||||
}
|
||||
|
||||
// getActiveConnectionCount returns the number of active connections.
|
||||
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections)
|
||||
}
|
||||
|
||||
// send broadcasts data to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
// During rapid reconnection, nodes may temporarily have no active connections
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
successCount := 0
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
for i, conn := range mc.connections {
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
lastErr = err
|
||||
failedConnections = append(failedConnections, i)
|
||||
log.Warn().Err(err).
|
||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: successfully sent to connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Remove failed connections (in reverse order to maintain indices)
|
||||
for i := len(failedConnections) - 1; i >= 0; i-- {
|
||||
idx := failedConnections[i]
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Str("conn.id", mc.connections[idx].id).
|
||||
Msg("send: removing failed connection")
|
||||
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
|
||||
}
|
||||
|
||||
mc.updateCount.Add(1)
|
||||
|
||||
log.Info().Uint64("node.id", mc.id.Uint64()).
|
||||
Int("successful_sends", successCount).
|
||||
Int("failed_connections", len(failedConnections)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("send: completed broadcast")
|
||||
|
||||
// Success if at least one send succeeded
|
||||
if successCount > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
|
||||
}
|
||||
|
||||
// send sends data to a single connection entry with timeout-based stale connection detection.
|
||||
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
||||
// This is critical for detecting Docker containers that are forcefully terminated
|
||||
// but still have channels that appear open.
|
||||
select {
|
||||
case entry.c <- data:
|
||||
// Update last used timestamp on successful send
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
return nil
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
}
|
||||
}
|
||||
|
||||
// nodeID returns the node ID.
|
||||
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
|
||||
return mc.id
|
||||
}
|
||||
|
||||
// version returns the capability version from the first active connection.
|
||||
// All connections for a node should have the same version in practice.
|
||||
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
return mc.connections[0].version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
// change applies a change to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(mc, mc.mapper, c)
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
// DebugNodeInfo contains debug information about a node's connections.
|
||||
type DebugNodeInfo struct {
|
||||
Connected bool `json:"connected"`
|
||||
ActiveConnections int `json:"active_connections"`
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
// Debug returns a pre-baked map of node debug information for the debug interface.
|
||||
func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
result := make(map[types.NodeID]DebugNodeInfo)
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
// Get all nodes with their connection status using immediate connection logic
|
||||
// (no grace period) for debug purposes
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
nodeConn.mutex.RLock()
|
||||
activeConnCount := len(nodeConn.connections)
|
||||
nodeConn.mutex.RUnlock()
|
||||
|
||||
// Use immediate connection status: if active connections exist, node is connected
|
||||
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
|
||||
connected := false
|
||||
if activeConnCount > 0 {
|
||||
connected = true
|
||||
} else {
|
||||
// Check connected map for immediate status
|
||||
if val, ok := b.connected.Load(id); ok && val == nil {
|
||||
connected = true
|
||||
}
|
||||
}
|
||||
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
|
@ -209,6 +209,7 @@ func setupBatcherWithTestData(
|
||||
|
||||
// Create test users and nodes in the database
|
||||
users := database.CreateUsersForTest(userCount, "testuser")
|
||||
|
||||
allNodes := make([]node, 0, userCount*nodesPerUser)
|
||||
for _, user := range users {
|
||||
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
|
||||
@ -353,6 +354,7 @@ func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected b
|
||||
if len(resp.PeersChangedPatch) > 0 {
|
||||
require.Len(t, resp.PeersChangedPatch, 1)
|
||||
assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -412,6 +414,7 @@ func (n *node) start() {
|
||||
n.maxPeersCount = info.PeerCount
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsPatch {
|
||||
atomic.AddInt64(&n.patchCount, 1)
|
||||
// For patches, we track how many patch items
|
||||
@ -550,6 +553,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Reduce verbose application logging for cleaner test output
|
||||
originalLevel := zerolog.GlobalLevel()
|
||||
defer zerolog.SetGlobalLevel(originalLevel)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
|
||||
// Test cases: different node counts to stress test the all-to-all connectivity
|
||||
@ -618,6 +622,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
// Join all nodes as fast as possible
|
||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
@ -693,6 +698,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
if stats.MaxPeersSeen > maxPeersGlobal {
|
||||
maxPeersGlobal = stats.MaxPeersSeen
|
||||
}
|
||||
|
||||
if stats.MaxPeersSeen < minPeersSeen {
|
||||
minPeersSeen = stats.MaxPeersSeen
|
||||
}
|
||||
@ -730,9 +736,11 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Show sample of node details
|
||||
if len(nodeDetails) > 0 {
|
||||
t.Logf(" Node sample:")
|
||||
|
||||
for _, detail := range nodeDetails[:min(5, len(nodeDetails))] {
|
||||
t.Logf(" %s", detail)
|
||||
}
|
||||
|
||||
if len(nodeDetails) > 5 {
|
||||
t.Logf(" ... (%d more nodes)", len(nodeDetails)-5)
|
||||
}
|
||||
@ -754,6 +762,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Show details of failed nodes for debugging
|
||||
if len(nodeDetails) > 5 {
|
||||
t.Logf("Failed nodes details:")
|
||||
|
||||
for _, detail := range nodeDetails[5:] {
|
||||
if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) {
|
||||
t.Logf(" %s", detail)
|
||||
@ -875,6 +884,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
@ -1026,10 +1036,12 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
// Collect updates with timeout
|
||||
updateCount := 0
|
||||
timeout := time.After(200 * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
updateCount++
|
||||
|
||||
receivedUpdates = append(receivedUpdates, data)
|
||||
|
||||
// Validate update content
|
||||
@ -1058,6 +1070,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
|
||||
// Validate that all updates have valid content
|
||||
validUpdates := 0
|
||||
|
||||
for _, data := range receivedUpdates {
|
||||
if data != nil {
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
@ -1095,16 +1108,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
var channelIssues int
|
||||
var mutex sync.Mutex
|
||||
|
||||
var (
|
||||
channelIssues int
|
||||
mutex sync.Mutex
|
||||
)
|
||||
|
||||
// Run rapid connect/disconnect cycles with real updates to test channel closing
|
||||
|
||||
for i := range 100 {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// First connection
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@ -1118,17 +1137,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
// Rapid second connection - should replace ch1
|
||||
ch2 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Remove second connection
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||
}()
|
||||
@ -1143,7 +1167,9 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
case <-time.After(1 * time.Millisecond):
|
||||
// If no data received, increment issues counter
|
||||
mutex.Lock()
|
||||
|
||||
channelIssues++
|
||||
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
@ -1185,18 +1211,24 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
var panics int
|
||||
var channelErrors int
|
||||
var invalidData int
|
||||
var mutex sync.Mutex
|
||||
|
||||
var (
|
||||
panics int
|
||||
channelErrors int
|
||||
invalidData int
|
||||
mutex sync.Mutex
|
||||
)
|
||||
|
||||
// Test rapid connect/disconnect with work generation
|
||||
|
||||
for i := range 50 {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mutex.Lock()
|
||||
|
||||
panics++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Panic caught: %v", r)
|
||||
}
|
||||
@ -1213,7 +1245,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mutex.Lock()
|
||||
|
||||
channelErrors++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Channel consumer panic: %v", r)
|
||||
}
|
||||
@ -1229,7 +1263,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
// Validate the data we received
|
||||
if valid, reason := validateUpdateContent(data); !valid {
|
||||
mutex.Lock()
|
||||
|
||||
invalidData++
|
||||
|
||||
mutex.Unlock()
|
||||
t.Logf("Invalid data received: %s", reason)
|
||||
}
|
||||
@ -1268,9 +1304,11 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
if panics > 0 {
|
||||
t.Errorf("Worker channel safety failed with %d panics", panics)
|
||||
}
|
||||
|
||||
if channelErrors > 0 {
|
||||
t.Errorf("Channel handling failed with %d channel errors", channelErrors)
|
||||
}
|
||||
|
||||
if invalidData > 0 {
|
||||
t.Errorf("Data validation failed with %d invalid data packets", invalidData)
|
||||
}
|
||||
@ -1342,15 +1380,19 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Use remaining nodes for connection churn testing
|
||||
churningNodes := allNodes[len(allNodes)/2:]
|
||||
churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||
|
||||
var churningChannelsMutex sync.Mutex // Protect concurrent map access
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
numCycles := 10 // Reduced for simpler test
|
||||
panicCount := 0
|
||||
|
||||
var panicMutex sync.Mutex
|
||||
|
||||
// Track deadlock with timeout
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
@ -1364,16 +1406,22 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicMutex.Lock()
|
||||
|
||||
panicCount++
|
||||
|
||||
panicMutex.Unlock()
|
||||
t.Logf("Panic in churning connect: %v", r)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
@ -1400,17 +1448,23 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicMutex.Lock()
|
||||
|
||||
panicCount++
|
||||
|
||||
panicMutex.Unlock()
|
||||
t.Logf("Panic in churning disconnect: %v", r)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
time.Sleep(time.Duration(i%5) * time.Millisecond)
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
if exists {
|
||||
batcher.RemoveNode(nodeID, ch)
|
||||
}
|
||||
@ -1422,10 +1476,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// DERP changes
|
||||
batcher.AddWork(change.DERPSet)
|
||||
}
|
||||
|
||||
if i%5 == 0 {
|
||||
// Full updates using real node data
|
||||
batcher.AddWork(change.FullSet)
|
||||
}
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
// Node-specific changes using real nodes
|
||||
node := allNodes[i%len(allNodes)]
|
||||
@ -1453,7 +1509,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
// Validate results
|
||||
panicMutex.Lock()
|
||||
|
||||
finalPanicCount := panicCount
|
||||
|
||||
panicMutex.Unlock()
|
||||
|
||||
allStats := tracker.getAllStats()
|
||||
@ -1536,6 +1594,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Reduce verbose application logging for cleaner test output
|
||||
originalLevel := zerolog.GlobalLevel()
|
||||
defer zerolog.SetGlobalLevel(originalLevel)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
|
||||
// Full test matrix for scalability testing
|
||||
@ -1624,6 +1683,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description)
|
||||
t.Logf(
|
||||
" Cycles: %d, Buffer Size: %d, Chaos Type: %s",
|
||||
@ -1660,12 +1720,16 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Connect all nodes first so they can see each other as peers
|
||||
connectedNodes := make(map[types.NodeID]bool)
|
||||
|
||||
var connectedNodesMutex sync.RWMutex
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
}
|
||||
|
||||
@ -1676,6 +1740,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
t.Logf(
|
||||
@ -1697,14 +1762,17 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// For chaos testing, only disconnect/reconnect a subset of nodes
|
||||
// This ensures some nodes stay connected to continue receiving updates
|
||||
startIdx := cycle % len(testNodes)
|
||||
|
||||
endIdx := startIdx + len(testNodes)/4
|
||||
if endIdx > len(testNodes) {
|
||||
endIdx = len(testNodes)
|
||||
}
|
||||
|
||||
if startIdx >= endIdx {
|
||||
startIdx = 0
|
||||
endIdx = min(len(testNodes)/4, len(testNodes))
|
||||
}
|
||||
|
||||
chaosNodes := testNodes[startIdx:endIdx]
|
||||
if len(chaosNodes) == 0 {
|
||||
chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos
|
||||
@ -1722,17 +1790,22 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
connectedNodesMutex.RLock()
|
||||
|
||||
isConnected := connectedNodes[nodeID]
|
||||
|
||||
connectedNodesMutex.RUnlock()
|
||||
|
||||
if isConnected {
|
||||
batcher.RemoveNode(nodeID, channel)
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = false
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
}
|
||||
}(
|
||||
@ -1746,6 +1819,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@ -1757,7 +1831,9 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
|
||||
connectedNodesMutex.Unlock()
|
||||
|
||||
// Add work to create load
|
||||
@ -1776,11 +1852,13 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count
|
||||
for i := range updateCount {
|
||||
wg.Add(1)
|
||||
|
||||
go func(index int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&panicCount, 1)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@ -1823,11 +1901,14 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
deadlockDetected = true
|
||||
// Collect diagnostic information
|
||||
allStats := tracker.getAllStats()
|
||||
|
||||
totalUpdates := 0
|
||||
for _, stats := range allStats {
|
||||
totalUpdates += stats.TotalUpdates
|
||||
}
|
||||
|
||||
interimPanics := atomic.LoadInt64(&panicCount)
|
||||
|
||||
t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT)
|
||||
t.Logf(
|
||||
" Progress at timeout: %d total updates, %d panics",
|
||||
@ -1873,6 +1954,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
stats := node.cleanup()
|
||||
totalUpdates += stats.TotalUpdates
|
||||
totalPatches += stats.PatchUpdates
|
||||
|
||||
totalFull += stats.FullUpdates
|
||||
if stats.MaxPeersSeen > maxPeersGlobal {
|
||||
maxPeersGlobal = stats.MaxPeersSeen
|
||||
@ -1910,10 +1992,12 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Legacy tracker comparison (optional)
|
||||
allStats := tracker.getAllStats()
|
||||
|
||||
legacyTotalUpdates := 0
|
||||
for _, stats := range allStats {
|
||||
legacyTotalUpdates += stats.TotalUpdates
|
||||
}
|
||||
|
||||
if legacyTotalUpdates != int(totalUpdates) {
|
||||
t.Logf(
|
||||
"Note: Legacy tracker mismatch - legacy: %d, new: %d",
|
||||
@ -1926,6 +2010,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Validation based on expectation
|
||||
testPassed := true
|
||||
|
||||
if tc.expectBreak {
|
||||
// For tests expected to break, we're mainly checking that we don't crash
|
||||
if finalPanicCount > 0 {
|
||||
@ -1947,14 +2032,19 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// For tests expected to pass, validate proper operation
|
||||
if finalPanicCount > 0 {
|
||||
t.Errorf("Scalability test failed with %d panics", finalPanicCount)
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
|
||||
if deadlockDetected {
|
||||
t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes))
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
|
||||
if totalUpdates == 0 {
|
||||
t.Error("No updates received - system may be completely stalled")
|
||||
|
||||
testPassed = false
|
||||
}
|
||||
}
|
||||
@ -2020,6 +2110,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
// Read all available updates for each node
|
||||
for i := range allNodes {
|
||||
nodeUpdates := 0
|
||||
|
||||
t.Logf("Reading updates for node %d:", i)
|
||||
|
||||
// Read up to 10 updates per node or until timeout/no more data
|
||||
@ -2056,6 +2147,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
if len(data.Peers) > 0 {
|
||||
t.Logf(" Full peer list with %d peers", len(data.Peers))
|
||||
|
||||
for j, peer := range data.Peers[:min(3, len(data.Peers))] {
|
||||
t.Logf(
|
||||
" Peer %d: NodeID=%d, Online=%v",
|
||||
@ -2065,8 +2157,10 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(data.PeersChangedPatch) > 0 {
|
||||
t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch))
|
||||
|
||||
for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] {
|
||||
t.Logf(
|
||||
" Patch %d: NodeID=%d, Online=%v",
|
||||
@ -2080,6 +2174,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Node %d received %d updates", i, nodeUpdates)
|
||||
}
|
||||
|
||||
@ -2095,71 +2190,132 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatcherWorkQueueTracing traces exactly what happens to change.FullSet work items.
|
||||
func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
// TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID
|
||||
// at the same time cause /debug/batcher to show nodes as disconnected when they should be connected.
|
||||
// This specifically tests the multi-channel batcher implementation issue.
|
||||
func TestBatcherRapidReconnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
t.Logf("=== RAPID RECONNECTION TEST ===")
|
||||
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
for i, node := range allNodes {
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
for i, node := range allNodes {
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let reconnections settle
|
||||
|
||||
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
|
||||
t.Logf("Phase 4: Checking debug status...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i, node := range allNodes {
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
// Check if the debug info shows the node as connected
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
// Also check IsConnected method
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Logf("Node %d IsConnected() returns false", i)
|
||||
}
|
||||
}
|
||||
|
||||
if disconnectedCount > 0 {
|
||||
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
|
||||
// This is expected behavior for multi-channel batcher according to user
|
||||
// "it has never worked with the multi"
|
||||
} else {
|
||||
t.Logf("All nodes show as connected - working correctly")
|
||||
}
|
||||
} else {
|
||||
t.Logf("Batcher does not implement Debug() method")
|
||||
}
|
||||
|
||||
// Phase 5: Test if "disconnected" nodes can actually receive updates
|
||||
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
|
||||
|
||||
// Send a change that should reach all nodes
|
||||
batcher.AddWork(change.DERPChange())
|
||||
|
||||
receivedCount := 0
|
||||
timeout := time.After(500 * time.Millisecond)
|
||||
|
||||
for i := 0; i < len(allNodes); i++ {
|
||||
select {
|
||||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Logf("Node %d timed out waiting for update", i)
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes))
|
||||
|
||||
if receivedCount < len(allNodes) {
|
||||
t.Logf("Some nodes failed to receive updates - confirming the issue")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
nodes := testData.Nodes
|
||||
|
||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Wait for initial NodeCameOnline to be processed
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Drain any initial updates
|
||||
drainedCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-nodes[0].ch:
|
||||
drainedCount++
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
t.Logf("Drained %d initial updates", drainedCount)
|
||||
|
||||
// Now send a single FullSet update and trace it closely
|
||||
t.Logf("Sending change.FullSet work item...")
|
||||
batcher.AddWork(change.FullSet)
|
||||
|
||||
// Give short time for processing
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check if any update was received
|
||||
select {
|
||||
case data := <-nodes[0].ch:
|
||||
t.Logf("SUCCESS: Received update after FullSet!")
|
||||
|
||||
if data != nil {
|
||||
// Detailed analysis of the response - data is already a MapResponse
|
||||
t.Logf("Response details:")
|
||||
t.Logf(" Peers: %d", len(data.Peers))
|
||||
t.Logf(" PeersChangedPatch: %d", len(data.PeersChangedPatch))
|
||||
t.Logf(" PeersChanged: %d", len(data.PeersChanged))
|
||||
t.Logf(" PeersRemoved: %d", len(data.PeersRemoved))
|
||||
t.Logf(" DERPMap: %v", data.DERPMap != nil)
|
||||
t.Logf(" KeepAlive: %v", data.KeepAlive)
|
||||
t.Logf(" Node: %v", data.Node != nil)
|
||||
|
||||
if len(data.Peers) > 0 {
|
||||
t.Logf("SUCCESS: Full peer list received with %d peers", len(data.Peers))
|
||||
} else if len(data.PeersChangedPatch) > 0 {
|
||||
t.Errorf("ERROR: Received patch update instead of full update!")
|
||||
} else if data.DERPMap != nil {
|
||||
t.Logf("Received DERP map update")
|
||||
} else if data.Node != nil {
|
||||
t.Logf("Received self node update")
|
||||
} else {
|
||||
t.Errorf("ERROR: Received unknown update type!")
|
||||
}
|
||||
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
@ -2328,12 +2484,53 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Response data is nil")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Errorf("CRITICAL: No update received after FullSet within 2 seconds!")
|
||||
t.Errorf("This indicates FullSet work items are not being processed at all")
|
||||
}
|
||||
|
||||
// Send another update and verify remaining connections still work
|
||||
clearChannel(node1.ch)
|
||||
clearChannel(thirdChannel)
|
||||
|
||||
testChangeSet2 := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
|
||||
batcher.AddWork(testChangeSet2)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify remaining connections still receive updates
|
||||
remaining1Received := false
|
||||
remaining3Received := false
|
||||
|
||||
select {
|
||||
case mapResp := <-node1.ch:
|
||||
remaining1Received = (mapResp != nil)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 1 did not receive update after removal")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-thirdChannel:
|
||||
remaining3Received = (mapResp != nil)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 3 did not receive update after removal")
|
||||
}
|
||||
|
||||
if remaining1Received && remaining3Received {
|
||||
t.Logf("SUCCESS: Remaining connections still receive updates after removal")
|
||||
} else {
|
||||
t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t",
|
||||
remaining1Received, remaining3Received)
|
||||
}
|
||||
|
||||
// Verify second channel no longer receives updates (should be closed/removed)
|
||||
select {
|
||||
case <-secondChannel:
|
||||
t.Errorf("Removed connection still received update - this should not happen")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Logf("SUCCESS: Removed connection correctly no longer receives updates")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -20,6 +20,8 @@ type MapResponseBuilder struct {
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
|
||||
debugType debugType
|
||||
}
|
||||
|
||||
type debugType string
|
||||
|
@ -139,11 +139,11 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(fullResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
@ -162,6 +162,7 @@ func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(derpResponseDebug).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
@ -173,6 +174,7 @@ func (m *mapper) peerChangedPatchResponse(
|
||||
changed []*tailcfg.PeerChange,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(patchResponseDebug).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
@ -186,6 +188,7 @@ func (m *mapper) peerChangeResponse(
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(changeResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
@ -199,6 +202,7 @@ func (m *mapper) peerRemovedResponse(
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(removeResponseDebug).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
@ -214,7 +218,7 @@ func writeDebugMapResponse(
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID))
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -224,7 +228,7 @@ func writeDebugMapResponse(
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s.json", now),
|
||||
fmt.Sprintf("%s-%s.json", now, t),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
@ -244,7 +248,11 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nodes, err := os.ReadDir(debugDumpMapResponsePath)
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
}
|
||||
|
||||
func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
nodes, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -263,7 +271,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
|
||||
nodeID := types.NodeID(nodeIDu)
|
||||
|
||||
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name()))
|
||||
files, err := os.ReadDir(path.Join(dir, node.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
|
||||
continue
|
||||
@ -278,7 +286,7 @@ func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, er
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name()))
|
||||
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading file %s", file.Name())
|
||||
continue
|
||||
|
@ -158,7 +158,6 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
Tags: []string{},
|
||||
|
||||
LastSeen: &lastSeen,
|
||||
MachineAuthorized: true,
|
||||
|
||||
CapMap: tailcfg.NodeCapMap{
|
||||
|
@ -175,8 +175,8 @@ func rejectUnsupported(
|
||||
Int("client_cap_ver", int(version)).
|
||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||
Str("client_version", capver.TailscaleVersion(version)).
|
||||
Str("node_key", nkey.ShortString()).
|
||||
Str("machine_key", mkey.ShortString()).
|
||||
Str("node.key", nkey.ShortString()).
|
||||
Str("machine.key", mkey.ShortString()).
|
||||
Msg("unsupported client connected")
|
||||
http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest)
|
||||
|
||||
@ -282,7 +282,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -181,7 +181,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
a.registrationCache.Set(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||
}
|
||||
@ -311,7 +311,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(werr).
|
||||
Msg("Failed to write response")
|
||||
Msg("Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
@ -349,7 +349,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
util.LogErr(err, "Failed to write HTTP response")
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -34,7 +34,7 @@ func (pol *Policy) compileFilterRules(
|
||||
|
||||
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
log.Trace().Caller().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
|
||||
@ -52,11 +52,11 @@ func (pol *Policy) compileFilterRules(
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
if ips == nil {
|
||||
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
|
||||
log.Debug().Caller().Msgf("destination resolved to nil ips: %v", dest)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
@ -115,7 +115,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
for _, src := range rule.Destinations {
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
@ -142,7 +142,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
continue // Skip this rule if we can't resolve sources
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -79,6 +80,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
if filterChanged {
|
||||
log.Debug().
|
||||
Str("filter.hash.old", pm.filterHash.String()[:8]).
|
||||
Str("filter.hash.new", filterHash.String()[:8]).
|
||||
Int("filter.rules", len(pm.filter)).
|
||||
Int("filter.rules.new", len(filter)).
|
||||
Msg("Policy filter hash changed")
|
||||
}
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
@ -95,6 +104,14 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
if tagOwnerChanged {
|
||||
log.Debug().
|
||||
Str("tagOwner.hash.old", pm.tagOwnerMapHash.String()[:8]).
|
||||
Str("tagOwner.hash.new", tagOwnerMapHash.String()[:8]).
|
||||
Int("tagOwners.old", len(pm.tagOwnerMap)).
|
||||
Int("tagOwners.new", len(tagMap)).
|
||||
Msg("Tag owner hash changed")
|
||||
}
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
@ -105,19 +122,42 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
if autoApproveChanged {
|
||||
log.Debug().
|
||||
Str("autoApprove.hash.old", pm.autoApproveMapHash.String()[:8]).
|
||||
Str("autoApprove.hash.new", autoApproveMapHash.String()[:8]).
|
||||
Int("autoApprovers.old", len(pm.autoApproveMap)).
|
||||
Int("autoApprovers.new", len(autoMap)).
|
||||
Msg("Auto-approvers hash changed")
|
||||
}
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
exitSetHash := deephash.Hash(&autoMap)
|
||||
exitSetHash := deephash.Hash(&exitSet)
|
||||
exitSetChanged := exitSetHash != pm.exitSetHash
|
||||
if exitSetChanged {
|
||||
log.Debug().
|
||||
Str("exitSet.hash.old", pm.exitSetHash.String()[:8]).
|
||||
Str("exitSet.hash.new", exitSetHash.String()[:8]).
|
||||
Msg("Exit node set hash changed")
|
||||
}
|
||||
pm.exitSet = exitSet
|
||||
pm.exitSetHash = exitSetHash
|
||||
|
||||
// If neither of the calculated values changed, no need to update nodes
|
||||
if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged {
|
||||
log.Trace().
|
||||
Msg("Policy evaluation detected no changes - all hashes match")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Bool("filter.changed", filterChanged).
|
||||
Bool("tagOwners.changed", tagOwnerChanged).
|
||||
Bool("autoApprovers.changed", autoApproveChanged).
|
||||
Bool("exitNodes.changed", exitSetChanged).
|
||||
Msg("Policy changes require node updates")
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@ -151,6 +191,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// Log policy metadata for debugging
|
||||
log.Debug().
|
||||
Int("policy.bytes", len(polB)).
|
||||
Int("acls.count", len(pol.ACLs)).
|
||||
Int("groups.count", len(pol.Groups)).
|
||||
Int("hosts.count", len(pol.Hosts)).
|
||||
Int("tagOwners.count", len(pol.TagOwners)).
|
||||
Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)).
|
||||
Msg("Policy parsed successfully")
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
|
@ -216,6 +216,21 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
// TODO(kradalby): Redo the comments here
|
||||
// Add node to batcher so it can receive updates,
|
||||
// adding this before connecting it to the state ensure that
|
||||
// it does not miss any updates that might be sent in the split
|
||||
// time between the node connecting and the batcher being ready.
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
|
||||
return
|
||||
}
|
||||
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
|
||||
|
||||
m.h.Change(mapReqChange)
|
||||
m.h.Change(connectChanges...)
|
||||
|
||||
// Loop through updates and continuously send them to the
|
||||
// client.
|
||||
for {
|
||||
@ -227,7 +242,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
m.tracef("poll context done")
|
||||
m.tracef("poll context done chan:%p", m.ch)
|
||||
mapResponseEnded.WithLabelValues("done").Inc()
|
||||
return
|
||||
|
||||
@ -295,7 +310,15 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node.name", m.node.Hostname).
|
||||
Uint64("node.id", m.node.ID.Uint64()).
|
||||
Str("chan", fmt.Sprintf("%p", m.ch)).
|
||||
TimeDiff("timeSpent", time.Now(), startWrite).
|
||||
Str("machine.key", m.node.MachineKey.String()).
|
||||
Bool("keepalive", msg.KeepAlive).
|
||||
Msgf("finished writing mapresp to node chan(%p)", m.ch)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -305,14 +328,14 @@ var keepAlive = tailcfg.MapResponse{
|
||||
}
|
||||
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
trace := log.Trace().Caller().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
|
||||
if peerChange.Key != nil {
|
||||
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
||||
trace = trace.Str("node.key", peerChange.Key.ShortString())
|
||||
}
|
||||
|
||||
if peerChange.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
||||
trace = trace.Str("disco.key", peerChange.DiscoKey.ShortString())
|
||||
}
|
||||
|
||||
if peerChange.Online != nil {
|
||||
@ -349,7 +372,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
@ -358,7 +381,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
@ -367,7 +390,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(err error, msg string, a ...any) {
|
||||
@ -376,7 +399,7 @@ func logPollFunc(
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Str("node", node.Hostname).
|
||||
Str("node.name", node.Hostname).
|
||||
Err(err).
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
|
@ -1430,7 +1430,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
|
||||
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||
log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
|
@ -97,6 +97,35 @@ func (c ChangeSet) IsFull() bool {
|
||||
return c.Change == Full || c.Change == Policy
|
||||
}
|
||||
|
||||
func HasFull(cs []ChangeSet) bool {
|
||||
for _, c := range cs {
|
||||
if c.IsFull() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) {
|
||||
for _, c := range cs {
|
||||
if c.SelfUpdateOnly {
|
||||
self = append(self, c)
|
||||
} else {
|
||||
all = append(all, c)
|
||||
}
|
||||
}
|
||||
return all, self
|
||||
}
|
||||
|
||||
func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
|
||||
for _, c := range cs {
|
||||
if c.NodeID != id || c.Change.AlsoSelf() {
|
||||
ret = append(ret, c)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c ChangeSet) AlsoSelf() bool {
|
||||
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||
// so we consider it as a change that should be sent to all nodes.
|
||||
|
@ -489,6 +489,7 @@ func derpConfig() DERPConfig {
|
||||
urlAddr, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("url", urlStr).
|
||||
Err(err).
|
||||
Msg("Failed to parse url, ignoring...")
|
||||
@ -561,6 +562,7 @@ func logConfig() LogConfig {
|
||||
logFormat = TextLogFormat
|
||||
default:
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "GetLogConfig").
|
||||
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
|
||||
}
|
||||
|
@ -54,6 +54,20 @@ func (id NodeID) String() string {
|
||||
return strconv.FormatUint(id.Uint64(), util.Base10)
|
||||
}
|
||||
|
||||
func ParseNodeID(s string) (NodeID, error) {
|
||||
id, err := strconv.ParseUint(s, util.Base10, 64)
|
||||
return NodeID(id), err
|
||||
}
|
||||
|
||||
func MustParseNodeID(s string) NodeID {
|
||||
id, err := ParseNodeID(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
// Node is a Headscale client.
|
||||
type Node struct {
|
||||
ID NodeID `gorm:"primary_key"`
|
||||
|
@ -61,6 +61,7 @@ func (pak *PreAuthKey) Validate() error {
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("key", pak.Key).
|
||||
Bool("hasExpiration", pak.Expiration != nil).
|
||||
Time("expiration", func() time.Time {
|
||||
|
@ -321,7 +321,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
|
||||
if err == nil {
|
||||
u.Name = claims.Username
|
||||
} else {
|
||||
log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username)
|
||||
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
|
||||
}
|
||||
|
||||
if claims.EmailVerified {
|
||||
|
@ -1160,57 +1160,61 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||
err = headscale.SetPolicy(&p)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the current policy and check
|
||||
// if it is the same as the one we set.
|
||||
var output *policyv2.Policy
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"policy",
|
||||
"get",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&output,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Get the current policy and check
|
||||
// if it is the same as the one we set.
|
||||
var output *policyv2.Policy
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"policy",
|
||||
"get",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&output,
|
||||
)
|
||||
assert.NoError(ct, err)
|
||||
|
||||
assert.Len(t, output.ACLs, 1)
|
||||
assert.Len(t, output.ACLs, 1)
|
||||
|
||||
if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("unexpected policy(-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// Test that user1 can visit all user2
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
require.NoError(t, err)
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Len(t, result, 13)
|
||||
require.NoError(t, err)
|
||||
if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" {
|
||||
ct.Errorf("unexpected policy(-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "verifying that the new policy took place")
|
||||
|
||||
// Test that user2 _cannot_ visit user1
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
require.NoError(t, err)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Test that user1 can visit all user2
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Empty(t, result)
|
||||
require.Error(t, err)
|
||||
result, err := client.Curl(url)
|
||||
assert.Len(ct, result, 13)
|
||||
assert.NoError(ct, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that user2 _cannot_ visit user1
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Empty(ct, result)
|
||||
assert.Error(ct, err)
|
||||
}
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "new policy did not get propagated to nodes")
|
||||
}
|
||||
|
||||
func TestACLAutogroupMember(t *testing.T) {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/samber/lo"
|
||||
@ -53,6 +54,18 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
expectedNodes := make([]types.NodeID, 0, len(allClients))
|
||||
for _, client := range allClients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
assertNoErr(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||
@ -64,9 +77,6 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
clientIPs[client] = ips
|
||||
}
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
@ -86,6 +96,9 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErrLogout(t, err)
|
||||
|
||||
// After taking down all nodes, verify all systems show nodes offline
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
|
@ -481,10 +481,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Empty(t, listUsers)
|
||||
|
||||
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
|
||||
assertNoErr(t, err)
|
||||
|
||||
@ -494,26 +490,28 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
|
||||
listUsers, err = headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 1)
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
}
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 1)
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
}
|
||||
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "validating users after first login")
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assertNoErr(t, err)
|
||||
@ -525,19 +523,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
// manually.
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Wait for logout to complete and then do second logout
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 5*time.Second, 1*time.Second)
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
// manually.
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
}, 30*time.Second, 1*time.Second)
|
||||
|
||||
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
@ -545,56 +543,56 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
|
||||
listUsers, err = headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 2)
|
||||
wantUsers = []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Name: "user2",
|
||||
Email: "user2@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
|
||||
},
|
||||
}
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 2)
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Name: "user2",
|
||||
Email: "user2@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
|
||||
},
|
||||
}
|
||||
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
ct.Errorf("unexpected users: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "validating users after new user login")
|
||||
|
||||
listNodesAfterNewUserLogin, err := headscale.ListNodes()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listNodesAfterNewUserLogin, 2)
|
||||
var listNodesAfterNewUserLogin []*v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listNodesAfterNewUserLogin, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodesAfterNewUserLogin, 2)
|
||||
|
||||
// Machine key is the same as the "machine" has not changed,
|
||||
// but Node key is not as it is a new node
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
|
||||
// Machine key is the same as the "machine" has not changed,
|
||||
// but Node key is not as it is a new node
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
|
||||
}, 30*time.Second, 1*time.Second, "listing nodes after new user login")
|
||||
|
||||
// Log out user2, and log into user1, no new node should be created,
|
||||
// the node should now "become" node1 again
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Wait for logout to complete and then do second logout
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 5*time.Second, 1*time.Second)
|
||||
t.Logf("Logged out take one")
|
||||
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
@ -602,65 +600,92 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
t.Logf("Logged out take two")
|
||||
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
|
||||
|
||||
// Wait for logout to complete and then do second logout
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second)
|
||||
|
||||
// We do not actually "change" the user here, it is done by logging in again
|
||||
// as the OIDC mock server is kind of like a stack, and the next user is
|
||||
// prepared and ready to go.
|
||||
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
|
||||
listUsers, err = headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 2)
|
||||
wantUsers = []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Name: "user2",
|
||||
Email: "user2@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
|
||||
},
|
||||
}
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "Running", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second)
|
||||
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
t.Logf("Logged back in")
|
||||
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
}
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listUsers, 2)
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Name: "user2",
|
||||
Email: "user2@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user2",
|
||||
},
|
||||
}
|
||||
|
||||
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listNodesAfterLoggingBackIn, 2)
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
|
||||
// Validate that the machine we had when we logged in the first time, has the same
|
||||
// machine key, but a different ID than the newly logged in version of the same
|
||||
// machine.
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
|
||||
assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
|
||||
assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
ct.Errorf("unexpected users: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
||||
|
||||
// Even tho we are logging in again with the same user, the previous key has been expired
|
||||
// and a new one has been generated. The node entry in the database should be the same
|
||||
// as the user + machinekey still matches.
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
|
||||
assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
|
||||
assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodesAfterLoggingBackIn, 2)
|
||||
|
||||
// The "logged back in" machine should have the same machinekey but a different nodekey
|
||||
// than the version logged in with a different user.
|
||||
assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
|
||||
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
|
||||
// Validate that the machine we had when we logged in the first time, has the same
|
||||
// machine key, but a different ID than the newly logged in version of the same
|
||||
// machine.
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
|
||||
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
|
||||
assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
|
||||
|
||||
// Even tho we are logging in again with the same user, the previous key has been expired
|
||||
// and a new one has been generated. The node entry in the database should be the same
|
||||
// as the user + machinekey still matches.
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
|
||||
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
|
||||
|
||||
// The "logged back in" machine should have the same machinekey but a different nodekey
|
||||
// than the version logged in with a different user.
|
||||
assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
|
||||
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
||||
}
|
||||
|
||||
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"net/netip"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
|
@ -10,18 +10,21 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
@ -59,13 +62,15 @@ func TestPingAllByIP(t *testing.T) {
|
||||
hs, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
all, err := hs.GetAllMapReponses()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
onlineMap := buildExpectedOnlineMap(all)
|
||||
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
// Extract node IDs for validation
|
||||
expectedNodes := make([]types.NodeID, 0, len(allClients))
|
||||
for _, client := range allClients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
require.NoError(t, err, "failed to parse node ID")
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
requireAllClientsOnline(t, hs, expectedNodes, true, "all clients should be online across all systems", 30*time.Second)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@ -73,6 +78,14 @@ func TestPingAllByIP(t *testing.T) {
|
||||
return x.String()
|
||||
})
|
||||
|
||||
// Get headscale instance for batcher debug check
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Test our DebugBatcher functionality
|
||||
t.Logf("Testing DebugBatcher functionality...")
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to the batcher", 30*time.Second)
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
}
|
||||
@ -962,9 +975,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
hs, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
@ -980,14 +990,31 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
return x.String()
|
||||
})
|
||||
|
||||
// Get headscale instance for batcher debug checks
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Initial check: all nodes should be connected to batcher
|
||||
// Extract node IDs for validation
|
||||
expectedNodes := make([]types.NodeID, 0, len(allClients))
|
||||
for _, client := range allClients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
assertNoErr(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second)
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
wg, _ := errgroup.WithContext(context.Background())
|
||||
|
||||
for run := range 3 {
|
||||
t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999"))
|
||||
|
||||
// Create fresh errgroup with timeout for each run
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
wg, _ := errgroup.WithContext(ctx)
|
||||
|
||||
for _, client := range allClients {
|
||||
c := client
|
||||
wg.Go(func() error {
|
||||
@ -1001,6 +1028,9 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
}
|
||||
t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
|
||||
|
||||
// After taking down all nodes, verify all systems show nodes offline
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, false, fmt.Sprintf("Run %d: all nodes should be offline after Down()", run+1), 120*time.Second)
|
||||
|
||||
for _, client := range allClients {
|
||||
c := client
|
||||
wg.Go(func() error {
|
||||
@ -1014,22 +1044,22 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
}
|
||||
t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
|
||||
|
||||
// After bringing up all nodes, verify batcher shows all reconnected
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all nodes should be reconnected after Up()", run+1), 120*time.Second)
|
||||
|
||||
// Wait for sync and successful pings after nodes come back up
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
all, err := hs.GetAllMapReponses()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
onlineMap := buildExpectedOnlineMap(all)
|
||||
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
|
||||
}, 60*time.Second, 2*time.Second)
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all systems should show nodes online after reconnection", run+1), 60*time.Second)
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
// Clean up context for this run
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1141,51 +1171,158 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
|
||||
}
|
||||
|
||||
func buildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool {
|
||||
res := make(map[types.NodeID]map[types.NodeID]bool)
|
||||
for nid, mrs := range all {
|
||||
res[nid] = make(map[types.NodeID]bool)
|
||||
for _, mr := range mrs {
|
||||
for _, peer := range mr.Peers {
|
||||
if peer.Online != nil {
|
||||
res[nid][types.NodeID(peer.ID)] = *peer.Online
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range mr.PeersChanged {
|
||||
if peer.Online != nil {
|
||||
res[nid][types.NodeID(peer.ID)] = *peer.Online
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range mr.PeersChangedPatch {
|
||||
if peer.Online != nil {
|
||||
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
// NodeSystemStatus represents the online status of a node across different systems
|
||||
type NodeSystemStatus struct {
|
||||
Batcher bool
|
||||
BatcherConnCount int
|
||||
MapResponses bool
|
||||
NodeStore bool
|
||||
}
|
||||
|
||||
func assertExpectedOnlineMapAllOnline(t *assert.CollectT, expectedPeerCount int, onlineMap map[types.NodeID]map[types.NodeID]bool) {
|
||||
for nid, peers := range onlineMap {
|
||||
onlineCount := 0
|
||||
for _, online := range peers {
|
||||
if online {
|
||||
onlineCount++
|
||||
// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore
|
||||
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
startTime := time.Now()
|
||||
t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format("2006-01-02T15:04:05.000"), message)
|
||||
|
||||
var prevReport string
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Get batcher state
|
||||
debugInfo, err := headscale.DebugBatcher()
|
||||
assert.NoError(c, err, "Failed to get batcher debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get map responses
|
||||
mapResponses, err := headscale.GetAllMapReponses()
|
||||
assert.NoError(c, err, "Failed to get map responses")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate node counts first
|
||||
expectedCount := len(expectedNodes)
|
||||
assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch")
|
||||
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
|
||||
|
||||
// Check that we have map responses for expected nodes
|
||||
mapResponseCount := len(mapResponses)
|
||||
assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch")
|
||||
|
||||
// Build status map for each node
|
||||
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
|
||||
|
||||
// Initialize all expected nodes
|
||||
for _, nodeID := range expectedNodes {
|
||||
nodeStatus[nodeID] = NodeSystemStatus{}
|
||||
}
|
||||
|
||||
// Check batcher state
|
||||
for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes {
|
||||
nodeID := types.MustParseNodeID(nodeIDStr)
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.Batcher = nodeInfo.Connected
|
||||
status.BatcherConnCount = nodeInfo.ActiveConnections
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
assert.Equalf(t, expectedPeerCount, len(peers), "node:%d had an unexpected number of peers in online map", nid)
|
||||
if expectedPeerCount != onlineCount {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Not all of node:%d peers where online:\n", nid))
|
||||
for pid, online := range peers {
|
||||
sb.WriteString(fmt.Sprintf("\tPeer node:%d online: %t\n", pid, online))
|
||||
|
||||
// Check map responses using buildExpectedOnlineMap
|
||||
onlineFromMaps := make(map[types.NodeID]bool)
|
||||
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
|
||||
for nodeID := range nodeStatus {
|
||||
NODE_STATUS:
|
||||
for id, peerMap := range onlineMap {
|
||||
if id == nodeID {
|
||||
continue
|
||||
}
|
||||
|
||||
online := peerMap[nodeID]
|
||||
// If the node is offline in any map response, we consider it offline
|
||||
if !online {
|
||||
onlineFromMaps[nodeID] = false
|
||||
continue NODE_STATUS
|
||||
}
|
||||
|
||||
onlineFromMaps[nodeID] = true
|
||||
}
|
||||
sb.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
|
||||
sb.WriteString("expected all peers to be online.")
|
||||
t.Errorf("%s", sb.String())
|
||||
}
|
||||
}
|
||||
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
|
||||
|
||||
// Update status with map response data
|
||||
for nodeID, online := range onlineFromMaps {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.MapResponses = online
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
|
||||
// Check nodestore state
|
||||
for nodeID, node := range nodeStore {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
// Check if node is online in nodestore
|
||||
status.NodeStore = node.IsOnline != nil && *node.IsOnline
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all systems show nodes in expected state and report failures
|
||||
allMatch := true
|
||||
var failureReport strings.Builder
|
||||
|
||||
ids := types.NodeIDs(maps.Keys(nodeStatus))
|
||||
slices.Sort(ids)
|
||||
for _, nodeID := range ids {
|
||||
status := nodeStatus[nodeID]
|
||||
systemsMatch := (status.Batcher == expectedOnline) &&
|
||||
(status.MapResponses == expectedOnline) &&
|
||||
(status.NodeStore == expectedOnline)
|
||||
|
||||
if !systemsMatch {
|
||||
allMatch = false
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr))
|
||||
failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher))
|
||||
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
|
||||
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses))
|
||||
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore))
|
||||
}
|
||||
}
|
||||
|
||||
if !allMatch {
|
||||
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
|
||||
t.Log("Diff between reports:")
|
||||
t.Logf("Prev report: \n%s\n", prevReport)
|
||||
t.Logf("New report: \n%s\n", failureReport.String())
|
||||
t.Log("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
|
||||
prevReport = failureReport.String()
|
||||
}
|
||||
|
||||
failureReport.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
|
||||
|
||||
assert.Fail(c, failureReport.String())
|
||||
}
|
||||
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr))
|
||||
}, timeout, 2*time.Second, message)
|
||||
|
||||
endTime := time.Now()
|
||||
duration := endTime.Sub(startTime)
|
||||
t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format("2006-01-02T15:04:05.000"), duration, message)
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// PeerSyncTimeout returns the timeout for peer synchronization based on environment:
|
||||
@ -199,3 +200,30 @@ func CreateCertificate(hostname string) ([]byte, []byte, error) {
|
||||
|
||||
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
|
||||
}
|
||||
|
||||
func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool {
|
||||
res := make(map[types.NodeID]map[types.NodeID]bool)
|
||||
for nid, mrs := range all {
|
||||
res[nid] = make(map[types.NodeID]bool)
|
||||
for _, mr := range mrs {
|
||||
for _, peer := range mr.Peers {
|
||||
if peer.Online != nil {
|
||||
res[nid][types.NodeID(peer.ID)] = *peer.Online
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range mr.PeersChanged {
|
||||
if peer.Online != nil {
|
||||
res[nid][types.NodeID(peer.ID)] = *peer.Online
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range mr.PeersChangedPatch {
|
||||
if peer.Online != nil {
|
||||
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -5,7 +5,9 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@ -61,14 +63,14 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
|
||||
rawURL := fmt.Sprintf(rawFileURL, version)
|
||||
resp, err := http.Get(rawURL)
|
||||
if err != nil {
|
||||
fmt.Printf("Error fetching raw file for version %s: %v\n", version, err)
|
||||
log.Printf("Error fetching raw file for version %s: %v\n", version, err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading raw file for version %s: %v\n", version, err)
|
||||
log.Printf("Error reading raw file for version %s: %v\n", version, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -79,7 +81,7 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
|
||||
capabilityVersion, _ := strconv.Atoi(capabilityVersionStr)
|
||||
versions[version] = tailcfg.CapabilityVersion(capabilityVersion)
|
||||
} else {
|
||||
fmt.Printf("Version: %s, CurrentCapabilityVersion not found\n", version)
|
||||
log.Printf("Version: %s, CurrentCapabilityVersion not found\n", version)
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,29 +89,23 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
|
||||
}
|
||||
|
||||
func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion) error {
|
||||
// Open the output file
|
||||
file, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write the package declaration and variable
|
||||
file.WriteString("package capver\n\n")
|
||||
file.WriteString("//Generated DO NOT EDIT\n\n")
|
||||
file.WriteString(`import "tailscale.com/tailcfg"`)
|
||||
file.WriteString("\n\n")
|
||||
file.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
|
||||
// Generate the Go code as a string
|
||||
var content strings.Builder
|
||||
content.WriteString("package capver\n\n")
|
||||
content.WriteString("// Generated DO NOT EDIT\n\n")
|
||||
content.WriteString(`import "tailscale.com/tailcfg"`)
|
||||
content.WriteString("\n\n")
|
||||
content.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
|
||||
|
||||
sortedVersions := xmaps.Keys(versions)
|
||||
sort.Strings(sortedVersions)
|
||||
for _, version := range sortedVersions {
|
||||
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version])
|
||||
fmt.Fprintf(&content, "\t\"%s\": %d,\n", version, versions[version])
|
||||
}
|
||||
file.WriteString("}\n")
|
||||
content.WriteString("}\n")
|
||||
|
||||
file.WriteString("\n\n")
|
||||
file.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n")
|
||||
content.WriteString("\n\n")
|
||||
content.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n")
|
||||
|
||||
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
|
||||
for _, v := range sortedVersions {
|
||||
@ -129,9 +125,21 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
||||
return capsSorted[i] < capsSorted[j]
|
||||
})
|
||||
for _, capVer := range capsSorted {
|
||||
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
|
||||
fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
|
||||
}
|
||||
content.WriteString("}\n")
|
||||
|
||||
// Format the generated code
|
||||
formatted, err := format.Source([]byte(content.String()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error formatting Go code: %w", err)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
err = os.WriteFile(outputFile, formatted, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing file: %w", err)
|
||||
}
|
||||
file.WriteString("}\n")
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -139,15 +147,15 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
||||
func main() {
|
||||
versions, err := getCapabilityVersions()
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
log.Println("Error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = writeCapabilityVersionsToFile(versions)
|
||||
if err != nil {
|
||||
fmt.Println("Error writing to file:", err)
|
||||
log.Println("Error writing to file:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Capability versions written to", outputFile)
|
||||
log.Println("Capability versions written to", outputFile)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user