1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

state/nodestore: in memory representation of nodes

Initial work on a nodestore which stores all of the nodes
and their relations in memory with relationship for peers
precalculated.

It is a copy-on-write structure, replacing the "snapshot"
when a change to the structure occurs. It is optimised for reads,
and while batches are not fast, they are grouped together
to do less of the expensive peer calculation if there are many
changes rapidly.

Writes will block until commited, while reads are never
blocked.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-07-05 23:30:47 +02:00
parent b155f30ef6
commit 2e20652fdf
No known key found for this signature in database
35 changed files with 3960 additions and 1317 deletions

View File

@ -551,13 +551,12 @@ be assigned to nodes.`,
} }
} }
if confirm || force { if confirm || force {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force }) changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force})
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,

View File

@ -137,9 +137,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
// Initialize ephemeral garbage collector // Initialize ephemeral garbage collector
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
node, err := app.state.GetNodeByID(ni) node, ok := app.state.GetNodeByID(ni)
if err != nil { if !ok {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion") log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
return return
} }
@ -379,15 +380,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
log.Trace(). if err := func() error {
Caller(). log.Trace().
Str("client_address", req.RemoteAddr).
Msg("HTTP authentication invoked")
authHeader := req.Header.Get("authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller(). Caller().
Str("client_address", req.RemoteAddr). Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
@ -501,11 +495,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches the HTTP and gRPC server service Headscale and the API. // Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
var err error
capver.CanOldCodeBeCleanedUp() capver.CanOldCodeBeCleanedUp()
if profilingEnabled { if profilingEnabled {
if profilingPath != "" { if profilingPath != "" {
err := os.MkdirAll(profilingPath, os.ModePerm) err = os.MkdirAll(profilingPath, os.ModePerm)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory") log.Fatal().Err(err).Msg("failed to create profiling directory")
} }
@ -559,12 +554,9 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will // around between restarts, they will reconnect and the GC will
// be cancelled. // be cancelled.
go h.ephemeralGC.Start() go h.ephemeralGC.Start()
ephmNodes, err := h.state.ListEphemeralNodes() ephmNodes := h.state.ListEphemeralNodes()
if err != nil { for _, node := range ephmNodes.All() {
return fmt.Errorf("failed to list ephemeral nodes: %w", err) h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
}
for _, node := range ephmNodes {
h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout)
} }
if h.cfg.DNSConfig.ExtraRecordsPath != "" { if h.cfg.DNSConfig.ExtraRecordsPath != "" {
@ -794,23 +786,14 @@ func (h *Headscale) Serve() error {
continue continue
} }
changed, err := h.state.ReloadPolicy() changes, err := h.state.ReloadPolicy()
if err != nil { if err != nil {
log.Error().Err(err).Msgf("reloading policy") log.Error().Err(err).Msgf("reloading policy")
continue continue
} }
if changed { h.Change(changes...)
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
err = h.state.AutoApproveNodes()
if err != nil {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
h.Change(change.PolicySet)
}
default: default:
info := func(msg string) { log.Info().Msg(msg) } info := func(msg string) { log.Info().Msg(msg) }
log.Info(). log.Info().
@ -1020,6 +1003,6 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
// Change is used to send changes to nodes. // Change is used to send changes to nodes.
// All change should be enqueued here and empty will be automatically // All change should be enqueued here and empty will be automatically
// ignored. // ignored.
func (h *Headscale) Change(c change.ChangeSet) { func (h *Headscale) Change(cs ...change.ChangeSet) {
h.mapBatcher.AddWork(c) h.mapBatcher.AddWork(cs...)
} }

View File

@ -12,7 +12,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/types/change"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -29,28 +28,10 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey) node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("looking up node in database: %w", err)
}
if node != nil { if ok {
// If an existing node is trying to register with an auth key, resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
// we need to validate the auth key even for existing nodes
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
}
return resp, nil
}
resp, err := h.handleExistingNode(node, regReq, machineKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err) return nil, fmt.Errorf("handling existing node: %w", err)
} }
@ -70,6 +51,7 @@ func (h *Headscale) handleRegister(
if errors.As(err, &httpErr) { if errors.As(err, &httpErr) {
return nil, httpErr return nil, httpErr
} }
return nil, fmt.Errorf("handling register with auth key: %w", err) return nil, fmt.Errorf("handling register with auth key: %w", err)
} }
@ -89,13 +71,22 @@ func (h *Headscale) handleExistingNode(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey { if node.MachineKey != machineKey {
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
} }
expired := node.IsExpired() expired := node.IsExpired()
// If the node is expired and this is not a re-authentication attempt,
// force the client to re-authenticate
if expired && regReq.Auth == nil {
return &tailcfg.RegisterResponse{
NodeKeyExpired: true,
MachineAuthorized: false,
AuthURL: "", // Client will need to re-authenticate
}, nil
}
if !expired && !regReq.Expiry.IsZero() { if !expired && !regReq.Expiry.IsZero() {
requestExpiry := regReq.Expiry requestExpiry := regReq.Expiry
@ -107,7 +98,7 @@ func (h *Headscale) handleExistingNode(
// If the request expiry is in the past, we consider it a logout. // If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) { if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() { if node.IsEphemeral() {
c, err := h.state.DeleteNode(node) c, err := h.state.DeleteNode(node.View())
if err != nil { if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err) return nil, fmt.Errorf("deleting ephemeral node: %w", err)
} }
@ -118,15 +109,19 @@ func (h *Headscale) handleExistingNode(
} }
} }
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
if err != nil { if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err) return nil, fmt.Errorf("setting node expiry: %w", err)
} }
h.Change(c) h.Change(c)
}
return nodeToRegisterResponse(node), nil // CRITICAL: Use the updated node view for the response
// The original node object has stale expiry information
node = updatedNode.AsStruct()
}
return nodeToRegisterResponse(node), nil
} }
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
@ -177,7 +172,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey( node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq, regReq,
machineKey, machineKey,
) )
@ -193,8 +188,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
return nil, err return nil, err
} }
// If node is nil, it means an ephemeral node was deleted during logout // If node is not valid, it means an ephemeral node was deleted during logout
if node == nil { if !node.Valid() {
h.Change(changed) h.Change(changed)
return nil, nil return nil, nil
} }
@ -213,26 +208,30 @@ func (h *Headscale) handleRegisterWithAuthKey(
// TODO(kradalby): This needs to be ran as part of the batcher maybe? // TODO(kradalby): This needs to be ran as part of the batcher maybe?
// now since we dont update the node/pol here anymore // now since we dont update the node/pol here anymore
routeChange := h.state.AutoApproveRoutes(node) routeChange := h.state.AutoApproveRoutes(node)
if _, _, err := h.state.SaveNode(node); err != nil { if _, _, err := h.state.SaveNode(node); err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err) return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
} }
if routeChange && changed.Empty() { if routeChange && changed.Empty() {
changed = change.NodeAdded(node.ID) changed = change.NodeAdded(node.ID())
} }
h.Change(changed) h.Change(changed)
// If policy changed due to node registration, send a separate policy change // TODO(kradalby): I think this is covered above, but we need to validate that.
if policyChanged { // // If policy changed due to node registration, send a separate policy change
policyChange := change.PolicyChange() // if policyChanged {
h.Change(policyChange) // policyChange := change.PolicyChange()
} // h.Change(policyChange)
// }
user := node.User()
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
MachineAuthorized: true, MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(), NodeKeyExpired: node.IsExpired(),
User: *node.User.TailscaleUser(), User: *user.TailscaleUser(),
Login: *node.User.TailscaleLogin(), Login: *user.TailscaleLogin(),
}, nil }, nil
} }
@ -266,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive(
) )
log.Info().Msgf("Starting node registration using key: %s", registrationId) log.Info().Msgf("Starting node registration using key: %s", registrationId)
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId), AuthURL: h.authProvider.AuthURL(registrationId),
}, nil }, nil

View File

@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
} }
// RenameNode takes a Node struct and a new GivenName for the nodes // RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it. If the name is not unique, it will return an error. // and renames it. Validation should be done in the state layer before calling this function.
func RenameNode(tx *gorm.DB, func RenameNode(tx *gorm.DB,
nodeID types.NodeID, newName string, nodeID types.NodeID, newName string,
) error { ) error {
err := util.CheckForFQDNRules( // Check if the new name is unique
newName, var count int64
) if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
if err != nil { return fmt.Errorf("failed to check name uniqueness: %w", err)
return fmt.Errorf("renaming node: %w", err)
} }
uniq, err := isUniqueName(tx, newName) if count > 0 {
if err != nil { return errors.New("name is not unique")
return fmt.Errorf("checking if name is unique: %w", err)
}
if !uniq {
return fmt.Errorf("name is not unique: %s", newName)
} }
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
@ -333,108 +327,19 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
}) })
} }
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path // RegisterNodeForTest is used only for testing purposes to register a node directly in the database.
// with a registrationID to register or reauthenticate a node. // Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey.
// If the node found in the registration cache is not already registered, func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
// it will be registered with the user and the node will be removed from the cache. if !testing.Testing() {
// If the node is already registered, the expiry will be updated. panic("RegisterNodeForTest can only be called during tests")
// The node, and a boolean indicating if it was a new node or not, will be returned. }
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, change.ChangeSet, error) {
var nodeChange change.ChangeSet
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
err,
)
}
log.Debug().
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}
reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod
if nodeExpiry != nil {
reg.Node.Expiry = nodeExpiry
}
node, err := RegisterNode(
tx,
reg.Node,
ipv4, ipv6,
)
if err == nil {
hsdb.regCache.Delete(registrationID)
}
// Signal to waiting clients that the machine has been registered.
select {
case reg.Registered <- node:
default:
}
close(reg.Registered)
nodeChange = change.NodeAdded(node.ID)
return node, err
} else {
// If the node is already registered, this is a refresh.
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
if err != nil {
return nil, err
}
nodeChange = change.KeyExpiry(node.ID)
return node, nil
}
}
return nil, ErrNodeNotFoundRegistrationCache
})
return node, nodeChange, err
}
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
return RegisterNode(tx, node, ipv4, ipv6)
})
}
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
log.Debug(). log.Debug().
Str("node", node.Hostname). Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()). Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()). Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Username()). Str("user", node.User.Username()).
Msg("Registering node") Msg("Registering test node")
// If the a new node is registered with the same machine key, to the same user, // If the a new node is registered with the same machine key, to the same user,
// update the existing node. // update the existing node.
@ -445,8 +350,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
node.ID = oldNode.ID node.ID = oldNode.ID
node.GivenName = oldNode.GivenName node.GivenName = oldNode.GivenName
node.ApprovedRoutes = oldNode.ApprovedRoutes node.ApprovedRoutes = oldNode.ApprovedRoutes
ipv4 = oldNode.IPv4 // Don't overwrite the provided IPs with old ones when they exist
ipv6 = oldNode.IPv6 if ipv4 == nil {
ipv4 = oldNode.IPv4
}
if ipv6 == nil {
ipv6 = oldNode.IPv6
}
} }
// If the node exists and it already has IP(s), we just save it // If the node exists and it already has IP(s), we just save it
@ -463,7 +373,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
Str("machine_key", node.MachineKey.ShortString()). Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()). Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Username()). Str("user", node.User.Username()).
Msg("Node authorized again") Msg("Test node authorized again")
return &node, nil return &node, nil
} }
@ -472,7 +382,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
node.IPv6 = ipv6 node.IPv6 = ipv6
if node.GivenName == "" { if node.GivenName == "" {
givenName, err := ensureUniqueGivenName(tx, node.Hostname) givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to ensure unique given name: %w", err) return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
} }
@ -487,7 +397,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("Node registered with the database") Msg("Test node registered with the database")
return &node, nil return &node, nil
} }
@ -560,7 +470,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) {
return len(nodes) == 0, nil return len(nodes) == 0, nil
} }
func ensureUniqueGivenName( // EnsureUniqueGivenName generates a unique given name for a node based on its hostname.
func EnsureUniqueGivenName(
tx *gorm.DB, tx *gorm.DB,
name string, name string,
) (string, error) { ) (string, error) {
@ -781,19 +692,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
node := hsdb.CreateNodeForTest(user, hostname...) node := hsdb.CreateNodeForTest(user, hostname...)
err := hsdb.DB.Transaction(func(tx *gorm.DB) error { // Allocate IPs for the test node using the database's IP allocator
_, err := RegisterNode(tx, *node, nil, nil) // This is a simplified allocation for testing - in production this would use State.ipAlloc
ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID)
if err != nil {
panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err))
}
var registeredNode *types.Node
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
var err error
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
return err return err
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to register test node: %v", err)) panic(fmt.Sprintf("failed to register test node: %v", err))
} }
registeredNode, err := hsdb.GetNodeByID(node.ID)
if err != nil {
panic(fmt.Sprintf("failed to get registered test node: %v", err))
}
return registeredNode return registeredNode
} }
@ -842,3 +757,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
return nodes return nodes
} }
// allocateTestIPs allocates sequential test IPs for nodes during testing.
func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) {
if !testing.Testing() {
panic("allocateTestIPs can only be called during tests")
}
// Use simple sequential allocation for tests
// IPv4: 100.64.0.x (where x is nodeID)
// IPv6: fd7a:115c:a1e0::x (where x is nodeID)
if nodeID > 254 {
return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID)
}
ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)})
ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)})
return &ipv4, &ipv6, nil
}

View File

@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) {
func TestAutoApproveRoutes(t *testing.T) { func TestAutoApproveRoutes(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
acl string acl string
routes []netip.Prefix routes []netip.Prefix
want []netip.Prefix want []netip.Prefix
want2 []netip.Prefix want2 []netip.Prefix
expectChange bool // whether to expect route changes
}{ }{
{
name: "no-auto-approvers-empty-policy",
acl: `
{
"groups": {
"group:admins": ["test@"]
},
"acls": [
{
"action": "accept",
"src": ["group:admins"],
"dst": ["group:admins:*"]
}
]
}`,
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
want: []netip.Prefix{}, // Should be empty - no auto-approvers
want2: []netip.Prefix{}, // Should be empty - no auto-approvers
expectChange: false, // No changes expected
},
{
name: "no-auto-approvers-explicit-empty",
acl: `
{
"groups": {
"group:admins": ["test@"]
},
"acls": [
{
"action": "accept",
"src": ["group:admins"],
"dst": ["group:admins:*"]
}
],
"autoApprovers": {
"routes": {},
"exitNode": []
}
}`,
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
expectChange: false, // No changes expected
},
{ {
name: "2068-approve-issue-sub-kube", name: "2068-approve-issue-sub-kube",
acl: ` acl: `
@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) {
} }
} }
}`, }`,
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
expectChange: true, // Routes should be approved
}, },
{ {
name: "2068-approve-issue-sub-exit-tag", name: "2068-approve-issue-sub-exit-tag",
@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) {
tsaddr.AllIPv4(), tsaddr.AllIPv4(),
tsaddr.AllIPv6(), tsaddr.AllIPv6(),
}, },
expectChange: true, // Routes should be approved
}, },
} }
@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, pm) require.NotNil(t, pm)
changed1 := policy.AutoApproveRoutes(pm, &node) newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes)
assert.True(t, changed1) assert.Equal(t, tt.expectChange, changed1)
err = adb.DB.Save(&node).Error if changed1 {
require.NoError(t, err) err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1)
require.NoError(t, err)
}
_ = policy.AutoApproveRoutes(pm, &nodeTagged) newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes)
if changed2 {
err = adb.DB.Save(&nodeTagged).Error err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2)
require.NoError(t, err) require.NoError(t, err)
}
node1ByID, err := adb.GetNodeByID(1) node1ByID, err := adb.GetNodeByID(1)
require.NoError(t, err) require.NoError(t, err)
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { // For empty auto-approvers tests, handle nil vs empty slice comparison
expectedRoutes1 := tt.want
if len(expectedRoutes1) == 0 {
expectedRoutes1 = nil
}
if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
} }
node2ByID, err := adb.GetNodeByID(2) node2ByID, err := adb.GetNodeByID(2)
require.NoError(t, err) require.NoError(t, err)
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { expectedRoutes2 := tt.want2
if len(expectedRoutes2) == 0 {
expectedRoutes2 = nil
}
if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
} }
}) })
@ -620,11 +679,11 @@ func TestRenameNode(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error { err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node, nil, nil) _, err := RegisterNodeForTest(tx, node, nil, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNode(tx, node2, nil, nil) _, err = RegisterNodeForTest(tx, node2, nil, nil)
return err return err
}) })
@ -721,11 +780,11 @@ func TestListPeers(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error { err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node1, nil, nil) _, err := RegisterNodeForTest(tx, node1, nil, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNode(tx, node2, nil, nil) _, err = RegisterNodeForTest(tx, node2, nil, nil)
return err return err
}) })
@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) {
// No parameter means no filter, should return all peers // No parameter means no filter, should return all peers
nodes, err = db.ListPeers(1) nodes, err = db.ListPeers(1)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, len(nodes)) assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname) assert.Equal(t, "test2", nodes[0].Hostname)
// Empty node list should return all peers // Empty node list should return all peers
nodes, err = db.ListPeers(1, types.NodeIDs{}...) nodes, err = db.ListPeers(1, types.NodeIDs{}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, len(nodes)) assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname) assert.Equal(t, "test2", nodes[0].Hostname)
// No match in IDs should return empty list and no error // No match in IDs should return empty list and no error
@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) {
// Partial match in IDs // Partial match in IDs
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, len(nodes)) assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname) assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs, but node ID is still filtered out // Several matched IDs, but node ID is still filtered out
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, len(nodes)) assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname) assert.Equal(t, "test2", nodes[0].Hostname)
} }
@ -806,11 +865,11 @@ func TestListNodes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error { err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node1, nil, nil) _, err := RegisterNodeForTest(tx, node1, nil, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNode(tx, node2, nil, nil) _, err = RegisterNodeForTest(tx, node2, nil, nil)
return err return err
}) })
@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) {
// No parameter means no filter, should return all nodes // No parameter means no filter, should return all nodes
nodes, err = db.ListNodes() nodes, err = db.ListNodes()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, len(nodes)) assert.Len(t, nodes, 2)
assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname) assert.Equal(t, "test2", nodes[1].Hostname)
// Empty node list should return all nodes // Empty node list should return all nodes
nodes, err = db.ListNodes(types.NodeIDs{}...) nodes, err = db.ListNodes(types.NodeIDs{}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, len(nodes)) assert.Len(t, nodes, 2)
assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname) assert.Equal(t, "test2", nodes[1].Hostname)
@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) {
// Partial match in IDs // Partial match in IDs
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, len(nodes)) assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname) assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs // Several matched IDs
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, len(nodes)) assert.Len(t, nodes, 2)
assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname) assert.Equal(t, "test2", nodes[1].Hostname)
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
@ -47,8 +48,9 @@ func CreatePreAuthKey(
return nil, err return nil, err
} }
// Remove duplicates // Remove duplicates and sort for consistency
aclTags = set.SetOf(aclTags).Slice() aclTags = set.SetOf(aclTags).Slice()
slices.Sort(aclTags)
// TODO(kradalby): factor out and create a reusable tag validation, // TODO(kradalby): factor out and create a reusable tag validation,
// check if there is one in Tailscale's lib. // check if there is one in Tailscale's lib.

View File

@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
} }
// AssignNodeToUser assigns a Node to a user. // AssignNodeToUser assigns a Node to a user.
// Note: Validation should be done in the state layer before calling this function.
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error { func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
node, err := GetNodeByID(tx, nodeID) // Check if the user exists
if err != nil { var userExists bool
return err if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
} }
user, err := GetUserByID(tx, uid)
if err != nil { if !userExists {
return err return ErrUserNotFound
} }
node.User = *user
node.UserID = user.ID if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil {
if result := tx.Save(&node); result.Error != nil { return fmt.Errorf("failed to assign node to user: %w", err)
return result.Error
} }
return nil return nil

View File

@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
} }
sshPol := make(map[string]*tailcfg.SSHPolicy) sshPol := make(map[string]*tailcfg.SSHPolicy)
for _, node := range nodes { for _, node := range nodes.All() {
pol, err := h.state.SSHPolicy(node.View()) pol, err := h.state.SSHPolicy(node)
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol
} }
sshJSON, err := json.MarshalIndent(sshPol, "", " ") sshJSON, err := json.MarshalIndent(sshPol, "", " ")

View File

@ -15,7 +15,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -25,6 +24,7 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/views"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/state"
@ -59,9 +59,10 @@ func (api headscaleV1APIServer) CreateUser(
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
} }
c := change.UserAdded(types.UserID(user.ID)) c := change.UserAdded(types.UserID(user.ID))
if policyChanged {
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
if !policyChanged.Empty() {
c.Change = change.Policy c.Change = change.Policy
} }
@ -79,15 +80,13 @@ func (api headscaleV1APIServer) RenameUser(
return nil, err return nil, err
} }
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) _, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Send policy update notifications if needed // Send policy update notifications if needed
if policyChanged { api.h.Change(c)
api.h.Change(change.PolicyChange())
}
newUser, err := api.h.state.GetUserByName(request.GetNewName()) newUser, err := api.h.state.GetUserByName(request.GetNewName())
if err != nil { if err != nil {
@ -288,17 +287,13 @@ func (api headscaleV1APIServer) GetNode(
ctx context.Context, ctx context.Context,
request *v1.GetNodeRequest, request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) { ) (*v1.GetNodeResponse, error) {
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if !ok {
return nil, err return nil, status.Errorf(codes.NotFound, "node not found")
} }
resp := node.Proto() resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
return &v1.GetNodeResponse{Node: resp}, nil return &v1.GetNodeResponse{Node: resp}, nil
} }
@ -323,7 +318,8 @@ func (api headscaleV1APIServer) SetTags(
api.h.Change(nodeChange) api.h.Change(nodeChange)
log.Trace(). log.Trace().
Str("node", node.Hostname). Caller().
Str("node", node.Hostname()).
Strs("tags", request.GetTags()). Strs("tags", request.GetTags()).
Msg("Changing tags of node") Msg("Changing tags of node")
@ -334,7 +330,13 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
ctx context.Context, ctx context.Context,
request *v1.SetApprovedRoutesRequest, request *v1.SetApprovedRoutesRequest,
) (*v1.SetApprovedRoutesResponse, error) { ) (*v1.SetApprovedRoutesResponse, error) {
var routes []netip.Prefix log.Debug().
Caller().
Uint64("node.id", request.GetNodeId()).
Strs("requestedRoutes", request.GetRoutes()).
Msg("gRPC SetApprovedRoutes called")
var newApproved []netip.Prefix
for _, route := range request.GetRoutes() { for _, route := range request.GetRoutes() {
prefix, err := netip.ParsePrefix(route) prefix, err := netip.ParsePrefix(route)
if err != nil { if err != nil {
@ -344,31 +346,35 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
// If the prefix is an exit route, add both. The client expect both // If the prefix is an exit route, add both. The client expect both
// to annotate the node as an exit node. // to annotate the node as an exit node.
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() { if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6()) newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
} else { } else {
routes = append(routes, prefix) newApproved = append(newApproved, prefix)
} }
} }
tsaddr.SortPrefixes(routes) tsaddr.SortPrefixes(newApproved)
routes = slices.Compact(routes) newApproved = slices.Compact(newApproved)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes) node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
// Always propagate node changes from SetApprovedRoutes // Always propagate node changes from SetApprovedRoutes
api.h.Change(nodeChange) api.h.Change(nodeChange)
// If routes changed, propagate those changes too
if !routeChange.Empty() {
api.h.Change(routeChange)
}
proto := node.Proto() proto := node.Proto()
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID)) // Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
// routes that are actively served from the node (per architectural requirement in types/node.go)
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
log.Debug().
Caller().
Uint64("node.id", node.ID().Uint64()).
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
Strs("finalSubnetRoutes", proto.SubnetRoutes).
Msg("gRPC SetApprovedRoutes completed")
return &v1.SetApprovedRoutesResponse{Node: proto}, nil return &v1.SetApprovedRoutesResponse{Node: proto}, nil
} }
@ -390,9 +396,9 @@ func (api headscaleV1APIServer) DeleteNode(
ctx context.Context, ctx context.Context,
request *v1.DeleteNodeRequest, request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) { ) (*v1.DeleteNodeResponse, error) {
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if !ok {
return nil, err return nil, status.Errorf(codes.NotFound, "node not found")
} }
nodeChange, err := api.h.state.DeleteNode(node) nodeChange, err := api.h.state.DeleteNode(node)
@ -420,8 +426,9 @@ func (api headscaleV1APIServer) ExpireNode(
api.h.Change(nodeChange) api.h.Change(nodeChange)
log.Trace(). log.Trace().
Str("node", node.Hostname). Caller().
Time("expiry", *node.Expiry). Str("node", node.Hostname()).
Time("expiry", *node.AsStruct().Expiry).
Msg("node expired") Msg("node expired")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
@ -440,7 +447,8 @@ func (api headscaleV1APIServer) RenameNode(
api.h.Change(nodeChange) api.h.Change(nodeChange)
log.Trace(). log.Trace().
Str("node", node.Hostname). Caller().
Str("node", node.Hostname()).
Str("new_name", request.GetNewName()). Str("new_name", request.GetNewName()).
Msg("node renamed") Msg("node renamed")
@ -455,58 +463,45 @@ func (api headscaleV1APIServer) ListNodes(
// the filtering of nodes by user, vs nodes as a whole can // the filtering of nodes by user, vs nodes as a whole can
// probably be done once. // probably be done once.
// TODO(kradalby): This should be done in one tx. // TODO(kradalby): This should be done in one tx.
IsConnected := api.h.mapBatcher.ConnectedMap()
if request.GetUser() != "" { if request.GetUser() != "" {
user, err := api.h.state.GetUserByName(request.GetUser()) user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil { if err != nil {
return nil, err return nil, err
} }
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID)) nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
if err != nil {
return nil, err
}
response := nodesToProto(api.h.state, IsConnected, nodes) response := nodesToProto(api.h.state, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil return &v1.ListNodesResponse{Nodes: response}, nil
} }
nodes, err := api.h.state.ListNodes() nodes := api.h.state.ListNodes()
if err != nil {
return nil, err
}
sort.Slice(nodes, func(i, j int) bool { response := nodesToProto(api.h.state, nodes)
return nodes[i].ID < nodes[j].ID
})
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil return &v1.ListNodesResponse{Nodes: response}, nil
} }
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node { func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node {
response := make([]*v1.Node, len(nodes)) response := make([]*v1.Node, nodes.Len())
for index, node := range nodes { for index, node := range nodes.All() {
resp := node.Proto() resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
if val, ok := IsConnected.Load(node.ID); ok && val {
resp.Online = true
}
var tags []string var tags []string
for _, tag := range node.RequestTags() { for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node.View(), tag) { if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag) tags = append(tags, tag)
} }
} }
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...)) resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp response[index] = resp
} }
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
})
return response return response
} }
@ -674,17 +669,15 @@ func (api headscaleV1APIServer) SetPolicy(
// a scenario where they might be allowed if the server has no nodes // a scenario where they might be allowed if the server has no nodes
// yet, but it should help for the general case and for hot reloading // yet, but it should help for the general case and for hot reloading
// configurations. // configurations.
nodes, err := api.h.state.ListNodes() nodes := api.h.state.ListNodes()
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) _, err := api.h.state.SetPolicy([]byte(p))
}
changed, err := api.h.state.SetPolicy([]byte(p))
if err != nil { if err != nil {
return nil, fmt.Errorf("setting policy: %w", err) return nil, fmt.Errorf("setting policy: %w", err)
} }
if len(nodes) > 0 { if nodes.Len() > 0 {
_, err = api.h.state.SSHPolicy(nodes[0].View()) _, err = api.h.state.SSHPolicy(nodes.At(0))
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err) return nil, fmt.Errorf("verifying SSH rules: %w", err)
} }
@ -695,14 +688,20 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err return nil, err
} }
// Only send update if the packet filter has changed. // Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
if changed { // This ensures that routes are re-evaluated for auto-approval in cases where routes
err = api.h.state.AutoApproveNodes() // were manually disabled but could now be auto-approved with the current policy.
if err != nil { cs, err := api.h.state.ReloadPolicy()
return nil, err if err != nil {
} return nil, fmt.Errorf("reloading policy: %w", err)
}
api.h.Change(change.PolicyChange()) if len(cs) > 0 {
api.h.Change(cs...)
} else {
log.Debug().
Caller().
Msg("No policy changes to distribute because ReloadPolicy returned empty changeset")
} }
response := &v1.SetPolicyResponse{ response := &v1.SetPolicyResponse{

View File

@ -94,13 +94,19 @@ func (h *Headscale) handleVerifyRequest(
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
} }
nodes, err := h.state.ListNodes() nodes := h.state.ListNodes()
if err != nil {
return fmt.Errorf("cannot list nodes: %w", err) // Check if any node has the requested NodeKey
var nodeKeyFound bool
for _, node := range nodes.All() {
if node.NodeKey() == derpAdmitClientRequest.NodePublic {
nodeKeyFound = true
break
}
} }
resp := &tailcfg.DERPAdmitClientResponse{ resp := &tailcfg.DERPAdmitClientResponse{
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), Allow: nodeKeyFound,
} }
return json.NewEncoder(writer).Encode(resp) return json.NewEncoder(writer).Encode(resp)

View File

@ -1,6 +1,7 @@
package mapper package mapper
import ( import (
"errors"
"fmt" "fmt"
"time" "time"
@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
type Batcher interface { type Batcher interface {
Start() Start()
Close() Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool] ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet) AddWork(c change.ChangeSet)
@ -120,7 +121,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet]. // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error { func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
if nc == nil { if nc == nil {
return fmt.Errorf("nodeConnection is nil") return errors.New("nodeConnection is nil")
} }
nodeID := nc.nodeID() nodeID := nc.nodeID()

View File

@ -21,8 +21,7 @@ type LockFreeBatcher struct {
mapper *mapper mapper *mapper
workers int workers int
// Lock-free concurrent maps nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
nodes *xsync.Map[types.NodeID, *nodeConn]
connected *xsync.Map[types.NodeID, *time.Time] connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel // Work queue channel
@ -32,7 +31,6 @@ type LockFreeBatcher struct {
// Batching state // Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
batchMutex sync.RWMutex
// Metrics // Metrics
totalNodes atomic.Int64 totalNodes atomic.Int64
@ -45,65 +43,63 @@ type LockFreeBatcher struct {
// AddNode registers a new node connection with the batcher and sends an initial map response. // AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation, // It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online. // and notifies other nodes that this node has come online.
// TODO(kradalby): See if we can move the isRouter argument somewhere else. func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error { addNodeStart := time.Now()
// First validate that we can generate initial map before doing anything else
fullSelfChange := change.FullSelf(id)
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange. // Generate connection ID
// This currently means that the goroutine for the node connection will do the processing connID := generateConnectionID()
// which means that we might have uncontrolled concurrency.
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing // Create new connection entry
// it to be processed in a more controlled manner. now := time.Now()
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange) newEntry := &connectionEntry{
if err != nil { id: connID,
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) c: c,
version: version,
created: now,
} }
// Only after validation succeeds, create or update node connection // Only after validation succeeds, create or update node connection
newConn := newNodeConn(id, c, version, b.mapper) newConn := newNodeConn(id, c, version, b.mapper)
var conn *nodeConn if !loaded {
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
// Update existing connection
existing.updateConnection(c, version)
conn = existing
} else {
b.totalNodes.Add(1) b.totalNodes.Add(1)
conn = newConn conn = newConn
} }
// Mark as connected only after validation succeeds
b.connected.Store(id, nil) // nil = connected b.connected.Store(id, nil) // nil = connected
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher") if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
// Send the validated initial map // Use a blocking send with timeout for initial map since the channel should be ready
if initialMap != nil { // and we want to avoid the race condition where the receiver isn't ready yet
if err := conn.send(initialMap); err != nil { select {
// Clean up the connection state on send failure case c <- initialMap:
b.nodes.Delete(id) // Success
b.connected.Delete(id) case <-time.After(5 * time.Second):
return fmt.Errorf("failed to send initial map to node %d: %w", id, err) log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
} log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
Msg("Initial map send timed out because channel was blocked or receiver not ready")
// Notify other nodes that this node came online nodeConn.removeConnectionByChannel(c)
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter}) return fmt.Errorf("failed to send initial map to node %d: timeout", id)
} }
return nil return nil
} }
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. // RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches the current one, closes the connection, // It validates the connection channel matches one of the current connections, closes that specific connection,
// and notifies other nodes that this node has gone offline. // and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) { // Reports if the node still has active connections after removal.
// Check if this is the current connection and mark it as closed func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
if existing, ok := b.nodes.Load(id); ok { nodeConn, exists := b.nodes.Load(id)
if !existing.matchesChannel(c) { if !exists {
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring") log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher")
return // Not the current connection, not an error return false
} }
// Mark the connection as closed to prevent further sends // Mark the connection as closed to prevent further sends
if connData := existing.connData.Load(); connData != nil { if connData := existing.connData.Load(); connData != nil {
@ -111,15 +107,20 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
} }
} }
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline") // Check if node has any remaining active connections
if nodeConn.hasActiveConnections() {
log.Debug().Caller().Uint64("node.id", id.Uint64()).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection removed but keeping online because other connections remain")
return true // Node still has active connections
}
// Remove node and mark disconnected atomically // Remove node and mark disconnected atomically
b.nodes.Delete(id) b.nodes.Delete(id)
b.connected.Store(id, ptr.To(time.Now())) b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1) b.totalNodes.Add(-1)
// Notify other nodes that this node went offline return false
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
} }
// AddWork queues a change to be processed by the batcher. // AddWork queues a change to be processed by the batcher.
@ -205,15 +206,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
return return
} }
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow synchronous work processing")
}
continue continue
} }
@ -221,16 +213,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
// that should be processed and sent to the node instead of // that should be processed and sent to the node instead of
// returned to the caller. // returned to the caller.
if nc, exists := b.nodes.Load(w.nodeID); exists { if nc, exists := b.nodes.Load(w.nodeID); exists {
// Check if this connection is still active before processing // Apply change to node - this will handle offline nodes gracefully
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() { // and queue work for when they reconnect
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("skipping work for closed connection")
continue
}
err := nc.change(w.c) err := nc.change(w.c)
if err != nil { if err != nil {
b.workErrors.Add(1) b.workErrors.Add(1)
@ -240,52 +224,18 @@ func (b *LockFreeBatcher) worker(workerID int) {
Str("change", w.c.Change.String()). Str("change", w.c.Change.String()).
Msg("failed to apply change") Msg("failed to apply change")
} }
} else {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("node not found for asynchronous work - node may have disconnected")
} }
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow asynchronous work processing")
}
case <-b.ctx.Done(): case <-b.ctx.Done():
return return
} }
} }
} }
func (b *LockFreeBatcher) addWork(c change.ChangeSet) { func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
// For critical changes that need immediate processing, send directly b.addToBatch(c...)
if b.shouldProcessImmediately(c) {
if c.SelfUpdateOnly {
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
return true
})
return
}
// For non-critical changes, add to batch
b.addToBatch(c)
} }
// queueWork safely queues work // queueWork safely queues work.
func (b *LockFreeBatcher) queueWork(w work) { func (b *LockFreeBatcher) queueWork(w work) {
b.workQueuedCount.Add(1) b.workQueuedCount.Add(1)
@ -298,26 +248,21 @@ func (b *LockFreeBatcher) queueWork(w work) {
} }
} }
// shouldProcessImmediately determines if a change should bypass batching // addToBatch adds a change to the pending batch.
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool { func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
// Process these changes immediately to avoid delaying critical functionality // Short circuit if any of the changes is a full update, which
switch c.Change { // means we can skip sending individual changes.
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy: if change.HasFull(c) {
return true b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
default: b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
return false
return true
})
return
} }
} }
// addToBatch adds a change to the pending batch
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if c.SelfUpdateOnly {
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(c.NodeID, changes)
return return
} }
@ -329,15 +274,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c) changes = append(changes, c)
b.pendingChanges.Store(nodeID, changes) b.pendingChanges.Store(nodeID, changes)
return true return true
}) })
} }
// processBatchedChanges processes all pending batched changes // processBatchedChanges processes all pending batched changes.
func (b *LockFreeBatcher) processBatchedChanges() { func (b *LockFreeBatcher) processBatchedChanges() {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if b.pendingChanges == nil { if b.pendingChanges == nil {
return return
} }
@ -355,16 +298,31 @@ func (b *LockFreeBatcher) processBatchedChanges() {
// Clear the pending changes for this node // Clear the pending changes for this node
b.pendingChanges.Delete(nodeID) b.pendingChanges.Delete(nodeID)
return true return true
}) })
} }
// IsConnected is lock-free read. // IsConnected is lock-free read.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
if val, ok := b.connected.Load(id); ok { // First check if we have active connections for this node
// nil means connected if nodeConn, exists := b.nodes.Load(id); exists {
return val == nil if nodeConn.hasActiveConnections() {
return true
}
} }
// Check disconnected timestamp with grace period
val, ok := b.connected.Load(id)
if !ok {
return false
}
// nil means connected
if val == nil {
return true
}
return false return false
} }
@ -372,9 +330,26 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]() ret := xsync.NewMap[types.NodeID, bool]()
// First, add all nodes with active connections
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
if nodeConn.hasActiveConnections() {
ret.Store(id, true)
}
return true
})
// Then add all entries from the connected map
b.connected.Range(func(id types.NodeID, val *time.Time) bool { b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// nil means connected // Only add if not already added as connected above
ret.Store(id, val == nil) if _, exists := ret.Load(id); !exists {
if val == nil {
// nil means connected
ret.Store(id, true)
} else {
// timestamp means disconnected
ret.Store(id, false)
}
}
return true return true
}) })
@ -482,12 +457,21 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
return fmt.Errorf("node %d: connection closed", nc.id) return fmt.Errorf("node %d: connection closed", nc.id)
} }
// TODO(kradalby): We might need some sort of timeout here if the client is not reading // Add all entries from the connected map to capture both connected and disconnected nodes
// the channel. That might mean that we are sending to a node that has gone offline, but b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// the channel is still open. // Only add if not already processed above
connData.c <- data if _, exists := result[id]; !exists {
nc.updateCount.Add(1) // Use immediate connection status for debug (no grace period)
return nil connected := (val == nil) // nil means connected, timestamp means disconnected
result[id] = DebugNodeInfo{
Connected: connected,
ActiveConnections: 0,
}
}
return true
})
return result
} }
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {

View File

@ -27,6 +27,60 @@ type batcherTestCase struct {
fn batcherFunc fn batcherFunc
} }
// testBatcherWrapper wraps a real batcher to add online/offline notifications
// that would normally be sent by poll.go in production.
type testBatcherWrapper struct {
Batcher
state *state.State
}
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
// Mark node as online in state before AddNode to match production behavior
// This ensures the NodeStore has correct online status for change processing
if t.state != nil {
// Use Connect to properly mark node online in NodeStore but don't send its changes
_ = t.state.Connect(id)
}
// First add the node to the real batcher
err := t.Batcher.AddNode(id, c, version)
if err != nil {
return err
}
// Send the online notification that poll.go would normally send
// This ensures other nodes get notified about this node coming online
t.AddWork(change.NodeOnline(id))
return nil
}
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
// Mark node as offline in state BEFORE removing from batcher
// This ensures the NodeStore has correct offline status when the change is processed
if t.state != nil {
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
_, _ = t.state.Disconnect(id)
}
// Send the offline notification that poll.go would normally send
// Do this BEFORE removing from batcher so the change can be processed
t.AddWork(change.NodeOffline(id))
// Finally remove from the real batcher
removed := t.Batcher.RemoveNode(id, c)
if !removed {
return false
}
return true
}
// wrapBatcherForTest wraps a batcher with test-specific behavior.
func wrapBatcherForTest(b Batcher, state *state.State) Batcher {
return &testBatcherWrapper{Batcher: b, state: state}
}
// allBatcherFunctions contains all batcher implementations to test. // allBatcherFunctions contains all batcher implementations to test.
var allBatcherFunctions = []batcherTestCase{ var allBatcherFunctions = []batcherTestCase{
{"LockFree", NewBatcherAndMapper}, {"LockFree", NewBatcherAndMapper},
@ -183,8 +237,8 @@ func setupBatcherWithTestData(
"acls": [ "acls": [
{ {
"action": "accept", "action": "accept",
"users": ["*"], "src": ["*"],
"ports": ["*:*"] "dst": ["*:*"]
} }
] ]
}` }`
@ -194,8 +248,8 @@ func setupBatcherWithTestData(
t.Fatalf("Failed to set allow-all policy: %v", err) t.Fatalf("Failed to set allow-all policy: %v", err)
} }
// Create batcher with the state // Create batcher with the state and wrap it for testing
batcher := bf(cfg, state) batcher := wrapBatcherForTest(bf(cfg, state), state)
batcher.Start() batcher.Start()
testData := &TestData{ testData := &TestData{
@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start() testNode.start()
// Connect the node to the batcher // Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
time.Sleep(100 * time.Millisecond) // Let connection settle time.Sleep(100 * time.Millisecond) // Let connection settle
// Generate some work // Generate some work
@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity // Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullSet) batcher.AddWork(change.FullSet)
@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Disconnect all nodes // Disconnect all nodes
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
batcher.RemoveNode(node.n.ID, node.ch, false) batcher.RemoveNode(node.n.ID, node.ch)
} }
// Give time for final updates to process // Give time for final updates to process
@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) {
tn2 := testData.Nodes[1] tn2 := testData.Nodes[1]
// Test AddNode with real node ID // Test AddNode with real node ID
batcher.AddNode(tn.n.ID, tn.ch, false, 100) batcher.AddNode(tn.n.ID, tn.ch, 100)
if !batcher.IsConnected(tn.n.ID) { if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode") t.Error("Node should be connected after AddNode")
} }
@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) {
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
// Add the second node and verify update message // Add the second node and verify update message
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100) batcher.AddNode(tn2.n.ID, tn2.ch, 100)
assert.True(t, batcher.IsConnected(tn2.n.ID)) assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected. // First node should get an update that second node has connected.
select { select {
case data := <-tn.ch: case data := <-tn.ch:
assertOnlineMapResponse(t, data, true) assertOnlineMapResponse(t, data, true)
case <-time.After(200 * time.Millisecond): case <-time.After(500 * time.Millisecond):
t.Error("Did not receive expected Online response update") t.Error("Did not receive expected Online response update")
} }
@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) {
} }
// Disconnect the second node // Disconnect the second node
batcher.RemoveNode(tn2.n.ID, tn2.ch, false) batcher.RemoveNode(tn2.n.ID, tn2.ch)
assert.False(t, batcher.IsConnected(tn2.n.ID)) // Note: IsConnected may return true during grace period for DNS resolution
// First node should get update that second has disconnected. // First node should get update that second has disconnected.
select { select {
case data := <-tn.ch: case data := <-tn.ch:
assertOnlineMapResponse(t, data, false) assertOnlineMapResponse(t, data, false)
case <-time.After(200 * time.Millisecond): case <-time.After(500 * time.Millisecond):
t.Error("Did not receive expected Online response update") t.Error("Did not receive expected Online response update")
} }
@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) {
// } // }
// Test RemoveNode // Test RemoveNode
batcher.RemoveNode(tn.n.ID, tn.ch, false) batcher.RemoveNode(tn.n.ID, tn.ch)
if batcher.IsConnected(tn.n.ID) { // Note: IsConnected may return true during grace period for DNS resolution
t.Error("Node should be disconnected after RemoveNode") // The node is actually removed from active connections but grace period allows DNS lookups
}
}) })
} }
} }
@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10) ch := make(chan *tailcfg.MapResponse, 10)
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
// Track update content for validation // Track update content for validation
var receivedUpdates []*tailcfg.MapResponse var receivedUpdates []*tailcfg.MapResponse
@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
}() }()
// Add real work during connection chaos // Add real work during connection chaos
@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
time.Sleep(1 * time.Microsecond) time.Sleep(1 * time.Microsecond)
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
}() }()
// Remove second connection // Remove second connection
@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
time.Sleep(2 * time.Microsecond) time.Sleep(2 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch2, false) batcher.RemoveNode(testNode.n.ID, ch2)
}() }()
wg.Wait() wg.Wait()
@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5) ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work // Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPSet) batcher.AddWork(change.DERPSet)
// Consumer goroutine to validate data and detect channel issues // Consumer goroutine to validate data and detect channel issues
@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Rapid removal creates race between worker and removal // Rapid removal creates race between worker and removal
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond) time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch, false) batcher.RemoveNode(testNode.n.ID, ch)
// Give workers time to process and close channels // Give workers time to process and close channels
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for _, node := range stableNodes { for _, node := range stableNodes {
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch stableChannels[node.n.ID] = ch
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
// Monitor updates for each stable client // Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
churningChannelsMutex.Lock() churningChannelsMutex.Lock()
churningChannels[nodeID] = ch churningChannels[nodeID] = ch
churningChannelsMutex.Unlock() churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking // Consume updates to prevent blocking
go func() { go func() {
@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch, exists := churningChannels[nodeID] ch, exists := churningChannels[nodeID]
churningChannelsMutex.Unlock() churningChannelsMutex.Unlock()
if exists { if exists {
batcher.RemoveNode(nodeID, ch, false) batcher.RemoveNode(nodeID, ch)
} }
}(node.n.ID) }(node.n.ID)
} }
@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
var connectedNodesMutex sync.RWMutex var connectedNodesMutex sync.RWMutex
for i := range testNodes { for i := range testNodes {
node := &testNodes[i] node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true connectedNodes[node.n.ID] = true
connectedNodesMutex.Unlock() connectedNodesMutex.Unlock()
@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) {
connectedNodesMutex.RUnlock() connectedNodesMutex.RUnlock()
if isConnected { if isConnected {
batcher.RemoveNode(nodeID, channel, false) batcher.RemoveNode(nodeID, channel)
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[nodeID] = false connectedNodes[nodeID] = false
connectedNodesMutex.Unlock() connectedNodesMutex.Unlock()
@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) {
// Now disconnect all nodes from batcher to stop new updates // Now disconnect all nodes from batcher to stop new updates
for i := range testNodes { for i := range testNodes {
node := &testNodes[i] node := &testNodes[i]
batcher.RemoveNode(node.n.ID, node.ch, false) batcher.RemoveNode(node.n.ID, node.ch)
} }
// Give time for enhanced tracking goroutines to process any remaining data in channels // Give time for enhanced tracking goroutines to process any remaining data in channels
@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Connect nodes one at a time to avoid overwhelming the work queue // Connect nodes one at a time to avoid overwhelming the work queue
for i, node := range allNodes { for i, node := range allNodes {
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID) t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Small delay between connections to allow NodeCameOnline processing // Small delay between connections to allow NodeCameOnline processing
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Check how many peers each node should see // Check how many peers each node should see
for i, node := range allNodes { for i, node := range allNodes {
peers, err := testData.State.ListPeers(node.n.ID) peers := testData.State.ListPeers(node.n.ID)
if err != nil { t.Logf("Node %d should see %d peers from state", i, peers.Len())
t.Errorf("Error listing peers for node %d: %v", i, err)
} else {
t.Logf("Node %d should see %d peers from state", i, len(peers))
}
} }
// Send a full update - this should generate full peer lists // Send a full update - this should generate full peer lists
@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
foundFullUpdate := false foundFullUpdate := false
// Read all available updates for each node // Read all available updates for each node
for i := range len(allNodes) { for i := range allNodes {
nodeUpdates := 0 nodeUpdates := 0
t.Logf("Reading updates for node %d:", i) t.Logf("Reading updates for node %d:", i)
@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
t.Logf("=== WORK QUEUE TRACING TEST ===") t.Logf("=== WORK QUEUE TRACING TEST ===")
// Connect first node time.Sleep(100 * time.Millisecond) // Let connections settle
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d", nodes[0].n.ID)
// Wait for initial NodeCameOnline to be processed // Wait for initial NodeCameOnline to be processed
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
t.Errorf("ERROR: Received unknown update type!") t.Errorf("ERROR: Received unknown update type!")
} }
// Check if there should be peers available batcher := testData.Batcher
peers, err := testData.State.ListPeers(nodes[0].n.ID) node1 := testData.Nodes[0]
if err != nil { node2 := testData.Nodes[1]
t.Errorf("Error getting peers from state: %v", err)
} else { t.Logf("=== MULTI-CONNECTION TEST ===")
t.Logf("State shows %d peers available for this node", len(peers))
if len(peers) > 0 && len(data.Peers) == 0 { // Phase 1: Connect first node with initial connection
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers)) t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node1: %v", err)
}
// Connect second node for comparison
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node2: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
debugInfo := debugBatcher.Debug()
if info, exists := debugInfo[node1.n.ID]; exists {
t.Logf("Node1 debug info: %+v", info)
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 3 {
t.Errorf("Node1 should have 3 active connections, got %d", activeConnections)
} else {
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
}
}
if connected, ok := infoMap["connected"].(bool); ok && !connected {
t.Errorf("Node1 should show as connected with 3 active connections")
}
}
}
if info, exists := debugInfo[node2.n.ID]; exists {
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 1 {
t.Errorf("Node2 should have 1 active connection, got %d", activeConnections)
}
}
}
}
}
// Phase 5: Send update and verify ALL connections receive it
t.Logf("Phase 5: Testing update distribution to all connections...")
// Clear any existing updates from all channels
clearChannel := func(ch chan *tailcfg.MapResponse) {
for {
select {
case <-ch:
// drain
default:
return
}
}
}
clearChannel(node1.ch)
clearChannel(secondChannel)
clearChannel(thirdChannel)
clearChannel(node2.ch)
// Send a change notification from node2 (so node1 should receive it on all connections)
testChangeSet := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
batcher.AddWork(testChangeSet)
time.Sleep(100 * time.Millisecond) // Let updates propagate
// Verify all three connections for node1 receive the update
connection1Received := false
connection2Received := false
connection3Received := false
select {
case mapResp := <-node1.ch:
connection1Received = (mapResp != nil)
t.Logf("Node1 connection 1 received update: %t", connection1Received)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 1 did not receive update")
}
select {
case mapResp := <-secondChannel:
connection2Received = (mapResp != nil)
t.Logf("Node1 connection 2 received update: %t", connection2Received)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 2 did not receive update")
}
select {
case mapResp := <-thirdChannel:
connection3Received = (mapResp != nil)
t.Logf("Node1 connection 3 received update: %t", connection3Received)
case <-time.After(500 * time.Millisecond):
t.Errorf("Node1 connection 3 did not receive update")
}
if connection1Received && connection2Received && connection3Received {
t.Logf("SUCCESS: All three connections for node1 received the update")
} else {
t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t",
connection1Received, connection2Received, connection3Received)
}
// Phase 6: Test connection removal and verify remaining connections still work
t.Logf("Phase 6: Testing connection removal...")
// Remove the second connection
removed := batcher.RemoveNode(node1.n.ID, secondChannel)
if !removed {
t.Errorf("Failed to remove second connection for node1")
}
time.Sleep(50 * time.Millisecond)
// Verify debug status shows 2 connections now
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
debugInfo := debugBatcher.Debug()
if info, exists := debugInfo[node1.n.ID]; exists {
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 2 {
t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections)
} else {
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
}
} }
} }
} else { } else {

View File

@ -1,6 +1,7 @@
package mapper package mapper
import ( import (
"errors"
"net/netip" "net/netip"
"sort" "sort"
"time" "time"
@ -12,7 +13,7 @@ import (
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
) )
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse // MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
type MapResponseBuilder struct { type MapResponseBuilder struct {
resp *tailcfg.MapResponse resp *tailcfg.MapResponse
mapper *mapper mapper *mapper
@ -21,7 +22,17 @@ type MapResponseBuilder struct {
errs []error errs []error
} }
// NewMapResponseBuilder creates a new builder with basic fields set type debugType string
const (
fullResponseDebug debugType = "full"
patchResponseDebug debugType = "patch"
removeResponseDebug debugType = "remove"
changeResponseDebug debugType = "change"
derpResponseDebug debugType = "derp"
)
// NewMapResponseBuilder creates a new builder with basic fields set.
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now() now := time.Now()
return &MapResponseBuilder{ return &MapResponseBuilder{
@ -35,32 +46,39 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
} }
} }
// addError adds an error to the builder's error list // addError adds an error to the builder's error list.
func (b *MapResponseBuilder) addError(err error) { func (b *MapResponseBuilder) addError(err error) {
if err != nil { if err != nil {
b.errs = append(b.errs, err) b.errs = append(b.errs, err)
} }
} }
// hasErrors returns true if the builder has accumulated any errors // hasErrors returns true if the builder has accumulated any errors.
func (b *MapResponseBuilder) hasErrors() bool { func (b *MapResponseBuilder) hasErrors() bool {
return len(b.errs) > 0 return len(b.errs) > 0
} }
// WithCapabilityVersion sets the capability version for the response // WithCapabilityVersion sets the capability version for the response.
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder { func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
b.capVer = capVer b.capVer = capVer
return b return b
} }
// WithSelfNode adds the requesting node to the response // WithSelfNode adds the requesting node to the response.
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID) nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil { if !ok {
b.addError(err) b.addError(errors.New("node not found"))
return b return b
} }
// Always use batcher's view of online status for self node
// The batcher respects grace periods for logout scenarios
node := nodeView.AsStruct()
// if b.mapper.batcher != nil {
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
// }
_, matchers := b.mapper.state.Filter() _, matchers := b.mapper.state.Filter()
tailnode, err := tailNode( tailnode, err := tailNode(
node.View(), b.capVer, b.mapper.state, node.View(), b.capVer, b.mapper.state,
@ -74,29 +92,38 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
} }
b.resp.Node = tailnode b.resp.Node = tailnode
return b return b
} }
// WithDERPMap adds the DERP map to the response func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder {
if debugDumpMapResponsePath != "" {
b.debugType = t
}
return b
}
// WithDERPMap adds the DERP map to the response.
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder { func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct() b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct()
return b return b
} }
// WithDomain adds the domain configuration // WithDomain adds the domain configuration.
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder { func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
b.resp.Domain = b.mapper.cfg.Domain() b.resp.Domain = b.mapper.cfg.Domain()
return b return b
} }
// WithCollectServicesDisabled sets the collect services flag to false // WithCollectServicesDisabled sets the collect services flag to false.
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder { func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
b.resp.CollectServices.Set(false) b.resp.CollectServices.Set(false)
return b return b
} }
// WithDebugConfig adds debug configuration // WithDebugConfig adds debug configuration
// It disables log tailing if the mapper's LogTail is not enabled // It disables log tailing if the mapper's LogTail is not enabled.
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{ b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled, DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
@ -104,53 +131,56 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
return b return b
} }
// WithSSHPolicy adds SSH policy configuration for the requesting node // WithSSHPolicy adds SSH policy configuration for the requesting node.
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil { if !ok {
b.addError(err) b.addError(errors.New("node not found"))
return b return b
} }
sshPolicy, err := b.mapper.state.SSHPolicy(node.View()) sshPolicy, err := b.mapper.state.SSHPolicy(node)
if err != nil { if err != nil {
b.addError(err) b.addError(err)
return b return b
} }
b.resp.SSHPolicy = sshPolicy b.resp.SSHPolicy = sshPolicy
return b return b
} }
// WithDNSConfig adds DNS configuration for the requesting node // WithDNSConfig adds DNS configuration for the requesting node.
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil { if !ok {
b.addError(err) b.addError(errors.New("node not found"))
return b return b
} }
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node) b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
return b return b
} }
// WithUserProfiles adds user profiles for the requesting node and given peers // WithUserProfiles adds user profiles for the requesting node and given peers.
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder { func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil { if !ok {
b.addError(err) b.addError(errors.New("node not found"))
return b return b
} }
b.resp.UserProfiles = generateUserProfiles(node, peers) b.resp.UserProfiles = generateUserProfiles(node, peers)
return b return b
} }
// WithPacketFilters adds packet filter rules based on policy // WithPacketFilters adds packet filter rules based on policy.
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil { if !ok {
b.addError(err) b.addError(errors.New("node not found"))
return b return b
} }
@ -161,15 +191,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
// new PacketFilters field and "base" allows us to send a full update when we // new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block. // have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node.View(), filter), "base": policy.ReduceFilterRules(node, filter),
} }
return b return b
} }
// WithPeers adds full peer list with policy filtering (for full map response) // WithPeers adds full peer list with policy filtering (for full map response).
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers) tailPeers, err := b.buildTailPeers(peers)
if err != nil { if err != nil {
b.addError(err) b.addError(err)
@ -177,12 +206,12 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
} }
b.resp.Peers = tailPeers b.resp.Peers = tailPeers
return b return b
} }
// WithPeerChanges adds changed peers with policy filtering (for incremental updates) // WithPeerChanges adds changed peers with policy filtering (for incremental updates).
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder { func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers) tailPeers, err := b.buildTailPeers(peers)
if err != nil { if err != nil {
b.addError(err) b.addError(err)
@ -190,14 +219,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil
} }
b.resp.PeersChanged = tailPeers b.resp.PeersChanged = tailPeers
return b return b
} }
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting // buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) { func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
node, err := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil { if !ok {
return nil, err return nil, errors.New("node not found")
} }
filter, matchers := b.mapper.state.Filter() filter, matchers := b.mapper.state.Filter()
@ -206,15 +236,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
// access each-other at all and remove them from the peers. // access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView] var changedViews views.Slice[types.NodeView]
if len(filter) > 0 { if len(filter) > 0 {
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers) changedViews = policy.ReduceNodes(node, peers, matchers)
} else { } else {
changedViews = peers.ViewSlice() changedViews = peers
} }
tailPeers, err := tailNodes( tailPeers, err := tailNodes(
changedViews, b.capVer, b.mapper.state, changedViews, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix { func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
}, },
b.mapper.cfg) b.mapper.cfg)
if err != nil { if err != nil {
@ -229,19 +259,20 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
return tailPeers, nil return tailPeers, nil
} }
// WithPeerChangedPatch adds peer change patches // WithPeerChangedPatch adds peer change patches.
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
b.resp.PeersChangedPatch = changes b.resp.PeersChangedPatch = changes
return b return b
} }
// WithPeersRemoved adds removed peer IDs // WithPeersRemoved adds removed peer IDs.
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs { for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID()) tailscaleIDs = append(tailscaleIDs, id.NodeID())
} }
b.resp.PeersRemoved = tailscaleIDs b.resp.PeersRemoved = tailscaleIDs
return b return b
} }
@ -251,11 +282,7 @@ func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
return nil, multierr.New(b.errs...) return nil, multierr.New(b.errs...)
} }
if debugDumpMapResponsePath != "" { if debugDumpMapResponsePath != "" {
node, err := b.mapper.state.GetNodeByID(b.nodeID) writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
if err != nil {
return nil, err
}
writeDebugMapResponse(b.resp, node)
} }
return b.resp, nil return b.resp, nil

View File

@ -19,6 +19,7 @@ import (
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
"tailscale.com/types/views"
) )
const ( const (
@ -69,16 +70,18 @@ func newMapper(
} }
func generateUserProfiles( func generateUserProfiles(
node *types.Node, node types.NodeView,
peers types.Nodes, peers views.Slice[types.NodeView],
) []tailcfg.UserProfile { ) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User) userMap := make(map[uint]*types.User)
ids := make([]uint, 0, len(userMap)) ids := make([]uint, 0, len(userMap))
userMap[node.User.ID] = &node.User user := node.User()
ids = append(ids, node.User.ID) userMap[user.ID] = &user
for _, peer := range peers { ids = append(ids, user.ID)
userMap[peer.User.ID] = &peer.User for _, peer := range peers.All() {
ids = append(ids, peer.User.ID) peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
} }
slices.Sort(ids) slices.Sort(ids)
@ -95,7 +98,7 @@ func generateUserProfiles(
func generateDNSConfig( func generateDNSConfig(
cfg *types.Config, cfg *types.Config,
node *types.Node, node types.NodeView,
) *tailcfg.DNSConfig { ) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil { if cfg.TailcfgDNSConfig == nil {
return nil return nil
@ -115,12 +118,12 @@ func generateDNSConfig(
// //
// This will produce a resolver like: // This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
for _, resolver := range resolvers { for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{ attrs := url.Values{
"device_name": []string{node.Hostname}, "device_name": []string{node.Hostname()},
"device_model": []string{node.Hostinfo.OS}, "device_model": []string{node.Hostinfo().OS()},
} }
if len(node.IPs()) > 0 { if len(node.IPs()) > 0 {
@ -138,10 +141,7 @@ func (m *mapper) fullMapResponse(
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
messages ...string, messages ...string,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID) peers := m.state.ListPeers(nodeID)
if err != nil {
return nil, err
}
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer). WithCapabilityVersion(capVer).
@ -183,10 +183,7 @@ func (m *mapper) peerChangeResponse(
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID, changedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID, changedNodeID) peers := m.state.ListPeers(nodeID, changedNodeID)
if err != nil {
return nil, err
}
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer). WithCapabilityVersion(capVer).
@ -208,7 +205,8 @@ func (m *mapper) peerRemovedResponse(
func writeDebugMapResponse( func writeDebugMapResponse(
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
node *types.Node, t debugType,
nodeID types.NodeID,
) { ) {
body, err := json.MarshalIndent(resp, "", " ") body, err := json.MarshalIndent(resp, "", " ")
if err != nil { if err != nil {
@ -236,25 +234,6 @@ func writeDebugMapResponse(
} }
} }
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.state.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, peer := range peers {
online := m.batcher.IsConnected(peer.ID)
peer.IsOnline = &online
}
return peers, nil
}
// routeFilterFunc is a function that takes a node ID and returns a list of // routeFilterFunc is a function that takes a node ID and returns a list of
// netip.Prefixes that are allowed for that node. It is used to filter routes // netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node. // from the primary route manager to the node.

View File

@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
&types.Config{ &types.Config{
TailcfgDNSConfig: &dnsConfigOrig, TailcfgDNSConfig: &dnsConfigOrig,
}, },
nodeInShared1, nodeInShared1.View(),
) )
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {

View File

@ -133,13 +133,12 @@ func tailNode(
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
} }
if !node.IsOnline().Valid() || !node.IsOnline().Get() { // Set LastSeen only for offline nodes to avoid confusing Tailscale clients
// LastSeen is only set when node is // during rapid reconnection cycles. Online nodes should not have LastSeen set
// not connected to the control server. // as this can make clients interpret them as "not online" despite Online=true.
if node.LastSeen().Valid() { if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() {
lastSeen := node.LastSeen().Get() lastSeen := node.LastSeen().Get()
tNode.LastSeen = &lastSeen tNode.LastSeen = &lastSeen
}
} }
return &tNode, nil return &tNode, nil

View File

@ -13,7 +13,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"gorm.io/gorm"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -296,16 +295,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
// getAndValidateNode retrieves the node from the database using the NodeKey // getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session. // and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil { if !ok {
if errors.Is(err, gorm.ErrRecordNotFound) { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
}
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
} }
nv := node.View()
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
if ns.machineKey != nv.MachineKey() { if ns.machineKey != nv.MachineKey() {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)

View File

@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
util.LogErr(err, "could not get userinfo; only using claims from id token") util.LogErr(err, "could not get userinfo; only using claims from id token")
} }
// The user claims are now updated from the the userinfo endpoint so we can verify the user a // The user claims are now updated from the userinfo endpoint so we can verify the user
// against allowed emails, email domains, and groups. // against allowed emails, email domains, and groups.
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
httpError(writer, err) httpError(writer, err)
@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return return
} }
user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims) user, c, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
} }
// Send policy update notifications if needed // Send policy update notifications if needed
if policyChanged { a.h.Change(c)
a.h.Change(change.PolicyChange())
}
// TODO(kradalby): Is this comment right? // TODO(kradalby): Is this comment right?
// If the node exists, then the node should be reauthenticated, // If the node exists, then the node should be reauthenticated,
@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims, claims *types.OIDCClaims,
) (*types.User, bool, error) { ) (*types.User, change.ChangeSet, error) {
var user *types.User var user *types.User
var err error var err error
var newUser bool var newUser bool
var policyChanged bool var c change.ChangeSet
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) { if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, false, fmt.Errorf("creating or updating user: %w", err) return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
} }
// if the user is still not found, create a new empty user. // if the user is still not found, create a new empty user.
@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
user.FromClaim(claims) user.FromClaim(claims)
if newUser { if newUser {
user, policyChanged, err = a.h.state.CreateUser(*user) user, c, err = a.h.state.CreateUser(*user)
if err != nil { if err != nil {
return nil, false, fmt.Errorf("creating user: %w", err) return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
} }
} else { } else {
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { _, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
*u = *user *u = *user
return nil return nil
}) })
if err != nil { if err != nil {
return nil, false, fmt.Errorf("updating user: %w", err) return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
} }
} }
return user, policyChanged, nil return user, c, nil
} }
func (a *AuthProviderOIDC) handleRegistration( func (a *AuthProviderOIDC) handleRegistration(

View File

@ -7,6 +7,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -138,39 +139,74 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
return ret return ret
} }
// AutoApproveRoutes approves any route that can be autoapproved from // ApproveRoutesWithPolicy checks if the node can approve the announced routes
// the nodes perspective according to the given policy. // and returns the new list of approved routes.
// It reports true if any routes were approved. // The approved routes will include:
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes. // 1. ALL previously approved routes (regardless of whether they're still advertised)
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { // 2. New routes from announcedRoutes that can be auto-approved by policy
// This ensures that:
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
// - New routes can be auto-approved according to policy
// - Routes can only be removed by explicit admin action (not by auto-approval).
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
if pm == nil { if pm == nil {
return false return currentApproved, false
} }
nodeView := node.View()
var newApproved []netip.Prefix // Start with ALL currently approved routes - we never remove approved routes
for _, route := range nodeView.AnnouncedRoutes() { newApproved := make([]netip.Prefix, len(currentApproved))
if pm.NodeCanApproveRoute(nodeView, route) { copy(newApproved, currentApproved)
// Then, check for new routes that can be auto-approved
for _, route := range announcedRoutes {
// Skip if already approved
if slices.Contains(newApproved, route) {
continue
}
// Check if this new route can be auto-approved by policy
canApprove := pm.NodeCanApproveRoute(nv, route)
if canApprove {
newApproved = append(newApproved, route) newApproved = append(newApproved, route)
} }
} }
// Only modify ApprovedRoutes if we have new routes to approve. // Sort and deduplicate
// This prevents clearing existing approved routes when nodes tsaddr.SortPrefixes(newApproved)
// temporarily don't have announced routes during policy changes. newApproved = slices.Compact(newApproved)
if len(newApproved) > 0 { newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
combined := append(newApproved, node.ApprovedRoutes...) return route.IsValid()
tsaddr.SortPrefixes(combined) })
combined = slices.Compact(combined)
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
// Only update if the routes actually changed // Sort the current approved for comparison
if !slices.Equal(node.ApprovedRoutes, combined) { sortedCurrent := make([]netip.Prefix, len(currentApproved))
node.ApprovedRoutes = combined copy(sortedCurrent, currentApproved)
return true tsaddr.SortPrefixes(sortedCurrent)
// Only update if the routes actually changed
if !slices.Equal(sortedCurrent, newApproved) {
// Log what changed
var added, kept []netip.Prefix
for _, route := range newApproved {
if !slices.Contains(sortedCurrent, route) {
added = append(added, route)
} else {
kept = append(kept, route)
}
} }
if len(added) > 0 {
log.Debug().
Uint64("node.id", nv.ID().Uint64()).
Str("node.name", nv.Hostname()).
Strs("routes.added", util.PrefixesToString(added)).
Strs("routes.kept", util.PrefixesToString(kept)).
Int("routes.total", len(newApproved)).
Msg("Routes auto-approved by policy")
}
return newApproved, true
} }
return false return newApproved, false
} }

View File

@ -0,0 +1,339 @@
package policy
import (
"fmt"
"net/netip"
"testing"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
)
func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
user1 := types.User{
Model: gorm.Model{ID: 1},
Name: "testuser@",
}
user2 := types.User{
Model: gorm.Model{ID: 2},
Name: "otheruser@",
}
users := []types.User{user1, user2}
node1 := &types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test-node",
UserID: user1.ID,
User: user1,
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ForcedTags: []string{"tag:test"},
}
node2 := &types.Node{
ID: 2,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "other-node",
UserID: user2.ID,
User: user2,
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
// Create a policy that auto-approves specific routes
policyJSON := `{
"groups": {
"group:test": ["testuser@"]
},
"tagOwners": {
"tag:test": ["testuser@"]
},
"acls": [
{
"action": "accept",
"src": ["*"],
"dst": ["*:*"]
}
],
"autoApprovers": {
"routes": {
"10.0.0.0/8": ["testuser@", "tag:test"],
"10.1.0.0/24": ["testuser@"],
"10.2.0.0/24": ["testuser@"],
"192.168.0.0/24": ["tag:test"]
}
}
}`
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
assert.NoError(t, err)
tests := []struct {
name string
node *types.Node
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
wantApproved []netip.Prefix
wantChanged bool
description string
}{
{
name: "previously_approved_route_no_longer_advertised_should_remain",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"), // Should still be here!
},
wantChanged: false,
description: "Previously approved routes should never be removed even when no longer advertised",
},
{
name: "add_new_auto_approved_route_keeps_old_approved",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8)
netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept
},
wantChanged: true,
description: "New auto-approved routes should be added while keeping old approved routes",
},
{
name: "no_announced_routes_keeps_all_approved",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
},
announcedRoutes: []netip.Prefix{}, // No routes announced
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
description: "All approved routes should remain when no routes are announced",
},
{
name: "no_changes_when_announced_equals_approved",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: false,
description: "No changes should occur when announced routes match approved routes",
},
{
name: "auto_approve_multiple_new_routes",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved
netip.MustParsePrefix("172.16.0.0/24"), // Original kept
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
},
wantChanged: true,
description: "Multiple new routes should be auto-approved while keeping existing approved routes",
},
{
name: "node_without_permission_no_auto_approval",
node: node2, // Different node without the tag
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route
},
wantChanged: false,
description: "Routes should not be auto-approved for nodes without proper permissions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
tsaddr.SortPrefixes(tt.wantApproved)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
// Verify that all previously approved routes are still present
for _, prevRoute := range tt.currentApproved {
assert.Contains(t, gotApproved, prevRoute,
"previously approved route %s was removed - this should never happen", prevRoute)
}
})
}
}
func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
// Create a basic policy for edge case testing
aclPolicy := `
{
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
],
"autoApprovers": {
"routes": {
"10.1.0.0/24": ["test@"],
},
},
}`
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
tests := []struct {
name string
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
wantApproved []netip.Prefix
wantChanged bool
}{
{
name: "nil_policy_manager",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: false,
},
{
name: "nil_current_approved",
currentApproved: nil,
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"),
},
wantChanged: true,
},
{
name: "nil_announced_routes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: nil,
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: false,
},
{
name: "duplicate_approved_routes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.1.0.0/24"),
},
wantChanged: true,
},
{
name: "empty_slices",
currentApproved: []netip.Prefix{},
announcedRoutes: []netip.Prefix{},
wantApproved: []netip.Prefix{},
wantChanged: false,
},
}
for _, tt := range tests {
for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
// Create test user
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
}
users := []types.User{user}
// Create test node
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
// Create policy manager or use nil if specified
var pm PolicyManager
var err error
if tt.name != "nil_policy_manager" {
pm, err = pmf(users, nodes.ViewSlice())
assert.NoError(t, err)
} else {
pm = nil
}
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
// Handle nil vs empty slice comparison
if tt.wantApproved == nil {
assert.Nil(t, gotApproved, "expected nil approved routes")
} else {
tsaddr.SortPrefixes(tt.wantApproved)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
}
})
}
}
}

View File

@ -0,0 +1,361 @@
package policy
import (
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
// Test policy that allows specific routes to be auto-approved
aclPolicy := `
{
"groups": {
"group:admins": ["test@"],
},
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
],
"autoApprovers": {
"routes": {
"10.0.0.0/24": ["test@"],
"192.168.0.0/24": ["group:admins"],
"172.16.0.0/16": ["tag:approved"],
},
},
"tagOwners": {
"tag:approved": ["test@"],
},
}`
tests := []struct {
name string
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
nodeHostname string
nodeUser string
nodeTags []string
wantApproved []netip.Prefix
wantChanged bool
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
}{
{
name: "previously_approved_route_no_longer_advertised_remains",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
},
{
name: "add_new_auto_approved_route_keeps_existing",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Still advertised
netip.MustParsePrefix("192.168.0.0/24"), // New route
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group
},
wantChanged: true,
},
{
name: "no_announced_routes_keeps_all_approved",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
},
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
},
{
name: "manually_approved_route_not_in_policy_remains",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved
netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved
},
wantChanged: true,
},
{
name: "tagged_node_gets_tag_approved_routes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route
},
nodeUser: "test",
nodeTags: []string{"tag:approved"},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
},
wantChanged: true,
},
{
name: "complex_scenario_multiple_changes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised
},
wantChanged: true,
},
}
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
for _, tt := range tests {
for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
// Create test user
user := types.User{
Model: gorm.Model{ID: 1},
Name: tt.nodeUser,
}
users := []types.User{user}
// Create test node
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname,
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
ForcedTags: tt.nodeTags,
}
nodes := types.Nodes{&node}
// Create policy manager
pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, pm)
// Test ApproveRoutesWithPolicy
gotApproved, gotChanged := ApproveRoutesWithPolicy(
pm,
node.View(),
tt.currentApproved,
tt.announcedRoutes,
)
// Check change flag
assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch")
// Check approved routes match expected
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
t.Logf("Want: %v", tt.wantApproved)
t.Logf("Got: %v", gotApproved)
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
}
// Verify all previously approved routes are still present
for _, prevRoute := range tt.currentApproved {
assert.Contains(t, gotApproved, prevRoute,
"previously approved route %s was removed - this should NEVER happen", prevRoute)
}
// Verify no routes were incorrectly removed
for _, removedRoute := range tt.wantRemovedRoutes {
assert.NotContains(t, gotApproved, removedRoute,
"route %s should have been removed but wasn't", removedRoute)
}
})
}
}
}
func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
aclPolicy := `
{
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
],
"autoApprovers": {
"routes": {
"10.0.0.0/8": ["test@"],
},
},
}`
tests := []struct {
name string
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
wantApproved []netip.Prefix
wantChanged bool
}{
{
name: "nil_current_approved",
currentApproved: nil,
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: true,
},
{
name: "empty_current_approved",
currentApproved: []netip.Prefix{},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: true,
},
{
name: "duplicate_routes_handled",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: true, // Duplicates are removed, so it's a change
},
}
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
for _, tt := range tests {
for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
// Create test user
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
}
users := []types.User{user}
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err)
gotApproved, gotChanged := ApproveRoutesWithPolicy(
pm,
node.View(),
tt.currentApproved,
tt.announcedRoutes,
)
assert.Equal(t, tt.wantChanged, gotChanged)
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
}
})
}
}
}
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
}
currentApproved := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
}
announcedRoutes := []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
}
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: currentApproved,
}
// With nil policy manager, should return current approved unchanged
gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes)
assert.False(t, gotChanged)
assert.Equal(t, currentApproved, gotApproved)
}

View File

@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) {
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
canApprove: false, canApprove: false,
}, },
{
name: "policy-without-autoApprovers-section",
node: normalNode,
route: p("10.33.0.0/16"),
policy: `{
"groups": {
"group:admin": ["user1@"]
},
"acls": [
{
"action": "accept",
"src": ["group:admin"],
"dst": ["group:admin:*"]
},
{
"action": "accept",
"src": ["group:admin"],
"dst": ["10.33.0.0/16:*"]
}
]
}`,
canApprove: false,
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// The fast path is that a node requests to approve a prefix // The fast path is that a node requests to approve a prefix
// where there is an exact entry, e.g. 10.0.0.0/8, then // where there is an exact entry, e.g. 10.0.0.0/8, then
// check and return quickly // check and return quickly
if _, ok := pm.autoApproveMap[route]; ok { if approvers, ok := pm.autoApproveMap[route]; ok {
if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) { canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains)
if canApprove {
return true return true
} }
} }
@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// Check if prefix is larger (so containing) and then overlaps // Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover // the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) { canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains)
if canApprove {
return true return true
} }
} }

View File

@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock" "github.com/sasha-s/go-deadlock"
@ -112,6 +111,15 @@ func (m *mapSession) serve() {
// This is the mechanism where the node gives us information about its // This is the mechanism where the node gives us information about its
// current configuration. // current configuration.
// //
// Process the MapRequest to update node state (endpoints, hostinfo, etc.)
c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
if err != nil {
httpError(m.w, err)
return
}
m.h.Change(c)
// If OmitPeers is true and Stream is false // If OmitPeers is true and Stream is false
// then the server will let clients update their endpoints without // then the server will let clients update their endpoints without
// breaking existing long-polling (Stream == true) connections. // breaking existing long-polling (Stream == true) connections.
@ -122,14 +130,6 @@ func (m *mapSession) serve() {
// the response and just wants a 200. // the response and just wants a 200.
// !req.stream && req.OmitPeers // !req.stream && req.OmitPeers
if m.isEndpointUpdate() { if m.isEndpointUpdate() {
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
if err != nil {
httpError(m.w, err)
return
}
m.h.Change(c)
m.w.WriteHeader(http.StatusOK) m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc() mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
} }
@ -142,6 +142,8 @@ func (m *mapSession) serve() {
func (m *mapSession) serveLongPoll() { func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll() m.beforeServeLongPoll()
log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected")
// Clean up the session when the client disconnects // Clean up the session when the client disconnects
defer func() { defer func() {
m.cancelChMu.Lock() m.cancelChMu.Lock()
@ -149,18 +151,38 @@ func (m *mapSession) serveLongPoll() {
close(m.cancelCh) close(m.cancelCh)
m.cancelChMu.Unlock() m.cancelChMu.Unlock()
// TODO(kradalby): This can likely be made more effective, but likely most _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
// nodes has access to the same routes, so it might not be a big deal.
disconnectChange, err := m.h.state.Disconnect(m.node) // When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather).
if err != nil { // Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects.
m.errf(err, "Failed to disconnect node %s", m.node.Hostname) // If it does reconnect, the existing mapSession will be replaced and the node remains online.
// If it doesn't reconnect within the timeout, we mark it as offline.
//
// This avoids flapping nodes in the UI and unnecessary churn in the network.
// This is not my favourite solution, but it kind of works in our eventually consistent world.
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
disconnected := true
// Wait up to 10 seconds for the node to reconnect.
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
for range 10 {
if m.h.mapBatcher.IsConnected(m.node.ID) {
disconnected = false
break
}
<-ticker.C
} }
m.h.Change(disconnectChange)
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter()) if disconnected {
disconnectChanges, err := m.h.state.Disconnect(m.node.ID)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
}
m.afterServeLongPoll() m.h.Change(disconnectChanges...)
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
}
}() }()
// Set up the client stream // Set up the client stream
@ -172,25 +194,25 @@ func (m *mapSession) serveLongPoll() {
m.keepAliveTicker = time.NewTicker(m.keepAlive) m.keepAliveTicker = time.NewTicker(m.keepAlive)
// Add node to batcher BEFORE sending Connect change to prevent race condition // Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
// where the change is sent before the node is in the batcher's node map // CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil { // synchronized. When nodes reconnect, they send their hostinfo with announced routes
m.errf(err, "failed to add node to batcher") // in the MapRequest. We need this data in NodeStore before Connect() sets up the
// Send empty response to client to fail fast for invalid/non-existent nodes // primary routes, otherwise SubnetRoutes() returns empty and the node is removed
select { // from AvailableRoutes.
case m.ch <- &tailcfg.MapResponse{}: mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
default: if err != nil {
// Channel might be closed m.errf(err, "failed to update node from initial MapRequest")
}
return return
} }
// Now send the Connect change - the batcher handles NodeCameOnline internally // Connect the node after its state has been updated.
// but we still need to update routes and other state-level changes // We send two separate change notifications because these are distinct operations:
connectChange := m.h.state.Connect(m.node) // 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo)
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline { // 2. Connect: marks the node online and recalculates primary routes based on the updated state
m.h.Change(connectChange) // While this results in two notifications, it ensures route data is synchronized before
} // primary route selection occurs, which is critical for proper HA subnet router failover.
connectChanges := m.h.state.Connect(m.node.ID)
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -235,6 +257,7 @@ func (m *mapSession) serveLongPoll() {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
} }
mapResponseSent.WithLabelValues("ok", "keepalive").Inc() mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
m.resetKeepAlive()
} }
} }
} }

View File

@ -0,0 +1,403 @@
package state
import (
"fmt"
"maps"
"strings"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"tailscale.com/types/key"
"tailscale.com/types/views"
)
const (
batchSize = 10
batchTimeout = 500 * time.Millisecond
)
const (
put = 1
del = 2
update = 3
)
const prometheusNamespace = "headscale"
var (
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "nodestore_operations_total",
Help: "Total number of NodeStore operations",
}, []string{"operation"})
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_operation_duration_seconds",
Help: "Duration of NodeStore operations",
Buckets: prometheus.DefBuckets,
}, []string{"operation"})
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_batch_size",
Help: "Size of NodeStore write batches",
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
})
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_batch_duration_seconds",
Help: "Duration of NodeStore batch processing",
Buckets: prometheus.DefBuckets,
})
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_snapshot_build_duration_seconds",
Help: "Duration of NodeStore snapshot building from nodes",
Buckets: prometheus.DefBuckets,
})
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_nodes_total",
Help: "Total number of nodes in the NodeStore",
})
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_peers_calculation_duration_seconds",
Help: "Duration of peers calculation in NodeStore",
Buckets: prometheus.DefBuckets,
})
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_queue_depth",
Help: "Current depth of NodeStore write queue",
})
)
// NodeStore is a thread-safe store for nodes.
// It is a copy-on-write structure, replacing the "snapshot"
// when a change to the structure occurs. It is optimised for reads,
// and while batches are not fast, they are grouped together
// to do less of the expensive peer calculation if there are many
// changes rapidly.
//
// Writes will block until committed, while reads are never
// blocked. This means that the caller of a write operation
// is responsible for ensuring an update depending on a write
// is not issued before the write is complete.
type NodeStore struct {
data atomic.Pointer[Snapshot]
peersFunc PeersFunc
writeQueue chan work
}
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
nodes := make(map[types.NodeID]types.Node, len(allNodes))
for _, n := range allNodes {
nodes[n.ID] = *n
}
snap := snapshotFromNodes(nodes, peersFunc)
store := &NodeStore{
peersFunc: peersFunc,
}
store.data.Store(&snap)
// Initialize node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
return store
}
// Snapshot is the representation of the current state of the NodeStore.
// It contains all nodes and their relationships.
// It is a copy-on-write structure, meaning that when a write occurs,
// a new Snapshot is created with the updated state,
// and replaces the old one atomically.
type Snapshot struct {
// nodesByID is the main source of truth for nodes.
nodesByID map[types.NodeID]types.Node
// calculated from nodesByID
nodesByNodeKey map[key.NodePublic]types.NodeView
peersByNode map[types.NodeID][]types.NodeView
nodesByUser map[types.UserID][]types.NodeView
allNodes []types.NodeView
}
// PeersFunc is a function that takes a list of nodes and returns a map
// with the relationships between nodes and their peers.
// This will typically be used to calculate which nodes can see each other
// based on the current policy.
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
// work represents a single operation to be performed on the NodeStore.
type work struct {
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
}
// PutNode adds or updates a node in the store.
// If the node already exists, it will be replaced.
// If the node does not exist, it will be added.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) PutNode(n types.Node) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
defer timer.ObserveDuration()
work := work{
op: put,
nodeID: n.ID,
node: n,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("put").Inc()
}
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
type UpdateNodeFunc func(n *types.Node)
// UpdateNode applies a function to modify a specific node in the store.
// This is a blocking operation that waits for the write to complete.
// This is analogous to a database "transaction", or, the caller should
// rather collect all data they want to change, and then call this function.
// Fewer calls are better.
//
// TODO(kradalby): Technically we could have a version of this that modifies the node
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
// This is because the main nodesByID map contains the struct, and every other map is using a
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
// a lock around the nodesByID map to ensure that no other writes are happening
// while we are modifying the node. Which mean we would need to implement read-write locks
// on all read operations.
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
defer timer.ObserveDuration()
work := work{
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("update").Inc()
}
// DeleteNode removes a node from the store by its ID.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) DeleteNode(id types.NodeID) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
defer timer.ObserveDuration()
work := work{
op: del,
nodeID: id,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("delete").Inc()
}
// Start initializes the NodeStore and starts processing the write queue.
func (s *NodeStore) Start() {
s.writeQueue = make(chan work)
go s.processWrite()
}
// Stop stops the NodeStore.
func (s *NodeStore) Stop() {
close(s.writeQueue)
}
// processWrite processes the write queue in batches.
func (s *NodeStore) processWrite() {
c := time.NewTicker(batchTimeout)
defer c.Stop()
batch := make([]work, 0, batchSize)
for {
select {
case w, ok := <-s.writeQueue:
if !ok {
// Channel closed, apply any remaining batch and exit
if len(batch) != 0 {
s.applyBatch(batch)
}
return
}
batch = append(batch, w)
if len(batch) >= batchSize {
s.applyBatch(batch)
batch = batch[:0]
c.Reset(batchTimeout)
}
case <-c.C:
if len(batch) != 0 {
s.applyBatch(batch)
batch = batch[:0]
}
c.Reset(batchTimeout)
}
}
}
// applyBatch applies a batch of work to the node store.
// This means that it takes a copy of the current nodes,
// then applies the batch of operations to that copy,
// runs any precomputation needed (like calculating peers),
// and finally replaces the snapshot in the store with the new one.
// The replacement of the snapshot is atomic, ensuring that reads
// are never blocked by writes.
// Each write item is blocked until the batch is applied to ensure
// the caller knows the operation is complete and do not send any
// updates that are dependent on a read that is yet to be written.
func (s *NodeStore) applyBatch(batch []work) {
timer := prometheus.NewTimer(nodeStoreBatchDuration)
defer timer.ObserveDuration()
nodeStoreBatchSize.Observe(float64(len(batch)))
nodes := make(map[types.NodeID]types.Node)
maps.Copy(nodes, s.data.Load().nodesByID)
for _, w := range batch {
switch w.op {
case put:
nodes[w.nodeID] = w.node
case update:
// Update the specific node identified by nodeID
if n, exists := nodes[w.nodeID]; exists {
w.updateFn(&n)
nodes[w.nodeID] = n
}
case del:
delete(nodes, w.nodeID)
}
}
newSnap := snapshotFromNodes(nodes, s.peersFunc)
s.data.Store(&newSnap)
// Update node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
for _, w := range batch {
close(w.result)
}
}
// snapshotFromNodes creates a new Snapshot from the provided nodes.
// It builds a lot of "indexes" to make lookups fast for datasets we
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
// This is not a fast operation, it is the "slow" part of our copy-on-write
// structure, but it allows us to have fast reads and efficient lookups.
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
defer timer.ObserveDuration()
allNodes := make([]types.NodeView, 0, len(nodes))
for _, n := range nodes {
allNodes = append(allNodes, n.View())
}
newSnap := Snapshot{
nodesByID: nodes,
allNodes: allNodes,
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
// peersByNode is most likely the most expensive operation,
// it will use the list of all nodes, combined with the
// current policy to precalculate which nodes are peers and
// can see each other.
peersByNode: func() map[types.NodeID][]types.NodeView {
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
defer peersTimer.ObserveDuration()
return peersFunc(allNodes)
}(),
nodesByUser: make(map[types.UserID][]types.NodeView),
}
// Build nodesByUser and nodesByNodeKey maps
for _, n := range nodes {
nodeView := n.View()
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
}
return newSnap
}
// GetNode retrieves a node by its ID.
// The bool indicates if the node exists or is available (like "err not found").
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
// it isn't an invalid node (this is more of a node error or node is broken).
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("get").Inc()
n, exists := s.data.Load().nodesByID[id]
if !exists {
return types.NodeView{}, false
}
return n.View(), true
}
// GetNodeByNodeKey retrieves a node by its NodeKey.
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
return s.data.Load().nodesByNodeKey[nodeKey]
}
// ListNodes returns a slice of all nodes in the store.
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list").Inc()
return views.SliceOf(s.data.Load().allNodes)
}
// ListPeers returns a slice of all peers for a given node ID.
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list_peers").Inc()
return views.SliceOf(s.data.Load().peersByNode[id])
}
// ListNodesByUser returns a slice of all nodes for a given user ID.
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
return views.SliceOf(s.data.Load().nodesByUser[uid])
}

View File

@ -0,0 +1,501 @@
package state
import (
"net/netip"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
)
func TestSnapshotFromNodes(t *testing.T) {
tests := []struct {
name string
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
}{
{
name: "empty nodes",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := make(map[types.NodeID]types.Node)
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
return make(map[types.NodeID][]types.NodeView)
}
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "single node",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
assert.Len(t, snapshot.nodesByUser[1], 1)
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
},
},
{
name: "multiple nodes same user",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 1, "user1", "node2"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Each node sees the other as peer (but not itself)
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "multiple nodes different users",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 1, "user1", "node3"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// Each node should have 2 peers (all others, but not itself)
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
},
},
{
name: "odd-even peers filtering",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 3, "user3", "node3"),
4: createTestNode(4, 4, "user4", "node4"),
}
peersFunc := oddEvenPeersFunc
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 4)
assert.Len(t, snapshot.allNodes, 4)
assert.Len(t, snapshot.peersByNode, 4)
assert.Len(t, snapshot.nodesByUser, 4)
// Odd nodes should only see other odd nodes as peers
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// Even nodes should only see other even nodes as peers
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nodes, peersFunc := tt.setupFunc()
snapshot := snapshotFromNodes(nodes, peersFunc)
tt.validate(t, nodes, snapshot)
})
}
}
// Helper functions
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
now := time.Now()
machineKey := key.NewMachine()
nodeKey := key.NewNode()
discoKey := key.NewDisco()
ipv4 := netip.MustParseAddr("100.64.0.1")
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
return types.Node{
ID: nodeID,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: userID,
User: types.User{
Name: username,
DisplayName: username,
},
RegisterMethod: "test",
IPv4: &ipv4,
IPv6: &ipv6,
CreatedAt: now,
UpdatedAt: now,
}
}
// Peer functions
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
nodeIsOdd := node.ID()%2 == 1
for _, n := range nodes {
if n.ID() == node.ID() {
continue
}
peerIsOdd := n.ID()%2 == 1
// Only add peer if both are odd or both are even
if nodeIsOdd == peerIsOdd {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func TestNodeStoreOperations(t *testing.T) {
tests := []struct {
name string
setupFunc func(t *testing.T) *NodeStore
steps []testStep
}{
{
name: "create empty store and add single node",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify empty store",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "add first node",
action: func(store *NodeStore) {
node := createTestNode(1, 1, "user1", "node1")
store.PutNode(node)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
assert.Len(t, snapshot.nodesByUser[1], 1)
},
},
},
},
{
name: "create store with initial node and add more",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
initialNodes := types.Nodes{&node1}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial state",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
assert.Empty(t, snapshot.peersByNode[1])
},
},
{
name: "add second node same user",
action: func(store *NodeStore) {
node2 := createTestNode(2, 1, "user1", "node2")
store.PutNode(node2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Now both nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "add third node different user",
action: func(store *NodeStore) {
node3 := createTestNode(3, 2, "user2", "node3")
store.PutNode(node3)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// All nodes should see the other 2 as peers
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
},
},
},
},
{
name: "test node deletion",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
node3 := createTestNode(3, 2, "user2", "node3")
initialNodes := types.Nodes{&node1, &node2, &node3}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial 3 nodes",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
},
},
{
name: "delete middle node",
action: func(store *NodeStore) {
store.DeleteNode(2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 2)
// Node 2 should be gone
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
// Remaining nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// User groupings updated
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
},
},
{
name: "delete all remaining nodes",
action: func(store *NodeStore) {
store.DeleteNode(1)
store.DeleteNode(3)
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
},
},
{
name: "test node updates",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial hostnames",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
},
},
{
name: "update node hostname",
action: func(store *NodeStore) {
store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "updated-node1"
n.GivenName = "updated-node1"
})
snapshot := store.data.Load()
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
// Peers should still work correctly
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Len(t, snapshot.peersByNode[2], 1)
},
},
},
},
{
name: "test with odd-even peers filtering",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, oddEvenPeersFunc)
},
steps: []testStep{
{
name: "add nodes with odd-even filtering",
action: func(store *NodeStore) {
// Add nodes in sequence
store.PutNode(createTestNode(1, 1, "user1", "node1"))
store.PutNode(createTestNode(2, 2, "user2", "node2"))
store.PutNode(createTestNode(3, 3, "user3", "node3"))
store.PutNode(createTestNode(4, 4, "user4", "node4"))
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 4)
// Verify odd-even peer relationships
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
{
name: "delete odd node and verify even nodes unaffected",
action: func(store *NodeStore) {
store.DeleteNode(1)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
// Node 3 (odd) should now have no peers
assert.Empty(t, snapshot.peersByNode[3])
// Even nodes should still see each other
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := tt.setupFunc(t)
store.Start()
defer store.Stop()
for _, step := range tt.steps {
t.Run(step.name, func(t *testing.T) {
step.action(store)
})
}
})
}
}
type testStep struct {
name string
action func(store *NodeStore)
}

File diff suppressed because it is too large Load Diff

View File

@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool {
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate: case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
return true return true
} }
return false return false
} }

View File

@ -13,6 +13,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"go4.org/netipx" "go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
@ -355,6 +356,7 @@ func (node *Node) Proto() *v1.Node {
GivenName: node.GivenName, GivenName: node.GivenName,
User: node.User.Proto(), User: node.User.Proto(),
ForcedTags: node.ForcedTags, ForcedTags: node.ForcedTags,
Online: node.IsOnline != nil && *node.IsOnline,
// Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has
// to be populated manually with PrimaryRoute, to ensure it includes the // to be populated manually with PrimaryRoute, to ensure it includes the
@ -419,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix {
} }
// SubnetRoutes returns the list of routes that the node announces and are approved. // SubnetRoutes returns the list of routes that the node announces and are approved.
//
// IMPORTANT: This method is used for internal data structures and should NOT be used
// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually
// with PrimaryRoutes to ensure it includes only routes actively served by the node.
// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto.
func (node *Node) SubnetRoutes() []netip.Prefix { func (node *Node) SubnetRoutes() []netip.Prefix {
var routes []netip.Prefix var routes []netip.Prefix
@ -511,11 +518,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
} }
if node.Hostname != hostInfo.Hostname { if node.Hostname != hostInfo.Hostname {
log.Trace().
Str("node.id", node.ID.String()).
Str("old_hostname", node.Hostname).
Str("new_hostname", hostInfo.Hostname).
Str("old_given_name", node.GivenName).
Bool("given_name_changed", node.GivenNameHasBeenChanged()).
Msg("Updating hostname from hostinfo")
if node.GivenNameHasBeenChanged() { if node.GivenNameHasBeenChanged() {
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname) node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
} }
node.Hostname = hostInfo.Hostname node.Hostname = hostInfo.Hostname
log.Trace().
Str("node.id", node.ID.String()).
Str("new_hostname", node.Hostname).
Str("new_given_name", node.GivenName).
Msg("Hostname updated")
} }
} }
@ -759,6 +780,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix {
return v.ж.ExitRoutes() return v.ж.ExitRoutes()
} }
// RequestTags returns the ACL tags that the node is requesting.
func (v NodeView) RequestTags() []string {
if !v.Valid() || !v.Hostinfo().Valid() {
return []string{}
}
return v.Hostinfo().RequestTags().AsSlice()
}
// Proto converts the NodeView to a protobuf representation.
func (v NodeView) Proto() *v1.Node {
if !v.Valid() {
return nil
}
return v.ж.Proto()
}
// HasIP reports if a node has a given IP address. // HasIP reports if a node has a given IP address.
func (v NodeView) HasIP(i netip.Addr) bool { func (v NodeView) HasIP(i netip.Addr) bool {
if !v.Valid() { if !v.Valid() {

View File

@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
} }
// Parse each hop line // Parse each hop line
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?") hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
for i := 1; i < len(lines); i++ { for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i]) matches := hopRegex.FindStringSubmatch(lines[i])

View File

@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
assert.NoError(ct, err) assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState) assert.Equal(ct, "NeedsLogin", status.BackendState)
} }
assertTailscaleNodesLogout(t, allClients)
}, shortAccessTTL+10*time.Second, 5*time.Second) }, shortAccessTTL+10*time.Second, 5*time.Second)
} }

View File

@ -547,6 +547,8 @@ func TestUpdateHostnameFromClient(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
var nodes []*v1.Node var nodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) { assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err := executeAndUnmarshal( err := executeAndUnmarshal(
@ -642,27 +644,34 @@ func TestUpdateHostnameFromClient(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
err = executeAndUnmarshal( // Wait for nodestore batch processing to complete
headscale, // NodeStore batching timeout is 500ms, so we wait up to 1 second
[]string{ assert.Eventually(t, func() bool {
"headscale", err = executeAndUnmarshal(
"node", headscale,
"list", []string{
"--output", "headscale",
"json", "node",
}, "list",
&nodes, "--output",
) "json",
},
&nodes,
)
assertNoErr(t, err) if err != nil || len(nodes) != 3 {
assert.Len(t, nodes, 3) return false
}
for _, node := range nodes { for _, node := range nodes {
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
givenName := fmt.Sprintf("%d-givenname", node.GetId()) givenName := fmt.Sprintf("%d-givenname", node.GetId())
assert.Equal(t, hostname+"NEW", node.GetName()) if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName {
assert.Equal(t, givenName, node.GetGivenName()) return false
} }
}
return true
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix")
} }
func TestExpireNode(t *testing.T) { func TestExpireNode(t *testing.T) {

View File

@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, node.GetSubnetRoutes(), 1) assert.Len(t, node.GetSubnetRoutes(), 1)
} }
time.Sleep(5 * time.Second) // Wait for route state changes to propagate to clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assert.NoError(c, err)
// Verify that the clients can see the new routes for _, peerKey := range status.Peers() {
for _, client := range allClients { peerStatus := status.Peer[peerKey]
status, err := client.Status()
require.NoError(t, err)
for _, peerKey := range status.Peers() { assert.NotNil(c, peerStatus.PrimaryRoutes)
peerStatus := status.Peer[peerKey] assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
assert.NotNil(t, peerStatus.PrimaryRoutes) }
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
} }
} }, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
_, err = headscale.ApproveRoutes( _, err = headscale.ApproveRoutes(
1, 1,
@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate to nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
nodes, err = headscale.ListNodes() for _, node := range nodes {
require.NoError(t, err) if node.GetId() == 1 {
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
for _, node := range nodes { assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
if node.GetId() == 1 { assert.Empty(c, node.GetSubnetRoutes())
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 } else if node.GetId() == 2 {
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
assert.Empty(t, node.GetSubnetRoutes()) assert.Empty(c, node.GetApprovedRoutes())
} else if node.GetId() == 2 { assert.Empty(c, node.GetSubnetRoutes())
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 } else {
assert.Empty(t, node.GetApprovedRoutes()) assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
assert.Empty(t, node.GetSubnetRoutes()) assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
} else { assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 }
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
} }
} }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
// Verify that the clients can see the new routes // Verify that the clients can see the new routes
for _, client := range allClients { for _, client := range allClients {
@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
time.Sleep(3 * time.Second) // Wait for route configuration changes after advertising routes
var nodes []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err := headscale.ListNodes() requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0)
require.NoError(t, err) requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
}, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved")
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that no routes has been sent to the client, // Verify that no routes has been sent to the client,
// they are not yet enabled. // they are not yet enabled.
@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(3 * time.Second) // Wait for route approval on first subnet router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes() requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
require.NoError(t, err) requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
}, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route")
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the client has routes from the primary machine and can access // Verify that the client has routes from the primary machine and can access
// the webservice. // the webservice.
@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(3 * time.Second) // Wait for route approval on second subnet router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes() requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
require.NoError(t, err) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
}, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route")
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1 = subRouter1.MustStatus() srs1 = subRouter1.MustStatus()
@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(3 * time.Second) // Wait for route approval on third subnet router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes() requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
require.NoError(t, err) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0)
}, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route")
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 1, 1, 0)
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1 = subRouter1.MustStatus() srs1 = subRouter1.MustStatus()
@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, result, 13) assert.Len(t, result, 13)
tr, err = client.Traceroute(webip) // Wait for traceroute to work correctly through the expected router
require.NoError(t, err) assert.EventuallyWithT(t, func(c *assert.CollectT) {
assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) tr, err := client.Traceroute(webip)
assert.NoError(c, err)
// Get the expected router IP - use a more robust approach to handle temporary disconnections
ips, err := subRouter1.IPs()
assert.NoError(c, err)
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
var expectedIP netip.Addr
for _, ip := range ips {
if ip.Is4() {
expectedIP = ip
break
}
}
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
}, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1")
// Take down the current primary // Take down the current primary
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname()) t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter1.Down() err = subRouter1.Down()
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for router status changes after r1 goes down
assert.EventuallyWithT(t, func(c *assert.CollectT) {
srs2 = subRouter2.MustStatus()
clientStatus = client.MustStatus()
srs2 = subRouter2.MustStatus() srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
clientStatus = client.MustStatus() srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] assert.True(c, srs2PeerStatus.Online, "r2 should be online")
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] assert.True(c, srs3PeerStatus.Online, "r3 should be online")
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down")
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
require.NotNil(t, srs2PeerStatus.PrimaryRoutes) require.NotNil(t, srs2PeerStatus.PrimaryRoutes)
@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter2.Down() err = subRouter2.Down()
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for router status changes after r2 goes down
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientStatus, err = client.Status()
assert.NoError(c, err)
// TODO(kradalby): Check client status srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
// Both are expected to be down srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
// Verify that the route is not presented from either router assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
clientStatus, err = client.Status() assert.False(c, srs2PeerStatus.Online, "r2 should be offline")
require.NoError(t, err) assert.True(c, srs3PeerStatus.Online, "r3 should be online")
}, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down")
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter1.Up() err = subRouter1.Up()
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for router status changes after r1 comes back up
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientStatus, err = client.Status()
assert.NoError(c, err)
// Verify that the route is announced from subnet router 1 srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
clientStatus, err = client.Status() srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
require.NoError(t, err) srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] assert.True(c, srs1PeerStatus.Online, "r1 should be back online")
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] assert.False(c, srs2PeerStatus.Online, "r2 should still be offline")
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] assert.True(c, srs3PeerStatus.Online, "r3 should still be online")
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up")
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter2.Up() err = subRouter2.Up()
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for nodestore batch processing to complete and online status to be updated
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientStatus, err = client.Status()
assert.NoError(c, err)
// Verify that the route is announced from subnet router 1 srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
clientStatus, err = client.Status() srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
require.NoError(t, err) srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] assert.True(c, srs1PeerStatus.Online, "r1 should be online")
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] assert.True(c, srs2PeerStatus.Online, "r2 should be online")
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] assert.True(c, srs3PeerStatus.Online, "r3 should be online")
}, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2")
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname()) t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{}) _, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second) // Wait for nodestore batch processing and route state changes to complete
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes() // After disabling route on r3, r1 should become primary with 1 subnet route
require.NoError(t, err) requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) }, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3")
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname()) t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{}) _, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second) // Wait for nodestore batch processing and route state changes to complete
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes() // After disabling route on r1, r2 should become primary with 1 subnet route
require.NoError(t, err) requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) }, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1")
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()), util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
) )
time.Sleep(5 * time.Second) // Wait for route state changes after re-enabling r1
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes() requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
require.NoError(t, err) requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
assert.Len(t, nodes, 6) requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
}, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping")
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate to nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
nodes, err = headscale.ListNodes() requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
require.NoError(t, err) requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0)
require.Len(t, nodes, 2) }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 0, 0, 0)
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1, _ := subRouter1.Status() srs1, _ := subRouter1.Status()
@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) {
requireNodeRouteCount(t, nodes[0], 2, 2, 2) requireNodeRouteCount(t, nodes[0], 2, 2, 2)
requireNodeRouteCount(t, nodes[1], 2, 2, 2) requireNodeRouteCount(t, nodes[1], 2, 2, 2)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate to clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assert.NoError(c, err)
// Verify that the clients can see the new routes for _, peerKey := range status.Peers() {
for _, client := range allClients { peerStatus := status.Peer[peerKey]
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() { assert.NotNil(c, peerStatus.AllowedIPs)
peerStatus := status.Peer[peerKey] assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
require.NotNil(t, peerStatus.AllowedIPs) assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4) }
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
} }
} }, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
} }
// TestSubnetRouterMultiNetwork is an evolution of the subnet router test. // TestSubnetRouterMultiNetwork is an evolution of the subnet router test.
@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate to nodes and clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
nodes, err = headscale.ListNodes() // Verify that the routes have been sent to the client
require.NoError(t, err) status, err = user2c.Status()
assert.Len(t, nodes, 2) assert.NoError(c, err)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
// Verify that the routes have been sent to the client. for _, peerKey := range status.Peers() {
status, err = user2c.Status() peerStatus := status.Peer[peerKey]
require.NoError(t, err)
for _, peerKey := range status.Peers() { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
peerStatus := status.Peer[peerKey] requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
}
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref) }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref})
}
usernet1, err := scenario.Network("usernet1") usernet1, err := scenario.Network("usernet1")
require.NoError(t, err) require.NoError(t, err)
@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate to nodes and clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2)
nodes, err = headscale.ListNodes() // Verify that the routes have been sent to the client
require.NoError(t, err) status, err = user2c.Status()
assert.Len(t, nodes, 2) assert.NoError(c, err)
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
// Verify that the routes have been sent to the client. for _, peerKey := range status.Peers() {
status, err = user2c.Status() peerStatus := status.Peer[peerKey]
require.NoError(t, err)
for _, peerKey := range status.Peers() { requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
peerStatus := status.Peer[peerKey] }
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
}
// Tell user2c to use user1c as an exit node. // Tell user2c to use user1c as an exit node.
command = []string{ command = []string{
@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
require.NoErrorf(t, err, "failed to create scenario: %s", err) require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t) defer scenario.ShutdownAssertNoPanics(t)
var nodes []*v1.Node
opts := []hsic.Option{ opts := []hsic.Option{
hsic.WithTestName("autoapprovemulti"), hsic.WithTestName("autoapprovemulti"),
hsic.WithEmbeddedDERPServerOnly(), hsic.WithEmbeddedDERPServerOnly(),
@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
require.NoErrorf(t, err, "failed to advertise route: %s", err) require.NoErrorf(t, err, "failed to advertise route: %s", err)
} }
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route // These route should auto approve, so the node is expected to have a route
// for all counts. // for all counts.
nodes, err := headscale.ListNodes() nodes, err := headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err := client.Status() status, err := client.Status()
@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = headscale.SetPolicy(tt.pol) err = headscale.SetPolicy(tt.pol)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route // These route should auto approve, so the node is expected to have a route
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route // These route should auto approve, so the node is expected to have a route
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = headscale.SetPolicy(tt.pol) err = headscale.SetPolicy(tt.pol)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route // These route should auto approve, so the node is expected to have a route
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
_, _, err = routerSubRoute.Execute(command) _, _, err = routerSubRoute.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err) require.NoErrorf(t, err, "failed to advertise route: %s", err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route // These route should auto approve, so the node is expected to have a route
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
requireNodeRouteCount(t, nodes[1], 1, 1, 1) requireNodeRouteCount(t, nodes[1], 1, 1, 1)
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
_, _, err = routerSubRoute.Execute(command) _, _, err = routerSubRoute.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err) require.NoErrorf(t, err, "failed to advertise route: %s", err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route // These route should auto approve, so the node is expected to have a route
// for all counts. // for all counts.
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
requireNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 0, 0, 0) requireNodeRouteCount(t, nodes[2], 0, 0, 0)
@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
_, _, err = routerExitNode.Execute(command) _, _, err = routerExitNode.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err) require.NoErrorf(t, err, "failed to advertise route: %s", err)
time.Sleep(5 * time.Second) // Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes() nodes, err = headscale.ListNodes()
require.NoError(t, err) assert.NoError(c, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 2, 2, 2) requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client. // Verify that the routes have been sent to the client.
status, err = client.Status() status, err = client.Status()
@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) {
require.Equal(t, tr.Route[0].IP, ip) require.Equal(t, tr.Route[0].IP, ip)
} }
// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT
func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) {
assert.NotNil(c, tr)
assert.True(c, tr.Success)
assert.NoError(c, tr.Err)
assert.NotEmpty(c, tr.Route)
assert.Equal(c, tr.Route[0].IP, ip)
}
// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes. // requirePeerSubnetRoutes asserts that the peer has the expected subnet routes.
func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) { func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) {
t.Helper() t.Helper()
@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
} }
} }
func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) {
if status.AllowedIPs.Len() <= 2 && len(expected) != 0 {
assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected))
return
}
if len(expected) == 0 {
expected = []netip.Prefix{}
}
got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool {
if tsaddr.IsExitRoute(p) {
return true
}
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
})
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff))
}
}
func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) { func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
t.Helper() t.Helper()
require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
} }
func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) {
assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
}
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes // TestSubnetRouteACLFiltering tests that a node can only access subnet routes
// that are explicitly allowed in the ACL. // that are explicitly allowed in the ACL.
func TestSubnetRouteACLFiltering(t *testing.T) { func TestSubnetRouteACLFiltering(t *testing.T) {
@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
// Give some time for the routes to propagate // Wait for route state changes to propagate
time.Sleep(5 * time.Second) assert.EventuallyWithT(t, func(c *assert.CollectT) {
// List nodes and verify the router has 3 available routes
nodes, err = headscale.NodesByUser()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
// List nodes and verify the router has 3 available routes // Find the router node
nodes, err = headscale.NodesByUser() routerNode = nodes[routerUser][0]
require.NoError(t, err)
require.Len(t, nodes, 2)
// Find the router node // Check that the router has 3 routes now approved and available
routerNode = nodes[routerUser][0] requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Check that the router has 3 routes now approved and available
requireNodeRouteCount(t, routerNode, 3, 3, 3)
// Now check the client node status // Now check the client node status
nodeStatus, err := nodeClient.Status() nodeStatus, err := nodeClient.Status()

View File

@ -14,7 +14,6 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"os" "os"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
return nil, fmt.Errorf("no network named: %s", name) return nil, fmt.Errorf("no network named: %s", name)
} }
for _, ipam := range net.Network.IPAM.Config { if len(net.Network.IPAM.Config) == 0 {
pref, err := netip.ParsePrefix(ipam.Subnet) return nil, fmt.Errorf("no IPAM config found in network: %s", name)
if err != nil {
return nil, err
}
return &pref, nil
} }
return nil, fmt.Errorf("no prefix found in network: %s", name) pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet)
if err != nil {
return nil, err
}
return &pref, nil
} }
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv(
return err return err
} }
sort.Strings(s.spec.Users)
for _, user := range s.spec.Users { for _, user := range s.spec.Users {
u, err := s.CreateUser(user) u, err := s.CreateUser(user)
if err != nil { if err != nil {