mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
state/nodestore: in memory representation of nodes
Initial work on a nodestore which stores all of the nodes and their relations in memory with relationship for peers precalculated. It is a copy-on-write structure, replacing the "snapshot" when a change to the structure occurs. It is optimised for reads, and while batches are not fast, they are grouped together to do less of the expensive peer calculation if there are many changes rapidly. Writes will block until commited, while reads are never blocked. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b155f30ef6
commit
2e20652fdf
@ -551,13 +551,12 @@ be assigned to nodes.`,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if confirm || force {
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force })
|
||||
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force})
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
@ -137,9 +137,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
||||
// Initialize ephemeral garbage collector
|
||||
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||
node, err := app.state.GetNodeByID(ni)
|
||||
if err != nil {
|
||||
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
|
||||
node, ok := app.state.GetNodeByID(ni)
|
||||
if !ok {
|
||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
|
||||
return
|
||||
}
|
||||
|
||||
@ -379,15 +380,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("HTTP authentication invoked")
|
||||
|
||||
authHeader := req.Header.Get("authorization")
|
||||
|
||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||
log.Error().
|
||||
if err := func() error {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
@ -501,11 +495,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
var err error
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
|
||||
if profilingEnabled {
|
||||
if profilingPath != "" {
|
||||
err := os.MkdirAll(profilingPath, os.ModePerm)
|
||||
err = os.MkdirAll(profilingPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||
}
|
||||
@ -559,12 +554,9 @@ func (h *Headscale) Serve() error {
|
||||
// around between restarts, they will reconnect and the GC will
|
||||
// be cancelled.
|
||||
go h.ephemeralGC.Start()
|
||||
ephmNodes, err := h.state.ListEphemeralNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
|
||||
}
|
||||
for _, node := range ephmNodes {
|
||||
h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout)
|
||||
ephmNodes := h.state.ListEphemeralNodes()
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
|
||||
if h.cfg.DNSConfig.ExtraRecordsPath != "" {
|
||||
@ -794,23 +786,14 @@ func (h *Headscale) Serve() error {
|
||||
continue
|
||||
}
|
||||
|
||||
changed, err := h.state.ReloadPolicy()
|
||||
changes, err := h.state.ReloadPolicy()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("reloading policy")
|
||||
continue
|
||||
}
|
||||
|
||||
if changed {
|
||||
log.Info().
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
h.Change(changes...)
|
||||
|
||||
err = h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
||||
}
|
||||
|
||||
h.Change(change.PolicySet)
|
||||
}
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
log.Info().
|
||||
@ -1020,6 +1003,6 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
// Change is used to send changes to nodes.
|
||||
// All change should be enqueued here and empty will be automatically
|
||||
// ignored.
|
||||
func (h *Headscale) Change(c change.ChangeSet) {
|
||||
h.mapBatcher.AddWork(c)
|
||||
func (h *Headscale) Change(cs ...change.ChangeSet) {
|
||||
h.mapBatcher.AddWork(cs...)
|
||||
}
|
||||
|
@ -12,7 +12,6 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@ -29,28 +28,10 @@ func (h *Headscale) handleRegister(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("looking up node in database: %w", err)
|
||||
}
|
||||
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||
|
||||
if node != nil {
|
||||
// If an existing node is trying to register with an auth key,
|
||||
// we need to validate the auth key even for existing nodes
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||
if ok {
|
||||
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
}
|
||||
@ -70,6 +51,7 @@ func (h *Headscale) handleRegister(
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||
}
|
||||
|
||||
@ -89,13 +71,22 @@ func (h *Headscale) handleExistingNode(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
if node.MachineKey != machineKey {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
|
||||
// If the node is expired and this is not a re-authentication attempt,
|
||||
// force the client to re-authenticate
|
||||
if expired && regReq.Auth == nil {
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
AuthURL: "", // Client will need to re-authenticate
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !expired && !regReq.Expiry.IsZero() {
|
||||
requestExpiry := regReq.Expiry
|
||||
|
||||
@ -107,7 +98,7 @@ func (h *Headscale) handleExistingNode(
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
c, err := h.state.DeleteNode(node)
|
||||
c, err := h.state.DeleteNode(node.View())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
@ -118,15 +109,19 @@ func (h *Headscale) handleExistingNode(
|
||||
}
|
||||
}
|
||||
|
||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
// CRITICAL: Use the updated node view for the response
|
||||
// The original node object has stale expiry information
|
||||
node = updatedNode.AsStruct()
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
@ -177,7 +172,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
)
|
||||
@ -193,8 +188,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If node is nil, it means an ephemeral node was deleted during logout
|
||||
if node == nil {
|
||||
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||
if !node.Valid() {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
}
|
||||
@ -213,26 +208,30 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||
// now since we dont update the node/pol here anymore
|
||||
routeChange := h.state.AutoApproveRoutes(node)
|
||||
|
||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
if routeChange && changed.Empty() {
|
||||
changed = change.NodeAdded(node.ID)
|
||||
changed = change.NodeAdded(node.ID())
|
||||
}
|
||||
h.Change(changed)
|
||||
|
||||
// If policy changed due to node registration, send a separate policy change
|
||||
if policyChanged {
|
||||
policyChange := change.PolicyChange()
|
||||
h.Change(policyChange)
|
||||
}
|
||||
// TODO(kradalby): I think this is covered above, but we need to validate that.
|
||||
// // If policy changed due to node registration, send a separate policy change
|
||||
// if policyChanged {
|
||||
// policyChange := change.PolicyChange()
|
||||
// h.Change(policyChange)
|
||||
// }
|
||||
|
||||
user := node.User()
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
MachineAuthorized: true,
|
||||
NodeKeyExpired: node.IsExpired(),
|
||||
User: *node.User.TailscaleUser(),
|
||||
Login: *node.User.TailscaleLogin(),
|
||||
User: *user.TailscaleUser(),
|
||||
Login: *user.TailscaleLogin(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -266,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
)
|
||||
|
||||
log.Info().Msgf("Starting node registration using key: %s", registrationId)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
AuthURL: h.authProvider.AuthURL(registrationId),
|
||||
}, nil
|
||||
|
@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
||||
}
|
||||
|
||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||
// and renames it. If the name is not unique, it will return an error.
|
||||
// and renames it. Validation should be done in the state layer before calling this function.
|
||||
func RenameNode(tx *gorm.DB,
|
||||
nodeID types.NodeID, newName string,
|
||||
) error {
|
||||
err := util.CheckForFQDNRules(
|
||||
newName,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming node: %w", err)
|
||||
// Check if the new name is unique
|
||||
var count int64
|
||||
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
||||
}
|
||||
|
||||
uniq, err := isUniqueName(tx, newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if name is unique: %w", err)
|
||||
}
|
||||
|
||||
if !uniq {
|
||||
return fmt.Errorf("name is not unique: %s", newName)
|
||||
if count > 0 {
|
||||
return errors.New("name is not unique")
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
@ -333,108 +327,19 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
||||
})
|
||||
}
|
||||
|
||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
||||
// with a registrationID to register or reauthenticate a node.
|
||||
// If the node found in the registration cache is not already registered,
|
||||
// it will be registered with the user and the node will be removed from the cache.
|
||||
// If the node is already registered, the expiry will be updated.
|
||||
// The node, and a boolean indicating if it was a new node or not, will be returned.
|
||||
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationID types.RegistrationID,
|
||||
userID types.UserID,
|
||||
nodeExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
var nodeChange change.ChangeSet
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
user, err := GetUserByID(tx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register node from auth callback, %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
// RegisterNodeForTest is used only for testing purposes to register a node directly in the database.
|
||||
// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey.
|
||||
func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
if !testing.Testing() {
|
||||
panic("RegisterNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registration_id", registrationID.String()).
|
||||
Str("username", user.Username()).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||
Msg("Registering node from API/CLI or auth callback")
|
||||
|
||||
// TODO(kradalby): This looks quite wrong? why ID 0?
|
||||
// Why not always?
|
||||
// Registration of expired node with different user
|
||||
if reg.Node.ID != 0 &&
|
||||
reg.Node.UserID != user.ID {
|
||||
return nil, ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
reg.Node.UserID = user.ID
|
||||
reg.Node.User = *user
|
||||
reg.Node.RegisterMethod = registrationMethod
|
||||
|
||||
if nodeExpiry != nil {
|
||||
reg.Node.Expiry = nodeExpiry
|
||||
}
|
||||
|
||||
node, err := RegisterNode(
|
||||
tx,
|
||||
reg.Node,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
hsdb.regCache.Delete(registrationID)
|
||||
}
|
||||
|
||||
// Signal to waiting clients that the machine has been registered.
|
||||
select {
|
||||
case reg.Registered <- node:
|
||||
default:
|
||||
}
|
||||
close(reg.Registered)
|
||||
|
||||
nodeChange = change.NodeAdded(node.ID)
|
||||
|
||||
return node, err
|
||||
} else {
|
||||
// If the node is already registered, this is a refresh.
|
||||
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeChange = change.KeyExpiry(node.ID)
|
||||
|
||||
return node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, nodeChange, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
return RegisterNode(tx, node, ipv4, ipv6)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
||||
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Username()).
|
||||
Msg("Registering node")
|
||||
Msg("Registering test node")
|
||||
|
||||
// If the a new node is registered with the same machine key, to the same user,
|
||||
// update the existing node.
|
||||
@ -445,8 +350,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
node.ID = oldNode.ID
|
||||
node.GivenName = oldNode.GivenName
|
||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||
ipv4 = oldNode.IPv4
|
||||
ipv6 = oldNode.IPv6
|
||||
// Don't overwrite the provided IPs with old ones when they exist
|
||||
if ipv4 == nil {
|
||||
ipv4 = oldNode.IPv4
|
||||
}
|
||||
if ipv6 == nil {
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
}
|
||||
|
||||
// If the node exists and it already has IP(s), we just save it
|
||||
@ -463,7 +373,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Username()).
|
||||
Msg("Node authorized again")
|
||||
Msg("Test node authorized again")
|
||||
|
||||
return &node, nil
|
||||
}
|
||||
@ -472,7 +382,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
node.IPv6 = ipv6
|
||||
|
||||
if node.GivenName == "" {
|
||||
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
|
||||
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||
}
|
||||
@ -487,7 +397,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Msg("Node registered with the database")
|
||||
Msg("Test node registered with the database")
|
||||
|
||||
return &node, nil
|
||||
}
|
||||
@ -560,7 +470,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
||||
return len(nodes) == 0, nil
|
||||
}
|
||||
|
||||
func ensureUniqueGivenName(
|
||||
// EnsureUniqueGivenName generates a unique given name for a node based on its hostname.
|
||||
func EnsureUniqueGivenName(
|
||||
tx *gorm.DB,
|
||||
name string,
|
||||
) (string, error) {
|
||||
@ -781,19 +692,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
||||
|
||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||
|
||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, *node, nil, nil)
|
||||
// Allocate IPs for the test node using the database's IP allocator
|
||||
// This is a simplified allocation for testing - in production this would use State.ipAlloc
|
||||
ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err))
|
||||
}
|
||||
|
||||
var registeredNode *types.Node
|
||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||
}
|
||||
|
||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
||||
}
|
||||
|
||||
return registeredNode
|
||||
}
|
||||
|
||||
@ -842,3 +757,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
// allocateTestIPs allocates sequential test IPs for nodes during testing.
|
||||
func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) {
|
||||
if !testing.Testing() {
|
||||
panic("allocateTestIPs can only be called during tests")
|
||||
}
|
||||
|
||||
// Use simple sequential allocation for tests
|
||||
// IPv4: 100.64.0.x (where x is nodeID)
|
||||
// IPv6: fd7a:115c:a1e0::x (where x is nodeID)
|
||||
|
||||
if nodeID > 254 {
|
||||
return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID)
|
||||
}
|
||||
|
||||
ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)})
|
||||
ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)})
|
||||
|
||||
return &ipv4, &ipv6, nil
|
||||
}
|
||||
|
@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
|
||||
func TestAutoApproveRoutes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
want2 []netip.Prefix
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
want2 []netip.Prefix
|
||||
expectChange bool // whether to expect route changes
|
||||
}{
|
||||
{
|
||||
name: "no-auto-approvers-empty-policy",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:admins": ["test@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admins"],
|
||||
"dst": ["group:admins:*"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
want: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||
want2: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||
expectChange: false, // No changes expected
|
||||
},
|
||||
{
|
||||
name: "no-auto-approvers-explicit-empty",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:admins": ["test@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admins"],
|
||||
"dst": ["group:admins:*"]
|
||||
}
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {},
|
||||
"exitNode": []
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||
want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||
expectChange: false, // No changes expected
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub-kube",
|
||||
acl: `
|
||||
@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
expectChange: true, // Routes should be approved
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub-exit-tag",
|
||||
@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
expectChange: true, // Routes should be approved
|
||||
},
|
||||
}
|
||||
|
||||
@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pm)
|
||||
|
||||
changed1 := policy.AutoApproveRoutes(pm, &node)
|
||||
assert.True(t, changed1)
|
||||
newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes)
|
||||
assert.Equal(t, tt.expectChange, changed1)
|
||||
|
||||
err = adb.DB.Save(&node).Error
|
||||
require.NoError(t, err)
|
||||
if changed1 {
|
||||
err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
|
||||
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes)
|
||||
if changed2 {
|
||||
err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
node1ByID, err := adb.GetNodeByID(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
// For empty auto-approvers tests, handle nil vs empty slice comparison
|
||||
expectedRoutes1 := tt.want
|
||||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
node2ByID, err := adb.GetNodeByID(2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
expectedRoutes2 := tt.want2
|
||||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@ -620,11 +679,11 @@ func TestRenameNode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node, nil, nil)
|
||||
_, err := RegisterNodeForTest(tx, node, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
})
|
||||
@ -721,11 +780,11 @@ func TestListPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node1, nil, nil)
|
||||
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
})
|
||||
@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) {
|
||||
// No parameter means no filter, should return all peers
|
||||
nodes, err = db.ListPeers(1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Empty node list should return all peers
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// No match in IDs should return empty list and no error
|
||||
@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs, but node ID is still filtered out
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
}
|
||||
|
||||
@ -806,11 +865,11 @@ func TestListNodes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node1, nil, nil)
|
||||
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
})
|
||||
@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) {
|
||||
// No parameter means no filter, should return all nodes
|
||||
nodes, err = db.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
// Empty node list should return all nodes
|
||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -47,8 +48,9 @@ func CreatePreAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove duplicates
|
||||
// Remove duplicates and sort for consistency
|
||||
aclTags = set.SetOf(aclTags).Slice()
|
||||
slices.Sort(aclTags)
|
||||
|
||||
// TODO(kradalby): factor out and create a reusable tag validation,
|
||||
// check if there is one in Tailscale's lib.
|
||||
|
@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
||||
}
|
||||
|
||||
// AssignNodeToUser assigns a Node to a user.
|
||||
// Note: Validation should be done in the state layer before calling this function.
|
||||
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
|
||||
node, err := GetNodeByID(tx, nodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
// Check if the user exists
|
||||
var userExists bool
|
||||
if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
user, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if !userExists {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
node.User = *user
|
||||
node.UserID = user.ID
|
||||
if result := tx.Save(&node); result.Error != nil {
|
||||
return result.Error
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil {
|
||||
return fmt.Errorf("failed to assign node to user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
}
|
||||
|
||||
sshPol := make(map[string]*tailcfg.SSHPolicy)
|
||||
for _, node := range nodes {
|
||||
pol, err := h.state.SSHPolicy(node.View())
|
||||
for _, node := range nodes.All() {
|
||||
pol, err := h.state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol
|
||||
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol
|
||||
}
|
||||
|
||||
sshJSON, err := json.MarshalIndent(sshPol, "", " ")
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -25,6 +24,7 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
@ -59,9 +59,10 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
}
|
||||
|
||||
|
||||
c := change.UserAdded(types.UserID(user.ID))
|
||||
if policyChanged {
|
||||
|
||||
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
|
||||
if !policyChanged.Empty() {
|
||||
c.Change = change.Policy
|
||||
}
|
||||
|
||||
@ -79,15 +80,13 @@ func (api headscaleV1APIServer) RenameUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||
_, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
api.h.Change(c)
|
||||
|
||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||
if err != nil {
|
||||
@ -288,17 +287,13 @@ func (api headscaleV1APIServer) GetNode(
|
||||
ctx context.Context,
|
||||
request *v1.GetNodeRequest,
|
||||
) (*v1.GetNodeResponse, error) {
|
||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||
}
|
||||
|
||||
resp := node.Proto()
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
||||
|
||||
return &v1.GetNodeResponse{Node: resp}, nil
|
||||
}
|
||||
|
||||
@ -323,7 +318,8 @@ func (api headscaleV1APIServer) SetTags(
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Caller().
|
||||
Str("node", node.Hostname()).
|
||||
Strs("tags", request.GetTags()).
|
||||
Msg("Changing tags of node")
|
||||
|
||||
@ -334,7 +330,13 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
ctx context.Context,
|
||||
request *v1.SetApprovedRoutesRequest,
|
||||
) (*v1.SetApprovedRoutesResponse, error) {
|
||||
var routes []netip.Prefix
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", request.GetNodeId()).
|
||||
Strs("requestedRoutes", request.GetRoutes()).
|
||||
Msg("gRPC SetApprovedRoutes called")
|
||||
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range request.GetRoutes() {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
@ -344,31 +346,35 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
// If the prefix is an exit route, add both. The client expect both
|
||||
// to annotate the node as an exit node.
|
||||
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
||||
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||
newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||
} else {
|
||||
routes = append(routes, prefix)
|
||||
newApproved = append(newApproved, prefix)
|
||||
}
|
||||
}
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
// Always propagate node changes from SetApprovedRoutes
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
// If routes changed, propagate those changes too
|
||||
if !routeChange.Empty() {
|
||||
api.h.Change(routeChange)
|
||||
}
|
||||
|
||||
proto := node.Proto()
|
||||
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
|
||||
// Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
|
||||
// routes that are actively served from the node (per architectural requirement in types/node.go)
|
||||
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
|
||||
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
||||
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
||||
Msg("gRPC SetApprovedRoutes completed")
|
||||
|
||||
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
||||
}
|
||||
@ -390,9 +396,9 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||
ctx context.Context,
|
||||
request *v1.DeleteNodeRequest,
|
||||
) (*v1.DeleteNodeResponse, error) {
|
||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||
}
|
||||
|
||||
nodeChange, err := api.h.state.DeleteNode(node)
|
||||
@ -420,8 +426,9 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Time("expiry", *node.Expiry).
|
||||
Caller().
|
||||
Str("node", node.Hostname()).
|
||||
Time("expiry", *node.AsStruct().Expiry).
|
||||
Msg("node expired")
|
||||
|
||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||
@ -440,7 +447,8 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Caller().
|
||||
Str("node", node.Hostname()).
|
||||
Str("new_name", request.GetNewName()).
|
||||
Msg("node renamed")
|
||||
|
||||
@ -455,58 +463,45 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
// the filtering of nodes by user, vs nodes as a whole can
|
||||
// probably be done once.
|
||||
// TODO(kradalby): This should be done in one tx.
|
||||
|
||||
IsConnected := api.h.mapBatcher.ConnectedMap()
|
||||
if request.GetUser() != "" {
|
||||
user, err := api.h.state.GetUserByName(request.GetUser())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
response := nodesToProto(api.h.state, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
nodes, err := api.h.state.ListNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes := api.h.state.ListNodes()
|
||||
|
||||
sort.Slice(nodes, func(i, j int) bool {
|
||||
return nodes[i].ID < nodes[j].ID
|
||||
})
|
||||
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
response := nodesToProto(api.h.state, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
response := make([]*v1.Node, len(nodes))
|
||||
for index, node := range nodes {
|
||||
func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node {
|
||||
response := make([]*v1.Node, nodes.Len())
|
||||
for index, node := range nodes.All() {
|
||||
resp := node.Proto()
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
if val, ok := IsConnected.Load(node.ID); ok && val {
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if state.NodeCanHaveTag(node.View(), tag) {
|
||||
if state.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
|
||||
|
||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
|
||||
response[index] = resp
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
})
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@ -674,17 +669,15 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
// a scenario where they might be allowed if the server has no nodes
|
||||
// yet, but it should help for the general case and for hot reloading
|
||||
// configurations.
|
||||
nodes, err := api.h.state.ListNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||
}
|
||||
changed, err := api.h.state.SetPolicy([]byte(p))
|
||||
nodes := api.h.state.ListNodes()
|
||||
|
||||
_, err := api.h.state.SetPolicy([]byte(p))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting policy: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = api.h.state.SSHPolicy(nodes[0].View())
|
||||
if nodes.Len() > 0 {
|
||||
_, err = api.h.state.SSHPolicy(nodes.At(0))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||
}
|
||||
@ -695,14 +688,20 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only send update if the packet filter has changed.
|
||||
if changed {
|
||||
err = api.h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
|
||||
// This ensures that routes are re-evaluated for auto-approval in cases where routes
|
||||
// were manually disabled but could now be auto-approved with the current policy.
|
||||
cs, err := api.h.state.ReloadPolicy()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reloading policy: %w", err)
|
||||
}
|
||||
|
||||
api.h.Change(change.PolicyChange())
|
||||
if len(cs) > 0 {
|
||||
api.h.Change(cs...)
|
||||
} else {
|
||||
log.Debug().
|
||||
Caller().
|
||||
Msg("No policy changes to distribute because ReloadPolicy returned empty changeset")
|
||||
}
|
||||
|
||||
response := &v1.SetPolicyResponse{
|
||||
|
@ -94,13 +94,19 @@ func (h *Headscale) handleVerifyRequest(
|
||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
|
||||
}
|
||||
|
||||
nodes, err := h.state.ListNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot list nodes: %w", err)
|
||||
nodes := h.state.ListNodes()
|
||||
|
||||
// Check if any node has the requested NodeKey
|
||||
var nodeKeyFound bool
|
||||
for _, node := range nodes.All() {
|
||||
if node.NodeKey() == derpAdmitClientRequest.NodePublic {
|
||||
nodeKeyFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp := &tailcfg.DERPAdmitClientResponse{
|
||||
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
|
||||
Allow: nodeKeyFound,
|
||||
}
|
||||
|
||||
return json.NewEncoder(writer).Encode(resp)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
@ -120,7 +121,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
return errors.New("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
@ -21,8 +21,7 @@ type LockFreeBatcher struct {
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
@ -32,7 +31,6 @@ type LockFreeBatcher struct {
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
@ -45,65 +43,63 @@ type LockFreeBatcher struct {
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
addNodeStart := time.Now()
|
||||
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
// Generate connection ID
|
||||
connID := generateConnectionID()
|
||||
|
||||
// Create new connection entry
|
||||
now := time.Now()
|
||||
newEntry := &connectionEntry{
|
||||
id: connID,
|
||||
c: c,
|
||||
version: version,
|
||||
created: now,
|
||||
}
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
|
||||
var conn *nodeConn
|
||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||
// Update existing connection
|
||||
existing.updateConnection(c, version)
|
||||
conn = existing
|
||||
} else {
|
||||
if !loaded {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
if err := conn.send(initialMap); err != nil {
|
||||
// Clean up the connection state on send failure
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
// Use a blocking send with timeout for initial map since the channel should be ready
|
||||
// and we want to avoid the race condition where the receiver isn't ready yet
|
||||
select {
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
||||
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
}
|
||||
// It validates the connection channel matches one of the current connections, closes that specific connection,
|
||||
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
||||
// Reports if the node still has active connections after removal.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
nodeConn, exists := b.nodes.Load(id)
|
||||
if !exists {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher")
|
||||
return false
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
@ -111,15 +107,20 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
// Check if node has any remaining active connections
|
||||
if nodeConn.hasActiveConnections() {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection removed but keeping online because other connections remain")
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
@ -205,15 +206,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@ -221,16 +213,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Check if this connection is still active before processing
|
||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply change to node - this will handle offline nodes gracefully
|
||||
// and queue work for when they reconnect
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
@ -240,52 +224,18 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow asynchronous work processing")
|
||||
}
|
||||
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
// For critical changes that need immediate processing, send directly
|
||||
if b.shouldProcessImmediately(c) {
|
||||
if c.SelfUpdateOnly {
|
||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||
return
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For non-critical changes, add to batch
|
||||
b.addToBatch(c)
|
||||
func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
|
||||
b.addToBatch(c...)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
// queueWork safely queues work.
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
@ -298,26 +248,21 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
// addToBatch adds a change to the pending batch.
|
||||
func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||
// Short circuit if any of the changes is a full update, which
|
||||
// means we can skip sending individual changes.
|
||||
if change.HasFull(c) {
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
|
||||
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(c.NodeID, changes)
|
||||
return
|
||||
}
|
||||
|
||||
@ -329,15 +274,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
// processBatchedChanges processes all pending batched changes.
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
@ -355,16 +298,31 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
// First check if we have active connections for this node
|
||||
if nodeConn, exists := b.nodes.Load(id); exists {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check disconnected timestamp with grace period
|
||||
val, ok := b.connected.Load(id)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// nil means connected
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -372,9 +330,26 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
// First, add all nodes with active connections
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Then add all entries from the connected map
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// nil means connected
|
||||
ret.Store(id, val == nil)
|
||||
// Only add if not already added as connected above
|
||||
if _, exists := ret.Load(id); !exists {
|
||||
if val == nil {
|
||||
// nil means connected
|
||||
ret.Store(id, true)
|
||||
} else {
|
||||
// timestamp means disconnected
|
||||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@ -482,12 +457,21 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// Only add if not already processed above
|
||||
if _, exists := result[id]; !exists {
|
||||
// Use immediate connection status for debug (no grace period)
|
||||
connected := (val == nil) // nil means connected, timestamp means disconnected
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
|
@ -27,6 +27,60 @@ type batcherTestCase struct {
|
||||
fn batcherFunc
|
||||
}
|
||||
|
||||
// testBatcherWrapper wraps a real batcher to add online/offline notifications
|
||||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
state *state.State
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
// Mark node as online in state before AddNode to match production behavior
|
||||
// This ensures the NodeStore has correct online status for change processing
|
||||
if t.state != nil {
|
||||
// Use Connect to properly mark node online in NodeStore but don't send its changes
|
||||
_ = t.state.Connect(id)
|
||||
}
|
||||
|
||||
// First add the node to the real batcher
|
||||
err := t.Batcher.AddNode(id, c, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send the online notification that poll.go would normally send
|
||||
// This ensures other nodes get notified about this node coming online
|
||||
t.AddWork(change.NodeOnline(id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
// Mark node as offline in state BEFORE removing from batcher
|
||||
// This ensures the NodeStore has correct offline status when the change is processed
|
||||
if t.state != nil {
|
||||
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
|
||||
_, _ = t.state.Disconnect(id)
|
||||
}
|
||||
|
||||
// Send the offline notification that poll.go would normally send
|
||||
// Do this BEFORE removing from batcher so the change can be processed
|
||||
t.AddWork(change.NodeOffline(id))
|
||||
|
||||
// Finally remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
func wrapBatcherForTest(b Batcher, state *state.State) Batcher {
|
||||
return &testBatcherWrapper{Batcher: b, state: state}
|
||||
}
|
||||
|
||||
// allBatcherFunctions contains all batcher implementations to test.
|
||||
var allBatcherFunctions = []batcherTestCase{
|
||||
{"LockFree", NewBatcherAndMapper},
|
||||
@ -183,8 +237,8 @@ func setupBatcherWithTestData(
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"users": ["*"],
|
||||
"ports": ["*:*"]
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
@ -194,8 +248,8 @@ func setupBatcherWithTestData(
|
||||
t.Fatalf("Failed to set allow-all policy: %v", err)
|
||||
}
|
||||
|
||||
// Create batcher with the state
|
||||
batcher := bf(cfg, state)
|
||||
// Create batcher with the state and wrap it for testing
|
||||
batcher := wrapBatcherForTest(bf(cfg, state), state)
|
||||
batcher.Start()
|
||||
|
||||
testData := &TestData{
|
||||
@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
time.Sleep(100 * time.Millisecond) // Let connection settle
|
||||
|
||||
// Generate some work
|
||||
@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullSet)
|
||||
@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Disconnect all nodes
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
// Give time for final updates to process
|
||||
@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
tn2 := testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
}
|
||||
@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
select {
|
||||
case data := <-tn.ch:
|
||||
assertOnlineMapResponse(t, data, true)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Did not receive expected Online response update")
|
||||
}
|
||||
|
||||
@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Disconnect the second node
|
||||
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
|
||||
assert.False(t, batcher.IsConnected(tn2.n.ID))
|
||||
batcher.RemoveNode(tn2.n.ID, tn2.ch)
|
||||
// Note: IsConnected may return true during grace period for DNS resolution
|
||||
|
||||
// First node should get update that second has disconnected.
|
||||
select {
|
||||
case data := <-tn.ch:
|
||||
assertOnlineMapResponse(t, data, false)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Did not receive expected Online response update")
|
||||
}
|
||||
|
||||
@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
// }
|
||||
|
||||
// Test RemoveNode
|
||||
batcher.RemoveNode(tn.n.ID, tn.ch, false)
|
||||
if batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be disconnected after RemoveNode")
|
||||
}
|
||||
batcher.RemoveNode(tn.n.ID, tn.ch)
|
||||
// Note: IsConnected may return true during grace period for DNS resolution
|
||||
// The node is actually removed from active connections but grace period allows DNS lookups
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
|
||||
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Add real work during connection chaos
|
||||
@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Remove second connection
|
||||
@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2, false)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPSet)
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
// Rapid removal creates race between worker and removal
|
||||
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch, false)
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Give workers time to process and close channels
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
for _, node := range stableNodes {
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
churningChannelsMutex.Lock()
|
||||
churningChannels[nodeID] = ch
|
||||
churningChannelsMutex.Unlock()
|
||||
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
ch, exists := churningChannels[nodeID]
|
||||
churningChannelsMutex.Unlock()
|
||||
if exists {
|
||||
batcher.RemoveNode(nodeID, ch, false)
|
||||
batcher.RemoveNode(nodeID, ch)
|
||||
}
|
||||
}(node.n.ID)
|
||||
}
|
||||
@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
var connectedNodesMutex sync.RWMutex
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
connectedNodesMutex.Lock()
|
||||
connectedNodes[node.n.ID] = true
|
||||
connectedNodesMutex.Unlock()
|
||||
@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
connectedNodesMutex.RUnlock()
|
||||
|
||||
if isConnected {
|
||||
batcher.RemoveNode(nodeID, channel, false)
|
||||
batcher.RemoveNode(nodeID, channel)
|
||||
connectedNodesMutex.Lock()
|
||||
connectedNodes[nodeID] = false
|
||||
connectedNodesMutex.Unlock()
|
||||
@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Now disconnect all nodes from batcher to stop new updates
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
||||
@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Connect nodes one at a time to avoid overwhelming the work queue
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
// Small delay between connections to allow NodeCameOnline processing
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Check how many peers each node should see
|
||||
for i, node := range allNodes {
|
||||
peers, err := testData.State.ListPeers(node.n.ID)
|
||||
if err != nil {
|
||||
t.Errorf("Error listing peers for node %d: %v", i, err)
|
||||
} else {
|
||||
t.Logf("Node %d should see %d peers from state", i, len(peers))
|
||||
}
|
||||
peers := testData.State.ListPeers(node.n.ID)
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
|
||||
// Send a full update - this should generate full peer lists
|
||||
@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
foundFullUpdate := false
|
||||
|
||||
// Read all available updates for each node
|
||||
for i := range len(allNodes) {
|
||||
for i := range allNodes {
|
||||
nodeUpdates := 0
|
||||
t.Logf("Reading updates for node %d:", i)
|
||||
|
||||
@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
|
||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||
|
||||
// Connect first node
|
||||
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d", nodes[0].n.ID)
|
||||
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||
|
||||
// Wait for initial NodeCameOnline to be processed
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
t.Errorf("ERROR: Received unknown update type!")
|
||||
}
|
||||
|
||||
// Check if there should be peers available
|
||||
peers, err := testData.State.ListPeers(nodes[0].n.ID)
|
||||
if err != nil {
|
||||
t.Errorf("Error getting peers from state: %v", err)
|
||||
} else {
|
||||
t.Logf("State shows %d peers available for this node", len(peers))
|
||||
if len(peers) > 0 && len(data.Peers) == 0 {
|
||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers))
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
|
||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
}
|
||||
|
||||
// Connect second node for comparison
|
||||
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node2: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
t.Errorf("Node1 should have 3 active connections, got %d", activeConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info, exists := debugInfo[node2.n.ID]; exists {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 1 {
|
||||
t.Errorf("Node2 should have 1 active connection, got %d", activeConnections)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Send update and verify ALL connections receive it
|
||||
t.Logf("Phase 5: Testing update distribution to all connections...")
|
||||
|
||||
// Clear any existing updates from all channels
|
||||
clearChannel := func(ch chan *tailcfg.MapResponse) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
// drain
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clearChannel(node1.ch)
|
||||
clearChannel(secondChannel)
|
||||
clearChannel(thirdChannel)
|
||||
clearChannel(node2.ch)
|
||||
|
||||
// Send a change notification from node2 (so node1 should receive it on all connections)
|
||||
testChangeSet := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
|
||||
batcher.AddWork(testChangeSet)
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Let updates propagate
|
||||
|
||||
// Verify all three connections for node1 receive the update
|
||||
connection1Received := false
|
||||
connection2Received := false
|
||||
connection3Received := false
|
||||
|
||||
select {
|
||||
case mapResp := <-node1.ch:
|
||||
connection1Received = (mapResp != nil)
|
||||
t.Logf("Node1 connection 1 received update: %t", connection1Received)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 1 did not receive update")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-secondChannel:
|
||||
connection2Received = (mapResp != nil)
|
||||
t.Logf("Node1 connection 2 received update: %t", connection2Received)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 2 did not receive update")
|
||||
}
|
||||
|
||||
select {
|
||||
case mapResp := <-thirdChannel:
|
||||
connection3Received = (mapResp != nil)
|
||||
t.Logf("Node1 connection 3 received update: %t", connection3Received)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Node1 connection 3 did not receive update")
|
||||
}
|
||||
|
||||
if connection1Received && connection2Received && connection3Received {
|
||||
t.Logf("SUCCESS: All three connections for node1 received the update")
|
||||
} else {
|
||||
t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t",
|
||||
connection1Received, connection2Received, connection3Received)
|
||||
}
|
||||
|
||||
// Phase 6: Test connection removal and verify remaining connections still work
|
||||
t.Logf("Phase 6: Testing connection removal...")
|
||||
|
||||
// Remove the second connection
|
||||
removed := batcher.RemoveNode(node1.n.ID, secondChannel)
|
||||
if !removed {
|
||||
t.Errorf("Failed to remove second connection for node1")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify debug status shows 2 connections now
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 2 {
|
||||
t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
@ -12,7 +13,7 @@ import (
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
@ -21,7 +22,17 @@ type MapResponseBuilder struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
type debugType string
|
||||
|
||||
const (
|
||||
fullResponseDebug debugType = "full"
|
||||
patchResponseDebug debugType = "patch"
|
||||
removeResponseDebug debugType = "remove"
|
||||
changeResponseDebug debugType = "change"
|
||||
derpResponseDebug debugType = "derp"
|
||||
)
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
@ -35,32 +46,39 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
// addError adds an error to the builder's error list.
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
// hasErrors returns true if the builder has accumulated any errors.
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
// WithCapabilityVersion sets the capability version for the response.
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
// WithSelfNode adds the requesting node to the response.
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
// Always use batcher's view of online status for self node
|
||||
// The batcher respects grace periods for logout scenarios
|
||||
node := nodeView.AsStruct()
|
||||
// if b.mapper.batcher != nil {
|
||||
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||
// }
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
@ -74,29 +92,38 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder {
|
||||
if debugDumpMapResponsePath != "" {
|
||||
b.debugType = t
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response.
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
// WithDomain adds the domain configuration.
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
// WithCollectServicesDisabled sets the collect services flag to false.
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
// It disables log tailing if the mapper's LogTail is not enabled.
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
@ -104,53 +131,56 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node.
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
// WithDNSConfig adds DNS configuration for the requesting node.
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers.
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
// WithPacketFilters adds packet filter rules based on policy.
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
@ -161,15 +191,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response).
|
||||
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@ -177,12 +206,12 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates).
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@ -190,14 +219,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
@ -206,15 +236,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
changedViews = policy.ReduceNodes(node, peers, matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
changedViews = peers
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
@ -229,19 +259,20 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
// WithPeerChangedPatch adds peer change patches.
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@ -251,11 +282,7 @@ func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeDebugMapResponse(b.resp, node)
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -69,16 +70,18 @@ func newMapper(
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@ -95,7 +98,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@ -115,12 +118,12 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
}
|
||||
|
||||
if len(node.IPs()) > 0 {
|
||||
@ -138,10 +141,7 @@ func (m *mapper) fullMapResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
@ -183,10 +183,7 @@ func (m *mapper) peerChangeResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
@ -208,7 +205,8 @@ func (m *mapper) peerRemovedResponse(
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
node *types.Node,
|
||||
t debugType,
|
||||
nodeID types.NodeID,
|
||||
) {
|
||||
body, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
@ -236,25 +234,6 @@ func writeDebugMapResponse(
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
|
@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1,
|
||||
nodeInShared1.View(),
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
|
@ -133,13 +133,12 @@ func tailNode(
|
||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||
}
|
||||
|
||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
|
||||
// LastSeen is only set when node is
|
||||
// not connected to the control server.
|
||||
if node.LastSeen().Valid() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
tNode.LastSeen = &lastSeen
|
||||
}
|
||||
// Set LastSeen only for offline nodes to avoid confusing Tailscale clients
|
||||
// during rapid reconnection cycles. Online nodes should not have LastSeen set
|
||||
// as this can make clients interpret them as "not online" despite Online=true.
|
||||
if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
tNode.LastSeen = &lastSeen
|
||||
}
|
||||
|
||||
return &tNode, nil
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -296,16 +295,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||
// and validates that it matches the MachineKey from the Noise session.
|
||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||
nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if !ok {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
|
||||
nv := node.View()
|
||||
|
||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||
if ns.machineKey != nv.MachineKey() {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||
|
@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||
}
|
||||
|
||||
// The user claims are now updated from the the userinfo endpoint so we can verify the user a
|
||||
// The user claims are now updated from the userinfo endpoint so we can verify the user
|
||||
// against allowed emails, email domains, and groups.
|
||||
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
||||
httpError(writer, err)
|
||||
@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims)
|
||||
user, c, err := a.createOrUpdateUserFromClaim(&claims)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
a.h.Change(change.PolicyChange())
|
||||
}
|
||||
a.h.Change(c)
|
||||
|
||||
// TODO(kradalby): Is this comment right?
|
||||
// If the node exists, then the node should be reauthenticated,
|
||||
@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
|
||||
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
claims *types.OIDCClaims,
|
||||
) (*types.User, bool, error) {
|
||||
) (*types.User, change.ChangeSet, error) {
|
||||
var user *types.User
|
||||
var err error
|
||||
var newUser bool
|
||||
var policyChanged bool
|
||||
var c change.ChangeSet
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
// if the user is still not found, create a new empty user.
|
||||
@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
user.FromClaim(claims)
|
||||
|
||||
if newUser {
|
||||
user, policyChanged, err = a.h.state.CreateUser(*user)
|
||||
user, c, err = a.h.state.CreateUser(*user)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
_, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
*u = *user
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("updating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return user, policyChanged, nil
|
||||
return user, c, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) handleRegistration(
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -138,39 +139,74 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
return ret
|
||||
}
|
||||
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
|
||||
// and returns the new list of approved routes.
|
||||
// The approved routes will include:
|
||||
// 1. ALL previously approved routes (regardless of whether they're still advertised)
|
||||
// 2. New routes from announcedRoutes that can be auto-approved by policy
|
||||
// This ensures that:
|
||||
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
|
||||
// - New routes can be auto-approved according to policy
|
||||
// - Routes can only be removed by explicit admin action (not by auto-approval).
|
||||
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
|
||||
if pm == nil {
|
||||
return false
|
||||
return currentApproved, false
|
||||
}
|
||||
nodeView := node.View()
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range nodeView.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(nodeView, route) {
|
||||
|
||||
// Start with ALL currently approved routes - we never remove approved routes
|
||||
newApproved := make([]netip.Prefix, len(currentApproved))
|
||||
copy(newApproved, currentApproved)
|
||||
|
||||
// Then, check for new routes that can be auto-approved
|
||||
for _, route := range announcedRoutes {
|
||||
// Skip if already approved
|
||||
if slices.Contains(newApproved, route) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this new route can be auto-approved by policy
|
||||
canApprove := pm.NodeCanApproveRoute(nv, route)
|
||||
if canApprove {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
|
||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
||||
// This prevents clearing existing approved routes when nodes
|
||||
// temporarily don't have announced routes during policy changes.
|
||||
if len(newApproved) > 0 {
|
||||
combined := append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(combined)
|
||||
combined = slices.Compact(combined)
|
||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
// Sort and deduplicate
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
||||
node.ApprovedRoutes = combined
|
||||
return true
|
||||
// Sort the current approved for comparison
|
||||
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||
copy(sortedCurrent, currentApproved)
|
||||
tsaddr.SortPrefixes(sortedCurrent)
|
||||
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(sortedCurrent, newApproved) {
|
||||
// Log what changed
|
||||
var added, kept []netip.Prefix
|
||||
for _, route := range newApproved {
|
||||
if !slices.Contains(sortedCurrent, route) {
|
||||
added = append(added, route)
|
||||
} else {
|
||||
kept = append(kept, route)
|
||||
}
|
||||
}
|
||||
|
||||
if len(added) > 0 {
|
||||
log.Debug().
|
||||
Uint64("node.id", nv.ID().Uint64()).
|
||||
Str("node.name", nv.Hostname()).
|
||||
Strs("routes.added", util.PrefixesToString(added)).
|
||||
Strs("routes.kept", util.PrefixesToString(kept)).
|
||||
Int("routes.total", len(newApproved)).
|
||||
Msg("Routes auto-approved by policy")
|
||||
}
|
||||
|
||||
return newApproved, true
|
||||
}
|
||||
|
||||
return false
|
||||
return newApproved, false
|
||||
}
|
||||
|
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
@ -0,0 +1,339 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
user1 := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "testuser@",
|
||||
}
|
||||
user2 := types.User{
|
||||
Model: gorm.Model{ID: 2},
|
||||
Name: "otheruser@",
|
||||
}
|
||||
users := []types.User{user1, user2}
|
||||
|
||||
node1 := &types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "test-node",
|
||||
UserID: user1.ID,
|
||||
User: user1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ForcedTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
node2 := &types.Node{
|
||||
ID: 2,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "other-node",
|
||||
UserID: user2.ID,
|
||||
User: user2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
// Create a policy that auto-approves specific routes
|
||||
policyJSON := `{
|
||||
"groups": {
|
||||
"group:test": ["testuser@"]
|
||||
},
|
||||
"tagOwners": {
|
||||
"tag:test": ["testuser@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.0.0.0/8": ["testuser@", "tag:test"],
|
||||
"10.1.0.0/24": ["testuser@"],
|
||||
"10.2.0.0/24": ["testuser@"],
|
||||
"192.168.0.0/24": ["tag:test"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||
assert.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "previously_approved_route_no_longer_advertised_should_remain",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Should still be here!
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "Previously approved routes should never be removed even when no longer advertised",
|
||||
},
|
||||
{
|
||||
name: "add_new_auto_approved_route_keeps_old_approved",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8)
|
||||
netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept
|
||||
},
|
||||
wantChanged: true,
|
||||
description: "New auto-approved routes should be added while keeping old approved routes",
|
||||
},
|
||||
{
|
||||
name: "no_announced_routes_keeps_all_approved",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{}, // No routes announced
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "All approved routes should remain when no routes are announced",
|
||||
},
|
||||
{
|
||||
name: "no_changes_when_announced_equals_approved",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "No changes should occur when announced routes match approved routes",
|
||||
},
|
||||
{
|
||||
name: "auto_approve_multiple_new_routes",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved
|
||||
netip.MustParsePrefix("172.16.0.0/24"), // Original kept
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||
},
|
||||
wantChanged: true,
|
||||
description: "Multiple new routes should be auto-approved while keeping existing approved routes",
|
||||
},
|
||||
{
|
||||
name: "node_without_permission_no_auto_approval",
|
||||
node: node2, // Different node without the tag
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "Routes should not be auto-approved for nodes without proper permissions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||
|
||||
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||
|
||||
// Verify that all previously approved routes are still present
|
||||
for _, prevRoute := range tt.currentApproved {
|
||||
assert.Contains(t, gotApproved, prevRoute,
|
||||
"previously approved route %s was removed - this should never happen", prevRoute)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||
// Create a basic policy for edge case testing
|
||||
aclPolicy := `
|
||||
{
|
||||
"acls": [
|
||||
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.1.0.0/24": ["test@"],
|
||||
},
|
||||
},
|
||||
}`
|
||||
|
||||
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
}{
|
||||
{
|
||||
name: "nil_policy_manager",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
},
|
||||
{
|
||||
name: "nil_current_approved",
|
||||
currentApproved: nil,
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "nil_announced_routes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: nil,
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate_approved_routes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "empty_slices",
|
||||
currentApproved: []netip.Prefix{},
|
||||
announcedRoutes: []netip.Prefix{},
|
||||
wantApproved: []netip.Prefix{},
|
||||
wantChanged: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for i, pmf := range pmfs {
|
||||
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||
// Create test user
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
}
|
||||
users := []types.User{user}
|
||||
|
||||
// Create test node
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager or use nil if specified
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
if tt.name != "nil_policy_manager" {
|
||||
pm, err = pmf(users, nodes.ViewSlice())
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
pm = nil
|
||||
}
|
||||
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
|
||||
|
||||
// Handle nil vs empty slice comparison
|
||||
if tt.wantApproved == nil {
|
||||
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||
} else {
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
361
hscontrol/policy/policy_route_approval_test.go
Normal file
361
hscontrol/policy/policy_route_approval_test.go
Normal file
@ -0,0 +1,361 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
// Test policy that allows specific routes to be auto-approved
|
||||
aclPolicy := `
|
||||
{
|
||||
"groups": {
|
||||
"group:admins": ["test@"],
|
||||
},
|
||||
"acls": [
|
||||
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.0.0.0/24": ["test@"],
|
||||
"192.168.0.0/24": ["group:admins"],
|
||||
"172.16.0.0/16": ["tag:approved"],
|
||||
},
|
||||
},
|
||||
"tagOwners": {
|
||||
"tag:approved": ["test@"],
|
||||
},
|
||||
}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
nodeHostname string
|
||||
nodeUser string
|
||||
nodeTags []string
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
|
||||
}{
|
||||
{
|
||||
name: "previously_approved_route_no_longer_advertised_remains",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
|
||||
},
|
||||
{
|
||||
name: "add_new_auto_approved_route_keeps_existing",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Still advertised
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New route
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "no_announced_routes_keeps_all_approved",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
},
|
||||
{
|
||||
name: "manually_approved_route_not_in_policy_remains",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "tagged_node_gets_tag_approved_routes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route
|
||||
},
|
||||
nodeUser: "test",
|
||||
nodeTags: []string{"tag:approved"},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "complex_scenario_multiple_changes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
|
||||
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
}
|
||||
|
||||
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||
|
||||
for _, tt := range tests {
|
||||
for i, pmf := range pmfs {
|
||||
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||
// Create test user
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: tt.nodeUser,
|
||||
}
|
||||
users := []types.User{user}
|
||||
|
||||
// Create test node
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: tt.nodeHostname,
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
ForcedTags: tt.nodeTags,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pm)
|
||||
|
||||
// Test ApproveRoutesWithPolicy
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||
pm,
|
||||
node.View(),
|
||||
tt.currentApproved,
|
||||
tt.announcedRoutes,
|
||||
)
|
||||
|
||||
// Check change flag
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch")
|
||||
|
||||
// Check approved routes match expected
|
||||
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||
t.Logf("Want: %v", tt.wantApproved)
|
||||
t.Logf("Got: %v", gotApproved)
|
||||
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// Verify all previously approved routes are still present
|
||||
for _, prevRoute := range tt.currentApproved {
|
||||
assert.Contains(t, gotApproved, prevRoute,
|
||||
"previously approved route %s was removed - this should NEVER happen", prevRoute)
|
||||
}
|
||||
|
||||
// Verify no routes were incorrectly removed
|
||||
for _, removedRoute := range tt.wantRemovedRoutes {
|
||||
assert.NotContains(t, gotApproved, removedRoute,
|
||||
"route %s should have been removed but wasn't", removedRoute)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
||||
aclPolicy := `
|
||||
{
|
||||
"acls": [
|
||||
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.0.0.0/8": ["test@"],
|
||||
},
|
||||
},
|
||||
}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
}{
|
||||
{
|
||||
name: "nil_current_approved",
|
||||
currentApproved: nil,
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "empty_current_approved",
|
||||
currentApproved: []netip.Prefix{},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate_routes_handled",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: true, // Duplicates are removed, so it's a change
|
||||
},
|
||||
}
|
||||
|
||||
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||
|
||||
for _, tt := range tests {
|
||||
for i, pmf := range pmfs {
|
||||
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||
// Create test user
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
}
|
||||
users := []types.User{user}
|
||||
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||
pm,
|
||||
node.View(),
|
||||
tt.currentApproved,
|
||||
tt.announcedRoutes,
|
||||
)
|
||||
|
||||
assert.Equal(t, tt.wantChanged, gotChanged)
|
||||
|
||||
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
currentApproved := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
}
|
||||
announcedRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
}
|
||||
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: currentApproved,
|
||||
}
|
||||
|
||||
// With nil policy manager, should return current approved unchanged
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes)
|
||||
|
||||
assert.False(t, gotChanged)
|
||||
assert.Equal(t, currentApproved, gotApproved)
|
||||
}
|
@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
|
||||
canApprove: false,
|
||||
},
|
||||
{
|
||||
name: "policy-without-autoApprovers-section",
|
||||
node: normalNode,
|
||||
route: p("10.33.0.0/16"),
|
||||
policy: `{
|
||||
"groups": {
|
||||
"group:admin": ["user1@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admin"],
|
||||
"dst": ["group:admin:*"]
|
||||
},
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admin"],
|
||||
"dst": ["10.33.0.0/16:*"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
canApprove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
// The fast path is that a node requests to approve a prefix
|
||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||
// check and return quickly
|
||||
if _, ok := pm.autoApproveMap[route]; ok {
|
||||
if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) {
|
||||
if approvers, ok := pm.autoApproveMap[route]; ok {
|
||||
canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains)
|
||||
if canApprove {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) {
|
||||
canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains)
|
||||
if canApprove {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
@ -112,6 +111,15 @@ func (m *mapSession) serve() {
|
||||
// This is the mechanism where the node gives us information about its
|
||||
// current configuration.
|
||||
//
|
||||
// Process the MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
m.h.Change(c)
|
||||
|
||||
// If OmitPeers is true and Stream is false
|
||||
// then the server will let clients update their endpoints without
|
||||
// breaking existing long-polling (Stream == true) connections.
|
||||
@ -122,14 +130,6 @@ func (m *mapSession) serve() {
|
||||
// the response and just wants a 200.
|
||||
// !req.stream && req.OmitPeers
|
||||
if m.isEndpointUpdate() {
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
m.h.Change(c)
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
}
|
||||
@ -142,6 +142,8 @@ func (m *mapSession) serve() {
|
||||
func (m *mapSession) serveLongPoll() {
|
||||
m.beforeServeLongPoll()
|
||||
|
||||
log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected")
|
||||
|
||||
// Clean up the session when the client disconnects
|
||||
defer func() {
|
||||
m.cancelChMu.Lock()
|
||||
@ -149,18 +151,38 @@ func (m *mapSession) serveLongPoll() {
|
||||
close(m.cancelCh)
|
||||
m.cancelChMu.Unlock()
|
||||
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
_ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
|
||||
|
||||
// When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather).
|
||||
// Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects.
|
||||
// If it does reconnect, the existing mapSession will be replaced and the node remains online.
|
||||
// If it doesn't reconnect within the timeout, we mark it as offline.
|
||||
//
|
||||
// This avoids flapping nodes in the UI and unnecessary churn in the network.
|
||||
// This is not my favourite solution, but it kind of works in our eventually consistent world.
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
disconnected := true
|
||||
// Wait up to 10 seconds for the node to reconnect.
|
||||
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
|
||||
for range 10 {
|
||||
if m.h.mapBatcher.IsConnected(m.node.ID) {
|
||||
disconnected = false
|
||||
break
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
m.h.Change(disconnectChange)
|
||||
|
||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
||||
if disconnected {
|
||||
disconnectChanges, err := m.h.state.Disconnect(m.node.ID)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
}
|
||||
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
m.h.Change(disconnectChanges...)
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set up the client stream
|
||||
@ -172,25 +194,25 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||
|
||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
||||
// where the change is sent before the node is in the batcher's node map
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
||||
select {
|
||||
case m.ch <- &tailcfg.MapResponse{}:
|
||||
default:
|
||||
// Channel might be closed
|
||||
}
|
||||
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
|
||||
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
|
||||
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
|
||||
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
|
||||
// from AvailableRoutes.
|
||||
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||
if err != nil {
|
||||
m.errf(err, "failed to update node from initial MapRequest")
|
||||
return
|
||||
}
|
||||
|
||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
||||
// but we still need to update routes and other state-level changes
|
||||
connectChange := m.h.state.Connect(m.node)
|
||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
||||
m.h.Change(connectChange)
|
||||
}
|
||||
// Connect the node after its state has been updated.
|
||||
// We send two separate change notifications because these are distinct operations:
|
||||
// 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo)
|
||||
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
|
||||
// While this results in two notifications, it ensures route data is synchronized before
|
||||
// primary route selection occurs, which is critical for proper HA subnet router failover.
|
||||
connectChanges := m.h.state.Connect(m.node.ID)
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
@ -235,6 +257,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
403
hscontrol/state/node_store.go
Normal file
403
hscontrol/state/node_store.go
Normal file
@ -0,0 +1,403 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
batchSize = 10
|
||||
batchTimeout = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
const (
|
||||
put = 1
|
||||
del = 2
|
||||
update = 3
|
||||
)
|
||||
|
||||
const prometheusNamespace = "headscale"
|
||||
|
||||
var (
|
||||
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_operations_total",
|
||||
Help: "Total number of NodeStore operations",
|
||||
}, []string{"operation"})
|
||||
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_operation_duration_seconds",
|
||||
Help: "Duration of NodeStore operations",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, []string{"operation"})
|
||||
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_batch_size",
|
||||
Help: "Size of NodeStore write batches",
|
||||
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
|
||||
})
|
||||
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_batch_duration_seconds",
|
||||
Help: "Duration of NodeStore batch processing",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_snapshot_build_duration_seconds",
|
||||
Help: "Duration of NodeStore snapshot building from nodes",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_nodes_total",
|
||||
Help: "Total number of nodes in the NodeStore",
|
||||
})
|
||||
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_peers_calculation_duration_seconds",
|
||||
Help: "Duration of peers calculation in NodeStore",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_queue_depth",
|
||||
Help: "Current depth of NodeStore write queue",
|
||||
})
|
||||
)
|
||||
|
||||
// NodeStore is a thread-safe store for nodes.
|
||||
// It is a copy-on-write structure, replacing the "snapshot"
|
||||
// when a change to the structure occurs. It is optimised for reads,
|
||||
// and while batches are not fast, they are grouped together
|
||||
// to do less of the expensive peer calculation if there are many
|
||||
// changes rapidly.
|
||||
//
|
||||
// Writes will block until committed, while reads are never
|
||||
// blocked. This means that the caller of a write operation
|
||||
// is responsible for ensuring an update depending on a write
|
||||
// is not issued before the write is complete.
|
||||
type NodeStore struct {
|
||||
data atomic.Pointer[Snapshot]
|
||||
|
||||
peersFunc PeersFunc
|
||||
writeQueue chan work
|
||||
}
|
||||
|
||||
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||
nodes := make(map[types.NodeID]types.Node, len(allNodes))
|
||||
for _, n := range allNodes {
|
||||
nodes[n.ID] = *n
|
||||
}
|
||||
snap := snapshotFromNodes(nodes, peersFunc)
|
||||
|
||||
store := &NodeStore{
|
||||
peersFunc: peersFunc,
|
||||
}
|
||||
store.data.Store(&snap)
|
||||
|
||||
// Initialize node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Snapshot is the representation of the current state of the NodeStore.
|
||||
// It contains all nodes and their relationships.
|
||||
// It is a copy-on-write structure, meaning that when a write occurs,
|
||||
// a new Snapshot is created with the updated state,
|
||||
// and replaces the old one atomically.
|
||||
type Snapshot struct {
|
||||
// nodesByID is the main source of truth for nodes.
|
||||
nodesByID map[types.NodeID]types.Node
|
||||
|
||||
// calculated from nodesByID
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
}
|
||||
|
||||
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||
// with the relationships between nodes and their peers.
|
||||
// This will typically be used to calculate which nodes can see each other
|
||||
// based on the current policy.
|
||||
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
|
||||
// work represents a single operation to be performed on the NodeStore.
|
||||
type work struct {
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
// If the node already exists, it will be replaced.
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
type UpdateNodeFunc func(n *types.Node)
|
||||
|
||||
// UpdateNode applies a function to modify a specific node in the store.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
// This is analogous to a database "transaction", or, the caller should
|
||||
// rather collect all data they want to change, and then call this function.
|
||||
// Fewer calls are better.
|
||||
//
|
||||
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||
// This is because the main nodesByID map contains the struct, and every other map is using a
|
||||
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
|
||||
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||
// on all read operations.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: del,
|
||||
nodeID: id,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("delete").Inc()
|
||||
}
|
||||
|
||||
// Start initializes the NodeStore and starts processing the write queue.
|
||||
func (s *NodeStore) Start() {
|
||||
s.writeQueue = make(chan work)
|
||||
go s.processWrite()
|
||||
}
|
||||
|
||||
// Stop stops the NodeStore.
|
||||
func (s *NodeStore) Stop() {
|
||||
close(s.writeQueue)
|
||||
}
|
||||
|
||||
// processWrite processes the write queue in batches.
|
||||
func (s *NodeStore) processWrite() {
|
||||
c := time.NewTicker(batchTimeout)
|
||||
defer c.Stop()
|
||||
batch := make([]work, 0, batchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-s.writeQueue:
|
||||
if !ok {
|
||||
// Channel closed, apply any remaining batch and exit
|
||||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
}
|
||||
return
|
||||
}
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= batchSize {
|
||||
s.applyBatch(batch)
|
||||
batch = batch[:0]
|
||||
c.Reset(batchTimeout)
|
||||
}
|
||||
case <-c.C:
|
||||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
batch = batch[:0]
|
||||
}
|
||||
c.Reset(batchTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyBatch applies a batch of work to the node store.
|
||||
// This means that it takes a copy of the current nodes,
|
||||
// then applies the batch of operations to that copy,
|
||||
// runs any precomputation needed (like calculating peers),
|
||||
// and finally replaces the snapshot in the store with the new one.
|
||||
// The replacement of the snapshot is atomic, ensuring that reads
|
||||
// are never blocked by writes.
|
||||
// Each write item is blocked until the batch is applied to ensure
|
||||
// the caller knows the operation is complete and do not send any
|
||||
// updates that are dependent on a read that is yet to be written.
|
||||
func (s *NodeStore) applyBatch(batch []work) {
|
||||
timer := prometheus.NewTimer(nodeStoreBatchDuration)
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreBatchSize.Observe(float64(len(batch)))
|
||||
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
for _, w := range batch {
|
||||
switch w.op {
|
||||
case put:
|
||||
nodes[w.nodeID] = w.node
|
||||
case update:
|
||||
// Update the specific node identified by nodeID
|
||||
if n, exists := nodes[w.nodeID]; exists {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
case del:
|
||||
delete(nodes, w.nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
||||
s.data.Store(&newSnap)
|
||||
|
||||
// Update node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotFromNodes creates a new Snapshot from the provided nodes.
|
||||
// It builds a lot of "indexes" to make lookups fast for datasets we
|
||||
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
|
||||
// This is not a fast operation, it is the "slow" part of our copy-on-write
|
||||
// structure, but it allows us to have fast reads and efficient lookups.
|
||||
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
||||
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
allNodes := make([]types.NodeView, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
allNodes = append(allNodes, n.View())
|
||||
}
|
||||
|
||||
newSnap := Snapshot{
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
|
||||
// peersByNode is most likely the most expensive operation,
|
||||
// it will use the list of all nodes, combined with the
|
||||
// current policy to precalculate which nodes are peers and
|
||||
// can see each other.
|
||||
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||
defer peersTimer.ObserveDuration()
|
||||
return peersFunc(allNodes)
|
||||
}(),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
}
|
||||
|
||||
// Build nodesByUser and nodesByNodeKey maps
|
||||
for _, n := range nodes {
|
||||
nodeView := n.View()
|
||||
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
|
||||
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
|
||||
}
|
||||
|
||||
return newSnap
|
||||
}
|
||||
|
||||
// GetNode retrieves a node by its ID.
|
||||
// The bool indicates if the node exists or is available (like "err not found").
|
||||
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get").Inc()
|
||||
|
||||
n, exists := s.data.Load().nodesByID[id]
|
||||
if !exists {
|
||||
return types.NodeView{}, false
|
||||
}
|
||||
|
||||
return n.View(), true
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
||||
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
|
||||
return s.data.Load().nodesByNodeKey[nodeKey]
|
||||
}
|
||||
|
||||
// ListNodes returns a slice of all nodes in the store.
|
||||
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().allNodes)
|
||||
}
|
||||
|
||||
// ListPeers returns a slice of all peers for a given node ID.
|
||||
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list_peers").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().peersByNode[id])
|
||||
}
|
||||
|
||||
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
||||
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
||||
}
|
501
hscontrol/state/node_store_test.go
Normal file
501
hscontrol/state/node_store_test.go
Normal file
@ -0,0 +1,501 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestSnapshotFromNodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
|
||||
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
|
||||
}{
|
||||
{
|
||||
name: "empty nodes",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
return make(map[types.NodeID][]types.NodeView)
|
||||
}
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single node",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
}
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
|
||||
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes same user",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 1, "user1", "node2"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
// Each node sees the other as peer (but not itself)
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes different users",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 1, "user1", "node3"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// Each node should have 2 peers (all others, but not itself)
|
||||
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||
|
||||
// User groupings
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "odd-even peers filtering",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 3, "user3", "node3"),
|
||||
4: createTestNode(4, 4, "user4", "node4"),
|
||||
}
|
||||
peersFunc := oddEvenPeersFunc
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
assert.Len(t, snapshot.allNodes, 4)
|
||||
assert.Len(t, snapshot.peersByNode, 4)
|
||||
assert.Len(t, snapshot.nodesByUser, 4)
|
||||
|
||||
// Odd nodes should only see other odd nodes as peers
|
||||
require.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
// Even nodes should only see other even nodes as peers
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nodes, peersFunc := tt.setupFunc()
|
||||
snapshot := snapshotFromNodes(nodes, peersFunc)
|
||||
tt.validate(t, nodes, snapshot)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
|
||||
now := time.Now()
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
ipv4 := netip.MustParseAddr("100.64.0.1")
|
||||
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
||||
|
||||
return types.Node{
|
||||
ID: nodeID,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
DiscoKey: discoKey.Public(),
|
||||
Hostname: hostname,
|
||||
GivenName: hostname,
|
||||
UserID: userID,
|
||||
User: types.User{
|
||||
Name: username,
|
||||
DisplayName: username,
|
||||
},
|
||||
RegisterMethod: "test",
|
||||
IPv4: &ipv4,
|
||||
IPv6: &ipv6,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// Peer functions
|
||||
|
||||
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
for _, n := range nodes {
|
||||
if n.ID() != node.ID() {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
nodeIsOdd := node.ID()%2 == 1
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() == node.ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIsOdd := n.ID()%2 == 1
|
||||
|
||||
// Only add peer if both are odd or both are even
|
||||
if nodeIsOdd == peerIsOdd {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestNodeStoreOperations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(t *testing.T) *NodeStore
|
||||
steps []testStep
|
||||
}{
|
||||
{
|
||||
name: "create empty store and add single node",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
return NewNodeStore(nil, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify empty store",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add first node",
|
||||
action: func(store *NodeStore) {
|
||||
node := createTestNode(1, 1, "user1", "node1")
|
||||
store.PutNode(node)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
|
||||
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create store with initial node and add more",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
initialNodes := types.Nodes{&node1}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial state",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
assert.Empty(t, snapshot.peersByNode[1])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add second node same user",
|
||||
action: func(store *NodeStore) {
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
store.PutNode(node2)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
// Now both nodes should see each other as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add third node different user",
|
||||
action: func(store *NodeStore) {
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
store.PutNode(node3)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// All nodes should see the other 2 as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||
|
||||
// User groupings
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test node deletion",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
initialNodes := types.Nodes{&node1, &node2, &node3}
|
||||
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial 3 nodes",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete middle node",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(2)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// Node 2 should be gone
|
||||
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
|
||||
|
||||
// Remaining nodes should see each other as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
// User groupings updated
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete all remaining nodes",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(1)
|
||||
store.DeleteNode(3)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test node updates",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial hostnames",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update node hostname",
|
||||
action: func(store *NodeStore) {
|
||||
store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "updated-node1"
|
||||
n.GivenName = "updated-node1"
|
||||
})
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
|
||||
|
||||
// Peers should still work correctly
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test with odd-even peers filtering",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
return NewNodeStore(nil, oddEvenPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "add nodes with odd-even filtering",
|
||||
action: func(store *NodeStore) {
|
||||
// Add nodes in sequence
|
||||
store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
|
||||
// Verify odd-even peer relationships
|
||||
require.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete odd node and verify even nodes unaffected",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(1)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
|
||||
// Node 3 (odd) should now have no peers
|
||||
assert.Empty(t, snapshot.peersByNode[3])
|
||||
|
||||
// Even nodes should still see each other
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := tt.setupFunc(t)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
for _, step := range tt.steps {
|
||||
t.Run(step.name, func(t *testing.T) {
|
||||
step.action(store)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testStep struct {
|
||||
name string
|
||||
action func(store *NodeStore)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool {
|
||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@ -355,6 +356,7 @@ func (node *Node) Proto() *v1.Node {
|
||||
GivenName: node.GivenName,
|
||||
User: node.User.Proto(),
|
||||
ForcedTags: node.ForcedTags,
|
||||
Online: node.IsOnline != nil && *node.IsOnline,
|
||||
|
||||
// Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has
|
||||
// to be populated manually with PrimaryRoute, to ensure it includes the
|
||||
@ -419,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix {
|
||||
}
|
||||
|
||||
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
||||
//
|
||||
// IMPORTANT: This method is used for internal data structures and should NOT be used
|
||||
// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually
|
||||
// with PrimaryRoutes to ensure it includes only routes actively served by the node.
|
||||
// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto.
|
||||
func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
var routes []netip.Prefix
|
||||
|
||||
@ -511,11 +518,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
||||
}
|
||||
|
||||
if node.Hostname != hostInfo.Hostname {
|
||||
log.Trace().
|
||||
Str("node.id", node.ID.String()).
|
||||
Str("old_hostname", node.Hostname).
|
||||
Str("new_hostname", hostInfo.Hostname).
|
||||
Str("old_given_name", node.GivenName).
|
||||
Bool("given_name_changed", node.GivenNameHasBeenChanged()).
|
||||
Msg("Updating hostname from hostinfo")
|
||||
|
||||
if node.GivenNameHasBeenChanged() {
|
||||
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
|
||||
}
|
||||
|
||||
node.Hostname = hostInfo.Hostname
|
||||
|
||||
log.Trace().
|
||||
Str("node.id", node.ID.String()).
|
||||
Str("new_hostname", node.Hostname).
|
||||
Str("new_given_name", node.GivenName).
|
||||
Msg("Hostname updated")
|
||||
}
|
||||
}
|
||||
|
||||
@ -759,6 +780,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix {
|
||||
return v.ж.ExitRoutes()
|
||||
}
|
||||
|
||||
// RequestTags returns the ACL tags that the node is requesting.
|
||||
func (v NodeView) RequestTags() []string {
|
||||
if !v.Valid() || !v.Hostinfo().Valid() {
|
||||
return []string{}
|
||||
}
|
||||
return v.Hostinfo().RequestTags().AsSlice()
|
||||
}
|
||||
|
||||
// Proto converts the NodeView to a protobuf representation.
|
||||
func (v NodeView) Proto() *v1.Node {
|
||||
if !v.Valid() {
|
||||
return nil
|
||||
}
|
||||
return v.ж.Proto()
|
||||
}
|
||||
|
||||
// HasIP reports if a node has a given IP address.
|
||||
func (v NodeView) HasIP(i netip.Addr) bool {
|
||||
if !v.Valid() {
|
||||
|
@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
}
|
||||
|
||||
// Parse each hop line
|
||||
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
|
||||
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
|
||||
|
||||
for i := 1; i < len(lines); i++ {
|
||||
matches := hopRegex.FindStringSubmatch(lines[i])
|
||||
|
@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}
|
||||
assertTailscaleNodesLogout(t, allClients)
|
||||
}, shortAccessTTL+10*time.Second, 5*time.Second)
|
||||
}
|
||||
|
||||
|
@ -547,6 +547,8 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
// Wait for nodestore batch processing to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||
var nodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
err := executeAndUnmarshal(
|
||||
@ -642,27 +644,34 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"node",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodes,
|
||||
)
|
||||
// Wait for nodestore batch processing to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||
assert.Eventually(t, func() bool {
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"node",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodes,
|
||||
)
|
||||
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
if err != nil || len(nodes) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
|
||||
givenName := fmt.Sprintf("%d-givenname", node.GetId())
|
||||
assert.Equal(t, hostname+"NEW", node.GetName())
|
||||
assert.Equal(t, givenName, node.GetGivenName())
|
||||
}
|
||||
for _, node := range nodes {
|
||||
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
|
||||
givenName := fmt.Sprintf("%d-givenname", node.GetId())
|
||||
if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix")
|
||||
}
|
||||
|
||||
func TestExpireNode(t *testing.T) {
|
||||
|
@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
assert.Len(t, node.GetSubnetRoutes(), 1)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
assert.NotNil(t, peerStatus.PrimaryRoutes)
|
||||
|
||||
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3)
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
|
||||
|
||||
_, err = headscale.ApproveRoutes(
|
||||
1,
|
||||
@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.GetId() == 1 {
|
||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
|
||||
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(t, node.GetSubnetRoutes())
|
||||
} else if node.GetId() == 2 {
|
||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(t, node.GetApprovedRoutes())
|
||||
assert.Empty(t, node.GetSubnetRoutes())
|
||||
} else {
|
||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
|
||||
for _, node := range nodes {
|
||||
if node.GetId() == 1 {
|
||||
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
|
||||
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(c, node.GetSubnetRoutes())
|
||||
} else if node.GetId() == 2 {
|
||||
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(c, node.GetApprovedRoutes())
|
||||
assert.Empty(c, node.GetSubnetRoutes())
|
||||
} else {
|
||||
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route configuration changes after advertising routes
|
||||
var nodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved")
|
||||
|
||||
// Verify that no routes has been sent to the client,
|
||||
// they are not yet enabled.
|
||||
@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route approval on first subnet router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route")
|
||||
|
||||
// Verify that the client has routes from the primary machine and can access
|
||||
// the webservice.
|
||||
@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route approval on second subnet router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route")
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1 = subRouter1.MustStatus()
|
||||
@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route approval on third subnet router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route")
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1 = subRouter1.MustStatus()
|
||||
@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 13)
|
||||
|
||||
tr, err = client.Traceroute(webip)
|
||||
require.NoError(t, err)
|
||||
assertTracerouteViaIP(t, tr, subRouter1.MustIPv4())
|
||||
// Wait for traceroute to work correctly through the expected router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Get the expected router IP - use a more robust approach to handle temporary disconnections
|
||||
ips, err := subRouter1.IPs()
|
||||
assert.NoError(c, err)
|
||||
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
|
||||
|
||||
var expectedIP netip.Addr
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
expectedIP = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
|
||||
}, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1")
|
||||
|
||||
// Take down the current primary
|
||||
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
|
||||
@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter1.Down()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for router status changes after r1 goes down
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
srs2 = subRouter2.MustStatus()
|
||||
clientStatus = client.MustStatus()
|
||||
|
||||
srs2 = subRouter2.MustStatus()
|
||||
clientStatus = client.MustStatus()
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up")
|
||||
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
|
||||
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
require.NotNil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter2.Down()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for router status changes after r2 goes down
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientStatus, err = client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// TODO(kradalby): Check client status
|
||||
// Both are expected to be down
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
// Verify that the route is not presented from either router
|
||||
clientStatus, err = client.Status()
|
||||
require.NoError(t, err)
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
|
||||
assert.False(c, srs2PeerStatus.Online, "r2 should be offline")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||
}, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter1.Up()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for router status changes after r1 comes back up
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientStatus, err = client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
require.NoError(t, err)
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
|
||||
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available")
|
||||
assert.True(c, srs1PeerStatus.Online, "r1 should be back online")
|
||||
assert.False(c, srs2PeerStatus.Online, "r2 should still be offline")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should still be online")
|
||||
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter2.Up()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for nodestore batch processing to complete and online status to be updated
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientStatus, err = client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
require.NoError(t, err)
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
|
||||
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up")
|
||||
assert.True(c, srs1PeerStatus.Online, "r1 should be online")
|
||||
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||
}, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for nodestore batch processing and route state changes to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
// After disabling route on r3, r1 should become primary with 1 subnet route
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3")
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for nodestore batch processing and route state changes to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
// After disabling route on r1, r2 should become primary with 1 subnet route
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1")
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
|
||||
)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes after re-enabling r1
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
}, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping")
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 0, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1, _ := subRouter1.Status()
|
||||
@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) {
|
||||
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
||||
requireNodeRouteCount(t, nodes[1], 2, 2, 2)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
require.NotNil(t, peerStatus.AllowedIPs)
|
||||
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4)
|
||||
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
||||
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
|
||||
assert.NotNil(c, peerStatus.AllowedIPs)
|
||||
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
|
||||
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
||||
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
|
||||
}
|
||||
|
||||
// TestSubnetRouterMultiNetwork is an evolution of the subnet router test.
|
||||
@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes and clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
// Verify that the routes have been sent to the client
|
||||
status, err = user2c.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = user2c.Status()
|
||||
require.NoError(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref})
|
||||
}
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
|
||||
|
||||
usernet1, err := scenario.Network("usernet1")
|
||||
require.NoError(t, err)
|
||||
@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
||||
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes and clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
||||
// Verify that the routes have been sent to the client
|
||||
status, err = user2c.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = user2c.Status()
|
||||
require.NoError(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
||||
}
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
|
||||
|
||||
// Tell user2c to use user1c as an exit node.
|
||||
command = []string{
|
||||
@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
var nodes []*v1.Node
|
||||
opts := []hsic.Option{
|
||||
hsic.WithTestName("autoapprovemulti"),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err := headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err := client.Status()
|
||||
@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
err = headscale.SetPolicy(tt.pol)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
err = headscale.SetPolicy(tt.pol)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
_, _, err = routerSubRoute.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
_, _, err = routerSubRoute.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 0, 0, 0)
|
||||
|
||||
@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
_, _, err = routerExitNode.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 2, 2, 2)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) {
|
||||
require.Equal(t, tr.Route[0].IP, ip)
|
||||
}
|
||||
|
||||
// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT
|
||||
func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) {
|
||||
assert.NotNil(c, tr)
|
||||
assert.True(c, tr.Success)
|
||||
assert.NoError(c, tr.Err)
|
||||
assert.NotEmpty(c, tr.Route)
|
||||
assert.Equal(c, tr.Route[0].IP, ip)
|
||||
}
|
||||
|
||||
// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes.
|
||||
func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
||||
t.Helper()
|
||||
@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
|
||||
}
|
||||
}
|
||||
|
||||
func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
||||
if status.AllowedIPs.Len() <= 2 && len(expected) != 0 {
|
||||
assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected))
|
||||
return
|
||||
}
|
||||
|
||||
if len(expected) == 0 {
|
||||
expected = []netip.Prefix{}
|
||||
}
|
||||
|
||||
got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool {
|
||||
if tsaddr.IsExitRoute(p) {
|
||||
return true
|
||||
}
|
||||
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
|
||||
})
|
||||
|
||||
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
|
||||
assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff))
|
||||
}
|
||||
}
|
||||
|
||||
func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
|
||||
t.Helper()
|
||||
require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
||||
@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
|
||||
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||
}
|
||||
|
||||
func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) {
|
||||
assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
||||
assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
|
||||
assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||
}
|
||||
|
||||
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
|
||||
// that are explicitly allowed in the ACL.
|
||||
func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give some time for the routes to propagate
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// List nodes and verify the router has 3 available routes
|
||||
nodes, err = headscale.NodesByUser()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
||||
// List nodes and verify the router has 3 available routes
|
||||
nodes, err = headscale.NodesByUser()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
// Find the router node
|
||||
routerNode = nodes[routerUser][0]
|
||||
|
||||
// Find the router node
|
||||
routerNode = nodes[routerUser][0]
|
||||
|
||||
// Check that the router has 3 routes now approved and available
|
||||
requireNodeRouteCount(t, routerNode, 3, 3, 3)
|
||||
// Check that the router has 3 routes now approved and available
|
||||
requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Now check the client node status
|
||||
nodeStatus, err := nodeClient.Status()
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
|
||||
return nil, fmt.Errorf("no network named: %s", name)
|
||||
}
|
||||
|
||||
for _, ipam := range net.Network.IPAM.Config {
|
||||
pref, err := netip.ParsePrefix(ipam.Subnet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pref, nil
|
||||
if len(net.Network.IPAM.Config) == 0 {
|
||||
return nil, fmt.Errorf("no IPAM config found in network: %s", name)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no prefix found in network: %s", name)
|
||||
pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pref, nil
|
||||
}
|
||||
|
||||
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
|
||||
@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv(
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Strings(s.spec.Users)
|
||||
for _, user := range s.spec.Users {
|
||||
u, err := s.CreateUser(user)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user