mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
state/nodestore: in memory representation of nodes
Initial work on a nodestore which stores all of the nodes and their relations in memory with relationship for peers precalculated. It is a copy-on-write structure, replacing the "snapshot" when a change to the structure occurs. It is optimised for reads, and while batches are not fast, they are grouped together to do less of the expensive peer calculation if there are many changes rapidly. Writes will block until commited, while reads are never blocked. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b155f30ef6
commit
2e20652fdf
@ -551,13 +551,12 @@ be assigned to nodes.`,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if confirm || force {
|
if confirm || force {
|
||||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force })
|
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
@ -137,9 +137,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
|
|
||||||
// Initialize ephemeral garbage collector
|
// Initialize ephemeral garbage collector
|
||||||
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||||
node, err := app.state.GetNodeByID(ni)
|
node, ok := app.state.GetNodeByID(ni)
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
|
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
|
||||||
|
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -379,15 +380,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
log.Trace().
|
if err := func() error {
|
||||||
Caller().
|
log.Trace().
|
||||||
Str("client_address", req.RemoteAddr).
|
|
||||||
Msg("HTTP authentication invoked")
|
|
||||||
|
|
||||||
authHeader := req.Header.Get("authorization")
|
|
||||||
|
|
||||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", req.RemoteAddr).
|
Str("client_address", req.RemoteAddr).
|
||||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||||
@ -501,11 +495,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
|
var err error
|
||||||
capver.CanOldCodeBeCleanedUp()
|
capver.CanOldCodeBeCleanedUp()
|
||||||
|
|
||||||
if profilingEnabled {
|
if profilingEnabled {
|
||||||
if profilingPath != "" {
|
if profilingPath != "" {
|
||||||
err := os.MkdirAll(profilingPath, os.ModePerm)
|
err = os.MkdirAll(profilingPath, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||||
}
|
}
|
||||||
@ -559,12 +554,9 @@ func (h *Headscale) Serve() error {
|
|||||||
// around between restarts, they will reconnect and the GC will
|
// around between restarts, they will reconnect and the GC will
|
||||||
// be cancelled.
|
// be cancelled.
|
||||||
go h.ephemeralGC.Start()
|
go h.ephemeralGC.Start()
|
||||||
ephmNodes, err := h.state.ListEphemeralNodes()
|
ephmNodes := h.state.ListEphemeralNodes()
|
||||||
if err != nil {
|
for _, node := range ephmNodes.All() {
|
||||||
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
|
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
}
|
|
||||||
for _, node := range ephmNodes {
|
|
||||||
h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.cfg.DNSConfig.ExtraRecordsPath != "" {
|
if h.cfg.DNSConfig.ExtraRecordsPath != "" {
|
||||||
@ -794,23 +786,14 @@ func (h *Headscale) Serve() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
changed, err := h.state.ReloadPolicy()
|
changes, err := h.state.ReloadPolicy()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("reloading policy")
|
log.Error().Err(err).Msgf("reloading policy")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if changed {
|
h.Change(changes...)
|
||||||
log.Info().
|
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
|
||||||
|
|
||||||
err = h.state.AutoApproveNodes()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Change(change.PolicySet)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
info := func(msg string) { log.Info().Msg(msg) }
|
info := func(msg string) { log.Info().Msg(msg) }
|
||||||
log.Info().
|
log.Info().
|
||||||
@ -1020,6 +1003,6 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||||||
// Change is used to send changes to nodes.
|
// Change is used to send changes to nodes.
|
||||||
// All change should be enqueued here and empty will be automatically
|
// All change should be enqueued here and empty will be automatically
|
||||||
// ignored.
|
// ignored.
|
||||||
func (h *Headscale) Change(c change.ChangeSet) {
|
func (h *Headscale) Change(cs ...change.ChangeSet) {
|
||||||
h.mapBatcher.AddWork(c)
|
h.mapBatcher.AddWork(cs...)
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
@ -29,28 +28,10 @@ func (h *Headscale) handleRegister(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, fmt.Errorf("looking up node in database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if node != nil {
|
if ok {
|
||||||
// If an existing node is trying to register with an auth key,
|
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
||||||
// we need to validate the auth key even for existing nodes
|
|
||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
|
||||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
|
||||||
if err != nil {
|
|
||||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
|
||||||
var httpErr HTTPError
|
|
||||||
if errors.As(err, &httpErr) {
|
|
||||||
return nil, httpErr
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||||
}
|
}
|
||||||
@ -70,6 +51,7 @@ func (h *Headscale) handleRegister(
|
|||||||
if errors.As(err, &httpErr) {
|
if errors.As(err, &httpErr) {
|
||||||
return nil, httpErr
|
return nil, httpErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,13 +71,22 @@ func (h *Headscale) handleExistingNode(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
|
|
||||||
if node.MachineKey != machineKey {
|
if node.MachineKey != machineKey {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
expired := node.IsExpired()
|
expired := node.IsExpired()
|
||||||
|
|
||||||
|
// If the node is expired and this is not a re-authentication attempt,
|
||||||
|
// force the client to re-authenticate
|
||||||
|
if expired && regReq.Auth == nil {
|
||||||
|
return &tailcfg.RegisterResponse{
|
||||||
|
NodeKeyExpired: true,
|
||||||
|
MachineAuthorized: false,
|
||||||
|
AuthURL: "", // Client will need to re-authenticate
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
if !expired && !regReq.Expiry.IsZero() {
|
if !expired && !regReq.Expiry.IsZero() {
|
||||||
requestExpiry := regReq.Expiry
|
requestExpiry := regReq.Expiry
|
||||||
|
|
||||||
@ -107,7 +98,7 @@ func (h *Headscale) handleExistingNode(
|
|||||||
// If the request expiry is in the past, we consider it a logout.
|
// If the request expiry is in the past, we consider it a logout.
|
||||||
if requestExpiry.Before(time.Now()) {
|
if requestExpiry.Before(time.Now()) {
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
c, err := h.state.DeleteNode(node)
|
c, err := h.state.DeleteNode(node.View())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||||
}
|
}
|
||||||
@ -118,15 +109,19 @@ func (h *Headscale) handleExistingNode(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.Change(c)
|
h.Change(c)
|
||||||
}
|
|
||||||
|
|
||||||
return nodeToRegisterResponse(node), nil
|
// CRITICAL: Use the updated node view for the response
|
||||||
|
// The original node object has stale expiry information
|
||||||
|
node = updatedNode.AsStruct()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodeToRegisterResponse(node), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||||
@ -177,7 +172,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||||
regReq,
|
regReq,
|
||||||
machineKey,
|
machineKey,
|
||||||
)
|
)
|
||||||
@ -193,8 +188,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If node is nil, it means an ephemeral node was deleted during logout
|
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||||
if node == nil {
|
if !node.Valid() {
|
||||||
h.Change(changed)
|
h.Change(changed)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -213,26 +208,30 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||||
// now since we dont update the node/pol here anymore
|
// now since we dont update the node/pol here anymore
|
||||||
routeChange := h.state.AutoApproveRoutes(node)
|
routeChange := h.state.AutoApproveRoutes(node)
|
||||||
|
|
||||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeChange && changed.Empty() {
|
if routeChange && changed.Empty() {
|
||||||
changed = change.NodeAdded(node.ID)
|
changed = change.NodeAdded(node.ID())
|
||||||
}
|
}
|
||||||
h.Change(changed)
|
h.Change(changed)
|
||||||
|
|
||||||
// If policy changed due to node registration, send a separate policy change
|
// TODO(kradalby): I think this is covered above, but we need to validate that.
|
||||||
if policyChanged {
|
// // If policy changed due to node registration, send a separate policy change
|
||||||
policyChange := change.PolicyChange()
|
// if policyChanged {
|
||||||
h.Change(policyChange)
|
// policyChange := change.PolicyChange()
|
||||||
}
|
// h.Change(policyChange)
|
||||||
|
// }
|
||||||
|
|
||||||
|
user := node.User()
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
MachineAuthorized: true,
|
MachineAuthorized: true,
|
||||||
NodeKeyExpired: node.IsExpired(),
|
NodeKeyExpired: node.IsExpired(),
|
||||||
User: *node.User.TailscaleUser(),
|
User: *user.TailscaleUser(),
|
||||||
Login: *node.User.TailscaleLogin(),
|
Login: *user.TailscaleLogin(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
)
|
)
|
||||||
|
|
||||||
log.Info().Msgf("Starting node registration using key: %s", registrationId)
|
log.Info().Msgf("Starting node registration using key: %s", registrationId)
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
AuthURL: h.authProvider.AuthURL(registrationId),
|
AuthURL: h.authProvider.AuthURL(registrationId),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||||
// and renames it. If the name is not unique, it will return an error.
|
// and renames it. Validation should be done in the state layer before calling this function.
|
||||||
func RenameNode(tx *gorm.DB,
|
func RenameNode(tx *gorm.DB,
|
||||||
nodeID types.NodeID, newName string,
|
nodeID types.NodeID, newName string,
|
||||||
) error {
|
) error {
|
||||||
err := util.CheckForFQDNRules(
|
// Check if the new name is unique
|
||||||
newName,
|
var count int64
|
||||||
)
|
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
||||||
return fmt.Errorf("renaming node: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uniq, err := isUniqueName(tx, newName)
|
if count > 0 {
|
||||||
if err != nil {
|
return errors.New("name is not unique")
|
||||||
return fmt.Errorf("checking if name is unique: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !uniq {
|
|
||||||
return fmt.Errorf("name is not unique: %s", newName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||||
@ -333,108 +327,19 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
// RegisterNodeForTest is used only for testing purposes to register a node directly in the database.
|
||||||
// with a registrationID to register or reauthenticate a node.
|
// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey.
|
||||||
// If the node found in the registration cache is not already registered,
|
func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||||
// it will be registered with the user and the node will be removed from the cache.
|
if !testing.Testing() {
|
||||||
// If the node is already registered, the expiry will be updated.
|
panic("RegisterNodeForTest can only be called during tests")
|
||||||
// The node, and a boolean indicating if it was a new node or not, will be returned.
|
}
|
||||||
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|
||||||
registrationID types.RegistrationID,
|
|
||||||
userID types.UserID,
|
|
||||||
nodeExpiry *time.Time,
|
|
||||||
registrationMethod string,
|
|
||||||
ipv4 *netip.Addr,
|
|
||||||
ipv6 *netip.Addr,
|
|
||||||
) (*types.Node, change.ChangeSet, error) {
|
|
||||||
var nodeChange change.ChangeSet
|
|
||||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
|
||||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
|
||||||
user, err := GetUserByID(tx, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf(
|
|
||||||
"failed to find user in register node from auth callback, %w",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().
|
|
||||||
Str("registration_id", registrationID.String()).
|
|
||||||
Str("username", user.Username()).
|
|
||||||
Str("registrationMethod", registrationMethod).
|
|
||||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
|
||||||
Msg("Registering node from API/CLI or auth callback")
|
|
||||||
|
|
||||||
// TODO(kradalby): This looks quite wrong? why ID 0?
|
|
||||||
// Why not always?
|
|
||||||
// Registration of expired node with different user
|
|
||||||
if reg.Node.ID != 0 &&
|
|
||||||
reg.Node.UserID != user.ID {
|
|
||||||
return nil, ErrDifferentRegisteredUser
|
|
||||||
}
|
|
||||||
|
|
||||||
reg.Node.UserID = user.ID
|
|
||||||
reg.Node.User = *user
|
|
||||||
reg.Node.RegisterMethod = registrationMethod
|
|
||||||
|
|
||||||
if nodeExpiry != nil {
|
|
||||||
reg.Node.Expiry = nodeExpiry
|
|
||||||
}
|
|
||||||
|
|
||||||
node, err := RegisterNode(
|
|
||||||
tx,
|
|
||||||
reg.Node,
|
|
||||||
ipv4, ipv6,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
hsdb.regCache.Delete(registrationID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal to waiting clients that the machine has been registered.
|
|
||||||
select {
|
|
||||||
case reg.Registered <- node:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
close(reg.Registered)
|
|
||||||
|
|
||||||
nodeChange = change.NodeAdded(node.ID)
|
|
||||||
|
|
||||||
return node, err
|
|
||||||
} else {
|
|
||||||
// If the node is already registered, this is a refresh.
|
|
||||||
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeChange = change.KeyExpiry(node.ID)
|
|
||||||
|
|
||||||
return node, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrNodeNotFoundRegistrationCache
|
|
||||||
})
|
|
||||||
|
|
||||||
return node, nodeChange, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
|
||||||
return RegisterNode(tx, node, ipv4, ipv6)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterNode is executed from the CLI to register a new Node using its MachineKey.
|
|
||||||
func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str("node_key", node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Username()).
|
Str("user", node.User.Username()).
|
||||||
Msg("Registering node")
|
Msg("Registering test node")
|
||||||
|
|
||||||
// If the a new node is registered with the same machine key, to the same user,
|
// If the a new node is registered with the same machine key, to the same user,
|
||||||
// update the existing node.
|
// update the existing node.
|
||||||
@ -445,8 +350,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
node.ID = oldNode.ID
|
node.ID = oldNode.ID
|
||||||
node.GivenName = oldNode.GivenName
|
node.GivenName = oldNode.GivenName
|
||||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||||
ipv4 = oldNode.IPv4
|
// Don't overwrite the provided IPs with old ones when they exist
|
||||||
ipv6 = oldNode.IPv6
|
if ipv4 == nil {
|
||||||
|
ipv4 = oldNode.IPv4
|
||||||
|
}
|
||||||
|
if ipv6 == nil {
|
||||||
|
ipv6 = oldNode.IPv6
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the node exists and it already has IP(s), we just save it
|
// If the node exists and it already has IP(s), we just save it
|
||||||
@ -463,7 +373,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str("node_key", node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Username()).
|
Str("user", node.User.Username()).
|
||||||
Msg("Node authorized again")
|
Msg("Test node authorized again")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
@ -472,7 +382,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
node.IPv6 = ipv6
|
node.IPv6 = ipv6
|
||||||
|
|
||||||
if node.GivenName == "" {
|
if node.GivenName == "" {
|
||||||
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
|
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||||
}
|
}
|
||||||
@ -487,7 +397,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("Node registered with the database")
|
Msg("Test node registered with the database")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
@ -560,7 +470,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
|||||||
return len(nodes) == 0, nil
|
return len(nodes) == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureUniqueGivenName(
|
// EnsureUniqueGivenName generates a unique given name for a node based on its hostname.
|
||||||
|
func EnsureUniqueGivenName(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
name string,
|
name string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
@ -781,19 +692,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
|||||||
|
|
||||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||||
|
|
||||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
// Allocate IPs for the test node using the database's IP allocator
|
||||||
_, err := RegisterNode(tx, *node, nil, nil)
|
// This is a simplified allocation for testing - in production this would use State.ipAlloc
|
||||||
|
ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
var registeredNode *types.Node
|
||||||
|
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
var err error
|
||||||
|
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return registeredNode
|
return registeredNode
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -842,3 +757,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
|||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allocateTestIPs allocates sequential test IPs for nodes during testing.
|
||||||
|
func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("allocateTestIPs can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use simple sequential allocation for tests
|
||||||
|
// IPv4: 100.64.0.x (where x is nodeID)
|
||||||
|
// IPv6: fd7a:115c:a1e0::x (where x is nodeID)
|
||||||
|
|
||||||
|
if nodeID > 254 {
|
||||||
|
return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)})
|
||||||
|
ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)})
|
||||||
|
|
||||||
|
return &ipv4, &ipv6, nil
|
||||||
|
}
|
||||||
|
@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
|||||||
|
|
||||||
func TestAutoApproveRoutes(t *testing.T) {
|
func TestAutoApproveRoutes(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
acl string
|
acl string
|
||||||
routes []netip.Prefix
|
routes []netip.Prefix
|
||||||
want []netip.Prefix
|
want []netip.Prefix
|
||||||
want2 []netip.Prefix
|
want2 []netip.Prefix
|
||||||
|
expectChange bool // whether to expect route changes
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "no-auto-approvers-empty-policy",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:admins": ["test@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admins"],
|
||||||
|
"dst": ["group:admins:*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||||
|
want: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||||
|
want2: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||||
|
expectChange: false, // No changes expected
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-auto-approvers-explicit-empty",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:admins": ["test@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admins"],
|
||||||
|
"dst": ["group:admins:*"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {},
|
||||||
|
"exitNode": []
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||||
|
want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||||
|
want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||||
|
expectChange: false, // No changes expected
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "2068-approve-issue-sub-kube",
|
name: "2068-approve-issue-sub-kube",
|
||||||
acl: `
|
acl: `
|
||||||
@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||||
|
expectChange: true, // Routes should be approved
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "2068-approve-issue-sub-exit-tag",
|
name: "2068-approve-issue-sub-exit-tag",
|
||||||
@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
tsaddr.AllIPv4(),
|
tsaddr.AllIPv4(),
|
||||||
tsaddr.AllIPv6(),
|
tsaddr.AllIPv6(),
|
||||||
},
|
},
|
||||||
|
expectChange: true, // Routes should be approved
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, pm)
|
require.NotNil(t, pm)
|
||||||
|
|
||||||
changed1 := policy.AutoApproveRoutes(pm, &node)
|
newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes)
|
||||||
assert.True(t, changed1)
|
assert.Equal(t, tt.expectChange, changed1)
|
||||||
|
|
||||||
err = adb.DB.Save(&node).Error
|
if changed1 {
|
||||||
require.NoError(t, err)
|
err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
|
newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes)
|
||||||
|
if changed2 {
|
||||||
err = adb.DB.Save(&nodeTagged).Error
|
err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
node1ByID, err := adb.GetNodeByID(1)
|
node1ByID, err := adb.GetNodeByID(1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
// For empty auto-approvers tests, handle nil vs empty slice comparison
|
||||||
|
expectedRoutes1 := tt.want
|
||||||
|
if len(expectedRoutes1) == 0 {
|
||||||
|
expectedRoutes1 = nil
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
node2ByID, err := adb.GetNodeByID(2)
|
node2ByID, err := adb.GetNodeByID(2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
expectedRoutes2 := tt.want2
|
||||||
|
if len(expectedRoutes2) == 0 {
|
||||||
|
expectedRoutes2 = nil
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -620,11 +679,11 @@ func TestRenameNode(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node, nil, nil)
|
_, err := RegisterNodeForTest(tx, node, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
@ -721,11 +780,11 @@ func TestListPeers(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node1, nil, nil)
|
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) {
|
|||||||
// No parameter means no filter, should return all peers
|
// No parameter means no filter, should return all peers
|
||||||
nodes, err = db.ListPeers(1)
|
nodes, err = db.ListPeers(1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Empty node list should return all peers
|
// Empty node list should return all peers
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// No match in IDs should return empty list and no error
|
// No match in IDs should return empty list and no error
|
||||||
@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) {
|
|||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs, but node ID is still filtered out
|
// Several matched IDs, but node ID is still filtered out
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -806,11 +865,11 @@ func TestListNodes(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node1, nil, nil)
|
_, err := RegisterNodeForTest(tx, node1, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) {
|
|||||||
// No parameter means no filter, should return all nodes
|
// No parameter means no filter, should return all nodes
|
||||||
nodes, err = db.ListNodes()
|
nodes, err = db.ListNodes()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(nodes))
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
// Empty node list should return all nodes
|
// Empty node list should return all nodes
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(nodes))
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) {
|
|||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs
|
// Several matched IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(nodes))
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -47,8 +48,9 @@ func CreatePreAuthKey(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove duplicates
|
// Remove duplicates and sort for consistency
|
||||||
aclTags = set.SetOf(aclTags).Slice()
|
aclTags = set.SetOf(aclTags).Slice()
|
||||||
|
slices.Sort(aclTags)
|
||||||
|
|
||||||
// TODO(kradalby): factor out and create a reusable tag validation,
|
// TODO(kradalby): factor out and create a reusable tag validation,
|
||||||
// check if there is one in Tailscale's lib.
|
// check if there is one in Tailscale's lib.
|
||||||
|
@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser assigns a Node to a user.
|
// AssignNodeToUser assigns a Node to a user.
|
||||||
|
// Note: Validation should be done in the state layer before calling this function.
|
||||||
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
|
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
|
||||||
node, err := GetNodeByID(tx, nodeID)
|
// Check if the user exists
|
||||||
if err != nil {
|
var userExists bool
|
||||||
return err
|
if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||||
}
|
}
|
||||||
user, err := GetUserByID(tx, uid)
|
|
||||||
if err != nil {
|
if !userExists {
|
||||||
return err
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
node.User = *user
|
|
||||||
node.UserID = user.ID
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil {
|
||||||
if result := tx.Save(&node); result.Error != nil {
|
return fmt.Errorf("failed to assign node to user: %w", err)
|
||||||
return result.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sshPol := make(map[string]*tailcfg.SSHPolicy)
|
sshPol := make(map[string]*tailcfg.SSHPolicy)
|
||||||
for _, node := range nodes {
|
for _, node := range nodes.All() {
|
||||||
pol, err := h.state.SSHPolicy(node.View())
|
pol, err := h.state.SSHPolicy(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol
|
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol
|
||||||
}
|
}
|
||||||
|
|
||||||
sshJSON, err := json.MarshalIndent(sshPol, "", " ")
|
sshJSON, err := json.MarshalIndent(sshPol, "", " ")
|
||||||
|
@ -15,7 +15,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@ -25,6 +24,7 @@ import (
|
|||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/views"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/state"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
@ -59,9 +59,10 @@ func (api headscaleV1APIServer) CreateUser(
|
|||||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
c := change.UserAdded(types.UserID(user.ID))
|
c := change.UserAdded(types.UserID(user.ID))
|
||||||
if policyChanged {
|
|
||||||
|
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
|
||||||
|
if !policyChanged.Empty() {
|
||||||
c.Change = change.Policy
|
c.Change = change.Policy
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,15 +80,13 @@ func (api headscaleV1APIServer) RenameUser(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
_, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
// Send policy update notifications if needed
|
||||||
if policyChanged {
|
api.h.Change(c)
|
||||||
api.h.Change(change.PolicyChange())
|
|
||||||
}
|
|
||||||
|
|
||||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -288,17 +287,13 @@ func (api headscaleV1APIServer) GetNode(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetNodeRequest,
|
request *v1.GetNodeRequest,
|
||||||
) (*v1.GetNodeResponse, error) {
|
) (*v1.GetNodeResponse, error) {
|
||||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := node.Proto()
|
resp := node.Proto()
|
||||||
|
|
||||||
// Populate the online field based on
|
|
||||||
// currently connected nodes.
|
|
||||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
|
||||||
|
|
||||||
return &v1.GetNodeResponse{Node: resp}, nil
|
return &v1.GetNodeResponse{Node: resp}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,7 +318,8 @@ func (api headscaleV1APIServer) SetTags(
|
|||||||
api.h.Change(nodeChange)
|
api.h.Change(nodeChange)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Caller().
|
||||||
|
Str("node", node.Hostname()).
|
||||||
Strs("tags", request.GetTags()).
|
Strs("tags", request.GetTags()).
|
||||||
Msg("Changing tags of node")
|
Msg("Changing tags of node")
|
||||||
|
|
||||||
@ -334,7 +330,13 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.SetApprovedRoutesRequest,
|
request *v1.SetApprovedRoutesRequest,
|
||||||
) (*v1.SetApprovedRoutesResponse, error) {
|
) (*v1.SetApprovedRoutesResponse, error) {
|
||||||
var routes []netip.Prefix
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", request.GetNodeId()).
|
||||||
|
Strs("requestedRoutes", request.GetRoutes()).
|
||||||
|
Msg("gRPC SetApprovedRoutes called")
|
||||||
|
|
||||||
|
var newApproved []netip.Prefix
|
||||||
for _, route := range request.GetRoutes() {
|
for _, route := range request.GetRoutes() {
|
||||||
prefix, err := netip.ParsePrefix(route)
|
prefix, err := netip.ParsePrefix(route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -344,31 +346,35 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
// If the prefix is an exit route, add both. The client expect both
|
// If the prefix is an exit route, add both. The client expect both
|
||||||
// to annotate the node as an exit node.
|
// to annotate the node as an exit node.
|
||||||
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
||||||
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||||
} else {
|
} else {
|
||||||
routes = append(routes, prefix)
|
newApproved = append(newApproved, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tsaddr.SortPrefixes(routes)
|
tsaddr.SortPrefixes(newApproved)
|
||||||
routes = slices.Compact(routes)
|
newApproved = slices.Compact(newApproved)
|
||||||
|
|
||||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
|
||||||
|
|
||||||
// Always propagate node changes from SetApprovedRoutes
|
// Always propagate node changes from SetApprovedRoutes
|
||||||
api.h.Change(nodeChange)
|
api.h.Change(nodeChange)
|
||||||
|
|
||||||
// If routes changed, propagate those changes too
|
|
||||||
if !routeChange.Empty() {
|
|
||||||
api.h.Change(routeChange)
|
|
||||||
}
|
|
||||||
|
|
||||||
proto := node.Proto()
|
proto := node.Proto()
|
||||||
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
|
// Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
|
||||||
|
// routes that are actively served from the node (per architectural requirement in types/node.go)
|
||||||
|
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
|
||||||
|
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Uint64("node.id", node.ID().Uint64()).
|
||||||
|
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||||
|
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
||||||
|
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
||||||
|
Msg("gRPC SetApprovedRoutes completed")
|
||||||
|
|
||||||
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
||||||
}
|
}
|
||||||
@ -390,9 +396,9 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteNodeRequest,
|
request *v1.DeleteNodeRequest,
|
||||||
) (*v1.DeleteNodeResponse, error) {
|
) (*v1.DeleteNodeResponse, error) {
|
||||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeChange, err := api.h.state.DeleteNode(node)
|
nodeChange, err := api.h.state.DeleteNode(node)
|
||||||
@ -420,8 +426,9 @@ func (api headscaleV1APIServer) ExpireNode(
|
|||||||
api.h.Change(nodeChange)
|
api.h.Change(nodeChange)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Caller().
|
||||||
Time("expiry", *node.Expiry).
|
Str("node", node.Hostname()).
|
||||||
|
Time("expiry", *node.AsStruct().Expiry).
|
||||||
Msg("node expired")
|
Msg("node expired")
|
||||||
|
|
||||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||||
@ -440,7 +447,8 @@ func (api headscaleV1APIServer) RenameNode(
|
|||||||
api.h.Change(nodeChange)
|
api.h.Change(nodeChange)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Caller().
|
||||||
|
Str("node", node.Hostname()).
|
||||||
Str("new_name", request.GetNewName()).
|
Str("new_name", request.GetNewName()).
|
||||||
Msg("node renamed")
|
Msg("node renamed")
|
||||||
|
|
||||||
@ -455,58 +463,45 @@ func (api headscaleV1APIServer) ListNodes(
|
|||||||
// the filtering of nodes by user, vs nodes as a whole can
|
// the filtering of nodes by user, vs nodes as a whole can
|
||||||
// probably be done once.
|
// probably be done once.
|
||||||
// TODO(kradalby): This should be done in one tx.
|
// TODO(kradalby): This should be done in one tx.
|
||||||
|
|
||||||
IsConnected := api.h.mapBatcher.ConnectedMap()
|
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
user, err := api.h.state.GetUserByName(request.GetUser())
|
user, err := api.h.state.GetUserByName(request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
response := nodesToProto(api.h.state, nodes)
|
||||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := api.h.state.ListNodes()
|
nodes := api.h.state.ListNodes()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(nodes, func(i, j int) bool {
|
response := nodesToProto(api.h.state, nodes)
|
||||||
return nodes[i].ID < nodes[j].ID
|
|
||||||
})
|
|
||||||
|
|
||||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
|
||||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node {
|
||||||
response := make([]*v1.Node, len(nodes))
|
response := make([]*v1.Node, nodes.Len())
|
||||||
for index, node := range nodes {
|
for index, node := range nodes.All() {
|
||||||
resp := node.Proto()
|
resp := node.Proto()
|
||||||
|
|
||||||
// Populate the online field based on
|
|
||||||
// currently connected nodes.
|
|
||||||
if val, ok := IsConnected.Load(node.ID); ok && val {
|
|
||||||
resp.Online = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var tags []string
|
var tags []string
|
||||||
for _, tag := range node.RequestTags() {
|
for _, tag := range node.RequestTags() {
|
||||||
if state.NodeCanHaveTag(node.View(), tag) {
|
if state.NodeCanHaveTag(node, tag) {
|
||||||
tags = append(tags, tag)
|
tags = append(tags, tag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
|
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
|
||||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
|
|
||||||
|
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sort.Slice(response, func(i, j int) bool {
|
||||||
|
return response[i].Id < response[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -674,17 +669,15 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
// a scenario where they might be allowed if the server has no nodes
|
// a scenario where they might be allowed if the server has no nodes
|
||||||
// yet, but it should help for the general case and for hot reloading
|
// yet, but it should help for the general case and for hot reloading
|
||||||
// configurations.
|
// configurations.
|
||||||
nodes, err := api.h.state.ListNodes()
|
nodes := api.h.state.ListNodes()
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
_, err := api.h.state.SetPolicy([]byte(p))
|
||||||
}
|
|
||||||
changed, err := api.h.state.SetPolicy([]byte(p))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting policy: %w", err)
|
return nil, fmt.Errorf("setting policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if nodes.Len() > 0 {
|
||||||
_, err = api.h.state.SSHPolicy(nodes[0].View())
|
_, err = api.h.state.SSHPolicy(nodes.At(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
}
|
}
|
||||||
@ -695,14 +688,20 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only send update if the packet filter has changed.
|
// Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
|
||||||
if changed {
|
// This ensures that routes are re-evaluated for auto-approval in cases where routes
|
||||||
err = api.h.state.AutoApproveNodes()
|
// were manually disabled but could now be auto-approved with the current policy.
|
||||||
if err != nil {
|
cs, err := api.h.state.ReloadPolicy()
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
return nil, fmt.Errorf("reloading policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
api.h.Change(change.PolicyChange())
|
if len(cs) > 0 {
|
||||||
|
api.h.Change(cs...)
|
||||||
|
} else {
|
||||||
|
log.Debug().
|
||||||
|
Caller().
|
||||||
|
Msg("No policy changes to distribute because ReloadPolicy returned empty changeset")
|
||||||
}
|
}
|
||||||
|
|
||||||
response := &v1.SetPolicyResponse{
|
response := &v1.SetPolicyResponse{
|
||||||
|
@ -94,13 +94,19 @@ func (h *Headscale) handleVerifyRequest(
|
|||||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
|
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := h.state.ListNodes()
|
nodes := h.state.ListNodes()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot list nodes: %w", err)
|
// Check if any node has the requested NodeKey
|
||||||
|
var nodeKeyFound bool
|
||||||
|
for _, node := range nodes.All() {
|
||||||
|
if node.NodeKey() == derpAdmitClientRequest.NodePublic {
|
||||||
|
nodeKeyFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &tailcfg.DERPAdmitClientResponse{
|
resp := &tailcfg.DERPAdmitClientResponse{
|
||||||
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
|
Allow: nodeKeyFound,
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.NewEncoder(writer).Encode(resp)
|
return json.NewEncoder(writer).Encode(resp)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package mapper
|
package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
|||||||
type Batcher interface {
|
type Batcher interface {
|
||||||
Start()
|
Start()
|
||||||
Close()
|
Close()
|
||||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
|
||||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||||
IsConnected(id types.NodeID) bool
|
IsConnected(id types.NodeID) bool
|
||||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||||
AddWork(c change.ChangeSet)
|
AddWork(c change.ChangeSet)
|
||||||
@ -120,7 +121,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
|||||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||||
if nc == nil {
|
if nc == nil {
|
||||||
return fmt.Errorf("nodeConnection is nil")
|
return errors.New("nodeConnection is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := nc.nodeID()
|
nodeID := nc.nodeID()
|
||||||
|
@ -21,8 +21,7 @@ type LockFreeBatcher struct {
|
|||||||
mapper *mapper
|
mapper *mapper
|
||||||
workers int
|
workers int
|
||||||
|
|
||||||
// Lock-free concurrent maps
|
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
|
||||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
|
||||||
connected *xsync.Map[types.NodeID, *time.Time]
|
connected *xsync.Map[types.NodeID, *time.Time]
|
||||||
|
|
||||||
// Work queue channel
|
// Work queue channel
|
||||||
@ -32,7 +31,6 @@ type LockFreeBatcher struct {
|
|||||||
|
|
||||||
// Batching state
|
// Batching state
|
||||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||||
batchMutex sync.RWMutex
|
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
totalNodes atomic.Int64
|
totalNodes atomic.Int64
|
||||||
@ -45,65 +43,63 @@ type LockFreeBatcher struct {
|
|||||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||||
// It creates or updates the node's connection data, validates the initial map generation,
|
// It creates or updates the node's connection data, validates the initial map generation,
|
||||||
// and notifies other nodes that this node has come online.
|
// and notifies other nodes that this node has come online.
|
||||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
addNodeStart := time.Now()
|
||||||
// First validate that we can generate initial map before doing anything else
|
|
||||||
fullSelfChange := change.FullSelf(id)
|
|
||||||
|
|
||||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
// Generate connection ID
|
||||||
// This currently means that the goroutine for the node connection will do the processing
|
connID := generateConnectionID()
|
||||||
// which means that we might have uncontrolled concurrency.
|
|
||||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
// Create new connection entry
|
||||||
// it to be processed in a more controlled manner.
|
now := time.Now()
|
||||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
newEntry := &connectionEntry{
|
||||||
if err != nil {
|
id: connID,
|
||||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
c: c,
|
||||||
|
version: version,
|
||||||
|
created: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only after validation succeeds, create or update node connection
|
// Only after validation succeeds, create or update node connection
|
||||||
newConn := newNodeConn(id, c, version, b.mapper)
|
newConn := newNodeConn(id, c, version, b.mapper)
|
||||||
|
|
||||||
var conn *nodeConn
|
if !loaded {
|
||||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
|
||||||
// Update existing connection
|
|
||||||
existing.updateConnection(c, version)
|
|
||||||
conn = existing
|
|
||||||
} else {
|
|
||||||
b.totalNodes.Add(1)
|
b.totalNodes.Add(1)
|
||||||
conn = newConn
|
conn = newConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark as connected only after validation succeeds
|
|
||||||
b.connected.Store(id, nil) // nil = connected
|
b.connected.Store(id, nil) // nil = connected
|
||||||
|
|
||||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
if err != nil {
|
||||||
|
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||||
|
nodeConn.removeConnectionByChannel(c)
|
||||||
|
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Send the validated initial map
|
// Use a blocking send with timeout for initial map since the channel should be ready
|
||||||
if initialMap != nil {
|
// and we want to avoid the race condition where the receiver isn't ready yet
|
||||||
if err := conn.send(initialMap); err != nil {
|
select {
|
||||||
// Clean up the connection state on send failure
|
case c <- initialMap:
|
||||||
b.nodes.Delete(id)
|
// Success
|
||||||
b.connected.Delete(id)
|
case <-time.After(5 * time.Second):
|
||||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
||||||
}
|
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
||||||
|
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
||||||
// Notify other nodes that this node came online
|
nodeConn.removeConnectionByChannel(c)
|
||||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||||
// It validates the connection channel matches the current one, closes the connection,
|
// It validates the connection channel matches one of the current connections, closes that specific connection,
|
||||||
// and notifies other nodes that this node has gone offline.
|
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
||||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
// Reports if the node still has active connections after removal.
|
||||||
// Check if this is the current connection and mark it as closed
|
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||||
if existing, ok := b.nodes.Load(id); ok {
|
nodeConn, exists := b.nodes.Load(id)
|
||||||
if !existing.matchesChannel(c) {
|
if !exists {
|
||||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher")
|
||||||
return // Not the current connection, not an error
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark the connection as closed to prevent further sends
|
// Mark the connection as closed to prevent further sends
|
||||||
if connData := existing.connData.Load(); connData != nil {
|
if connData := existing.connData.Load(); connData != nil {
|
||||||
@ -111,15 +107,20 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
// Check if node has any remaining active connections
|
||||||
|
if nodeConn.hasActiveConnections() {
|
||||||
|
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
||||||
|
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||||
|
Msg("Node connection removed but keeping online because other connections remain")
|
||||||
|
return true // Node still has active connections
|
||||||
|
}
|
||||||
|
|
||||||
// Remove node and mark disconnected atomically
|
// Remove node and mark disconnected atomically
|
||||||
b.nodes.Delete(id)
|
b.nodes.Delete(id)
|
||||||
b.connected.Store(id, ptr.To(time.Now()))
|
b.connected.Store(id, ptr.To(time.Now()))
|
||||||
b.totalNodes.Add(-1)
|
b.totalNodes.Add(-1)
|
||||||
|
|
||||||
// Notify other nodes that this node went offline
|
return false
|
||||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddWork queues a change to be processed by the batcher.
|
// AddWork queues a change to be processed by the batcher.
|
||||||
@ -205,15 +206,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
duration := time.Since(startTime)
|
|
||||||
if duration > 100*time.Millisecond {
|
|
||||||
log.Warn().
|
|
||||||
Int("workerID", workerID).
|
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
|
||||||
Str("change", w.c.Change.String()).
|
|
||||||
Dur("duration", duration).
|
|
||||||
Msg("slow synchronous work processing")
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,16 +213,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
// that should be processed and sent to the node instead of
|
// that should be processed and sent to the node instead of
|
||||||
// returned to the caller.
|
// returned to the caller.
|
||||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||||
// Check if this connection is still active before processing
|
// Apply change to node - this will handle offline nodes gracefully
|
||||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
// and queue work for when they reconnect
|
||||||
log.Debug().
|
|
||||||
Int("workerID", workerID).
|
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
|
||||||
Str("change", w.c.Change.String()).
|
|
||||||
Msg("skipping work for closed connection")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := nc.change(w.c)
|
err := nc.change(w.c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.workErrors.Add(1)
|
b.workErrors.Add(1)
|
||||||
@ -240,52 +224,18 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
Str("change", w.c.Change.String()).
|
Str("change", w.c.Change.String()).
|
||||||
Msg("failed to apply change")
|
Msg("failed to apply change")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
log.Debug().
|
|
||||||
Int("workerID", workerID).
|
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
|
||||||
Str("change", w.c.Change.String()).
|
|
||||||
Msg("node not found for asynchronous work - node may have disconnected")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
duration := time.Since(startTime)
|
|
||||||
if duration > 100*time.Millisecond {
|
|
||||||
log.Warn().
|
|
||||||
Int("workerID", workerID).
|
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
|
||||||
Str("change", w.c.Change.String()).
|
|
||||||
Dur("duration", duration).
|
|
||||||
Msg("slow asynchronous work processing")
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-b.ctx.Done():
|
case <-b.ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
|
||||||
// For critical changes that need immediate processing, send directly
|
b.addToBatch(c...)
|
||||||
if b.shouldProcessImmediately(c) {
|
|
||||||
if c.SelfUpdateOnly {
|
|
||||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
|
||||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// For non-critical changes, add to batch
|
|
||||||
b.addToBatch(c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// queueWork safely queues work
|
// queueWork safely queues work.
|
||||||
func (b *LockFreeBatcher) queueWork(w work) {
|
func (b *LockFreeBatcher) queueWork(w work) {
|
||||||
b.workQueuedCount.Add(1)
|
b.workQueuedCount.Add(1)
|
||||||
|
|
||||||
@ -298,26 +248,21 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldProcessImmediately determines if a change should bypass batching
|
// addToBatch adds a change to the pending batch.
|
||||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||||
// Process these changes immediately to avoid delaying critical functionality
|
// Short circuit if any of the changes is a full update, which
|
||||||
switch c.Change {
|
// means we can skip sending individual changes.
|
||||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
if change.HasFull(c) {
|
||||||
return true
|
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||||
default:
|
b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
|
||||||
return false
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addToBatch adds a change to the pending batch
|
|
||||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
|
||||||
b.batchMutex.Lock()
|
|
||||||
defer b.batchMutex.Unlock()
|
|
||||||
|
|
||||||
if c.SelfUpdateOnly {
|
|
||||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
|
||||||
changes = append(changes, c)
|
|
||||||
b.pendingChanges.Store(c.NodeID, changes)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,15 +274,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
|||||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||||
changes = append(changes, c)
|
changes = append(changes, c)
|
||||||
b.pendingChanges.Store(nodeID, changes)
|
b.pendingChanges.Store(nodeID, changes)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// processBatchedChanges processes all pending batched changes
|
// processBatchedChanges processes all pending batched changes.
|
||||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||||
b.batchMutex.Lock()
|
|
||||||
defer b.batchMutex.Unlock()
|
|
||||||
|
|
||||||
if b.pendingChanges == nil {
|
if b.pendingChanges == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -355,16 +298,31 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
|||||||
|
|
||||||
// Clear the pending changes for this node
|
// Clear the pending changes for this node
|
||||||
b.pendingChanges.Delete(nodeID)
|
b.pendingChanges.Delete(nodeID)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected is lock-free read.
|
// IsConnected is lock-free read.
|
||||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||||
if val, ok := b.connected.Load(id); ok {
|
// First check if we have active connections for this node
|
||||||
// nil means connected
|
if nodeConn, exists := b.nodes.Load(id); exists {
|
||||||
return val == nil
|
if nodeConn.hasActiveConnections() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check disconnected timestamp with grace period
|
||||||
|
val, ok := b.connected.Load(id)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// nil means connected
|
||||||
|
if val == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -372,9 +330,26 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
|||||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||||
ret := xsync.NewMap[types.NodeID, bool]()
|
ret := xsync.NewMap[types.NodeID, bool]()
|
||||||
|
|
||||||
|
// First, add all nodes with active connections
|
||||||
|
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||||
|
if nodeConn.hasActiveConnections() {
|
||||||
|
ret.Store(id, true)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Then add all entries from the connected map
|
||||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||||
// nil means connected
|
// Only add if not already added as connected above
|
||||||
ret.Store(id, val == nil)
|
if _, exists := ret.Load(id); !exists {
|
||||||
|
if val == nil {
|
||||||
|
// nil means connected
|
||||||
|
ret.Store(id, true)
|
||||||
|
} else {
|
||||||
|
// timestamp means disconnected
|
||||||
|
ret.Store(id, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -482,12 +457,21 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
|||||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||||
// the channel is still open.
|
// Only add if not already processed above
|
||||||
connData.c <- data
|
if _, exists := result[id]; !exists {
|
||||||
nc.updateCount.Add(1)
|
// Use immediate connection status for debug (no grace period)
|
||||||
return nil
|
connected := (val == nil) // nil means connected, timestamp means disconnected
|
||||||
|
result[id] = DebugNodeInfo{
|
||||||
|
Connected: connected,
|
||||||
|
ActiveConnections: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||||
|
@ -27,6 +27,60 @@ type batcherTestCase struct {
|
|||||||
fn batcherFunc
|
fn batcherFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testBatcherWrapper wraps a real batcher to add online/offline notifications
|
||||||
|
// that would normally be sent by poll.go in production.
|
||||||
|
type testBatcherWrapper struct {
|
||||||
|
Batcher
|
||||||
|
state *state.State
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||||
|
// Mark node as online in state before AddNode to match production behavior
|
||||||
|
// This ensures the NodeStore has correct online status for change processing
|
||||||
|
if t.state != nil {
|
||||||
|
// Use Connect to properly mark node online in NodeStore but don't send its changes
|
||||||
|
_ = t.state.Connect(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First add the node to the real batcher
|
||||||
|
err := t.Batcher.AddNode(id, c, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the online notification that poll.go would normally send
|
||||||
|
// This ensures other nodes get notified about this node coming online
|
||||||
|
t.AddWork(change.NodeOnline(id))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||||
|
// Mark node as offline in state BEFORE removing from batcher
|
||||||
|
// This ensures the NodeStore has correct offline status when the change is processed
|
||||||
|
if t.state != nil {
|
||||||
|
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
|
||||||
|
_, _ = t.state.Disconnect(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the offline notification that poll.go would normally send
|
||||||
|
// Do this BEFORE removing from batcher so the change can be processed
|
||||||
|
t.AddWork(change.NodeOffline(id))
|
||||||
|
|
||||||
|
// Finally remove from the real batcher
|
||||||
|
removed := t.Batcher.RemoveNode(id, c)
|
||||||
|
if !removed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||||
|
func wrapBatcherForTest(b Batcher, state *state.State) Batcher {
|
||||||
|
return &testBatcherWrapper{Batcher: b, state: state}
|
||||||
|
}
|
||||||
|
|
||||||
// allBatcherFunctions contains all batcher implementations to test.
|
// allBatcherFunctions contains all batcher implementations to test.
|
||||||
var allBatcherFunctions = []batcherTestCase{
|
var allBatcherFunctions = []batcherTestCase{
|
||||||
{"LockFree", NewBatcherAndMapper},
|
{"LockFree", NewBatcherAndMapper},
|
||||||
@ -183,8 +237,8 @@ func setupBatcherWithTestData(
|
|||||||
"acls": [
|
"acls": [
|
||||||
{
|
{
|
||||||
"action": "accept",
|
"action": "accept",
|
||||||
"users": ["*"],
|
"src": ["*"],
|
||||||
"ports": ["*:*"]
|
"dst": ["*:*"]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`
|
}`
|
||||||
@ -194,8 +248,8 @@ func setupBatcherWithTestData(
|
|||||||
t.Fatalf("Failed to set allow-all policy: %v", err)
|
t.Fatalf("Failed to set allow-all policy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create batcher with the state
|
// Create batcher with the state and wrap it for testing
|
||||||
batcher := bf(cfg, state)
|
batcher := wrapBatcherForTest(bf(cfg, state), state)
|
||||||
batcher.Start()
|
batcher.Start()
|
||||||
|
|
||||||
testData := &TestData{
|
testData := &TestData{
|
||||||
@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
|||||||
testNode.start()
|
testNode.start()
|
||||||
|
|
||||||
// Connect the node to the batcher
|
// Connect the node to the batcher
|
||||||
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||||
time.Sleep(100 * time.Millisecond) // Let connection settle
|
time.Sleep(100 * time.Millisecond) // Let connection settle
|
||||||
|
|
||||||
// Generate some work
|
// Generate some work
|
||||||
@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||||
for i := range allNodes {
|
for i := range allNodes {
|
||||||
node := &allNodes[i]
|
node := &allNodes[i]
|
||||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Issue full update after each join to ensure connectivity
|
// Issue full update after each join to ensure connectivity
|
||||||
batcher.AddWork(change.FullSet)
|
batcher.AddWork(change.FullSet)
|
||||||
@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
// Disconnect all nodes
|
// Disconnect all nodes
|
||||||
for i := range allNodes {
|
for i := range allNodes {
|
||||||
node := &allNodes[i]
|
node := &allNodes[i]
|
||||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
batcher.RemoveNode(node.n.ID, node.ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Give time for final updates to process
|
// Give time for final updates to process
|
||||||
@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
tn2 := testData.Nodes[1]
|
tn2 := testData.Nodes[1]
|
||||||
|
|
||||||
// Test AddNode with real node ID
|
// Test AddNode with real node ID
|
||||||
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
|
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||||
|
|
||||||
if !batcher.IsConnected(tn.n.ID) {
|
if !batcher.IsConnected(tn.n.ID) {
|
||||||
t.Error("Node should be connected after AddNode")
|
t.Error("Node should be connected after AddNode")
|
||||||
}
|
}
|
||||||
@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||||
|
|
||||||
// Add the second node and verify update message
|
// Add the second node and verify update message
|
||||||
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
|
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||||
|
|
||||||
// First node should get an update that second node has connected.
|
// First node should get an update that second node has connected.
|
||||||
select {
|
select {
|
||||||
case data := <-tn.ch:
|
case data := <-tn.ch:
|
||||||
assertOnlineMapResponse(t, data, true)
|
assertOnlineMapResponse(t, data, true)
|
||||||
case <-time.After(200 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Did not receive expected Online response update")
|
t.Error("Did not receive expected Online response update")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect the second node
|
// Disconnect the second node
|
||||||
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
|
batcher.RemoveNode(tn2.n.ID, tn2.ch)
|
||||||
assert.False(t, batcher.IsConnected(tn2.n.ID))
|
// Note: IsConnected may return true during grace period for DNS resolution
|
||||||
|
|
||||||
// First node should get update that second has disconnected.
|
// First node should get update that second has disconnected.
|
||||||
select {
|
select {
|
||||||
case data := <-tn.ch:
|
case data := <-tn.ch:
|
||||||
assertOnlineMapResponse(t, data, false)
|
assertOnlineMapResponse(t, data, false)
|
||||||
case <-time.After(200 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Did not receive expected Online response update")
|
t.Error("Did not receive expected Online response update")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// Test RemoveNode
|
// Test RemoveNode
|
||||||
batcher.RemoveNode(tn.n.ID, tn.ch, false)
|
batcher.RemoveNode(tn.n.ID, tn.ch)
|
||||||
if batcher.IsConnected(tn.n.ID) {
|
// Note: IsConnected may return true during grace period for DNS resolution
|
||||||
t.Error("Node should be disconnected after RemoveNode")
|
// The node is actually removed from active connections but grace period allows DNS lookups
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
|||||||
testNodes := testData.Nodes
|
testNodes := testData.Nodes
|
||||||
|
|
||||||
ch := make(chan *tailcfg.MapResponse, 10)
|
ch := make(chan *tailcfg.MapResponse, 10)
|
||||||
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Track update content for validation
|
// Track update content for validation
|
||||||
var receivedUpdates []*tailcfg.MapResponse
|
var receivedUpdates []*tailcfg.MapResponse
|
||||||
@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
|
|
||||||
|
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Add real work during connection chaos
|
// Add real work during connection chaos
|
||||||
@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
time.Sleep(1 * time.Microsecond)
|
time.Sleep(1 * time.Microsecond)
|
||||||
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Remove second connection
|
// Remove second connection
|
||||||
@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
time.Sleep(2 * time.Microsecond)
|
time.Sleep(2 * time.Microsecond)
|
||||||
batcher.RemoveNode(testNode.n.ID, ch2, false)
|
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
ch := make(chan *tailcfg.MapResponse, 5)
|
ch := make(chan *tailcfg.MapResponse, 5)
|
||||||
|
|
||||||
// Add node and immediately queue real work
|
// Add node and immediately queue real work
|
||||||
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
batcher.AddWork(change.DERPSet)
|
batcher.AddWork(change.DERPSet)
|
||||||
|
|
||||||
// Consumer goroutine to validate data and detect channel issues
|
// Consumer goroutine to validate data and detect channel issues
|
||||||
@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
|
|
||||||
// Rapid removal creates race between worker and removal
|
// Rapid removal creates race between worker and removal
|
||||||
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
||||||
batcher.RemoveNode(testNode.n.ID, ch, false)
|
batcher.RemoveNode(testNode.n.ID, ch)
|
||||||
|
|
||||||
// Give workers time to process and close channels
|
// Give workers time to process and close channels
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
for _, node := range stableNodes {
|
for _, node := range stableNodes {
|
||||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||||
stableChannels[node.n.ID] = ch
|
stableChannels[node.n.ID] = ch
|
||||||
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Monitor updates for each stable client
|
// Monitor updates for each stable client
|
||||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||||
@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
churningChannelsMutex.Lock()
|
churningChannelsMutex.Lock()
|
||||||
churningChannels[nodeID] = ch
|
churningChannels[nodeID] = ch
|
||||||
churningChannelsMutex.Unlock()
|
churningChannelsMutex.Unlock()
|
||||||
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Consume updates to prevent blocking
|
// Consume updates to prevent blocking
|
||||||
go func() {
|
go func() {
|
||||||
@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
ch, exists := churningChannels[nodeID]
|
ch, exists := churningChannels[nodeID]
|
||||||
churningChannelsMutex.Unlock()
|
churningChannelsMutex.Unlock()
|
||||||
if exists {
|
if exists {
|
||||||
batcher.RemoveNode(nodeID, ch, false)
|
batcher.RemoveNode(nodeID, ch)
|
||||||
}
|
}
|
||||||
}(node.n.ID)
|
}(node.n.ID)
|
||||||
}
|
}
|
||||||
@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
var connectedNodesMutex sync.RWMutex
|
var connectedNodesMutex sync.RWMutex
|
||||||
for i := range testNodes {
|
for i := range testNodes {
|
||||||
node := &testNodes[i]
|
node := &testNodes[i]
|
||||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
connectedNodes[node.n.ID] = true
|
connectedNodes[node.n.ID] = true
|
||||||
connectedNodesMutex.Unlock()
|
connectedNodesMutex.Unlock()
|
||||||
@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
connectedNodesMutex.RUnlock()
|
connectedNodesMutex.RUnlock()
|
||||||
|
|
||||||
if isConnected {
|
if isConnected {
|
||||||
batcher.RemoveNode(nodeID, channel, false)
|
batcher.RemoveNode(nodeID, channel)
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
connectedNodes[nodeID] = false
|
connectedNodes[nodeID] = false
|
||||||
connectedNodesMutex.Unlock()
|
connectedNodesMutex.Unlock()
|
||||||
@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
// Now disconnect all nodes from batcher to stop new updates
|
// Now disconnect all nodes from batcher to stop new updates
|
||||||
for i := range testNodes {
|
for i := range testNodes {
|
||||||
node := &testNodes[i]
|
node := &testNodes[i]
|
||||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
batcher.RemoveNode(node.n.ID, node.ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
||||||
@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
|
|
||||||
// Connect nodes one at a time to avoid overwhelming the work queue
|
// Connect nodes one at a time to avoid overwhelming the work queue
|
||||||
for i, node := range allNodes {
|
for i, node := range allNodes {
|
||||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||||
// Small delay between connections to allow NodeCameOnline processing
|
// Small delay between connections to allow NodeCameOnline processing
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
|
|
||||||
// Check how many peers each node should see
|
// Check how many peers each node should see
|
||||||
for i, node := range allNodes {
|
for i, node := range allNodes {
|
||||||
peers, err := testData.State.ListPeers(node.n.ID)
|
peers := testData.State.ListPeers(node.n.ID)
|
||||||
if err != nil {
|
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||||
t.Errorf("Error listing peers for node %d: %v", i, err)
|
|
||||||
} else {
|
|
||||||
t.Logf("Node %d should see %d peers from state", i, len(peers))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a full update - this should generate full peer lists
|
// Send a full update - this should generate full peer lists
|
||||||
@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
foundFullUpdate := false
|
foundFullUpdate := false
|
||||||
|
|
||||||
// Read all available updates for each node
|
// Read all available updates for each node
|
||||||
for i := range len(allNodes) {
|
for i := range allNodes {
|
||||||
nodeUpdates := 0
|
nodeUpdates := 0
|
||||||
t.Logf("Reading updates for node %d:", i)
|
t.Logf("Reading updates for node %d:", i)
|
||||||
|
|
||||||
@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||||
|
|
||||||
// Connect first node
|
time.Sleep(100 * time.Millisecond) // Let connections settle
|
||||||
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
|
|
||||||
t.Logf("Connected node %d", nodes[0].n.ID)
|
|
||||||
|
|
||||||
// Wait for initial NodeCameOnline to be processed
|
// Wait for initial NodeCameOnline to be processed
|
||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
|||||||
t.Errorf("ERROR: Received unknown update type!")
|
t.Errorf("ERROR: Received unknown update type!")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if there should be peers available
|
batcher := testData.Batcher
|
||||||
peers, err := testData.State.ListPeers(nodes[0].n.ID)
|
node1 := testData.Nodes[0]
|
||||||
if err != nil {
|
node2 := testData.Nodes[1]
|
||||||
t.Errorf("Error getting peers from state: %v", err)
|
|
||||||
} else {
|
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||||
t.Logf("State shows %d peers available for this node", len(peers))
|
|
||||||
if len(peers) > 0 && len(data.Peers) == 0 {
|
// Phase 1: Connect first node with initial connection
|
||||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers))
|
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||||
|
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add node1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect second node for comparison
|
||||||
|
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add node2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||||
|
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||||
|
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||||
|
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Phase 3: Add third connection for node1
|
||||||
|
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||||
|
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||||
|
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Phase 4: Verify debug status shows correct connection count
|
||||||
|
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||||
|
if debugBatcher, ok := batcher.(interface {
|
||||||
|
Debug() map[types.NodeID]any
|
||||||
|
}); ok {
|
||||||
|
debugInfo := debugBatcher.Debug()
|
||||||
|
|
||||||
|
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||||
|
t.Logf("Node1 debug info: %+v", info)
|
||||||
|
if infoMap, ok := info.(map[string]any); ok {
|
||||||
|
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||||
|
if activeConnections != 3 {
|
||||||
|
t.Errorf("Node1 should have 3 active connections, got %d", activeConnections)
|
||||||
|
} else {
|
||||||
|
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||||
|
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if info, exists := debugInfo[node2.n.ID]; exists {
|
||||||
|
if infoMap, ok := info.(map[string]any); ok {
|
||||||
|
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||||
|
if activeConnections != 1 {
|
||||||
|
t.Errorf("Node2 should have 1 active connection, got %d", activeConnections)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 5: Send update and verify ALL connections receive it
|
||||||
|
t.Logf("Phase 5: Testing update distribution to all connections...")
|
||||||
|
|
||||||
|
// Clear any existing updates from all channels
|
||||||
|
clearChannel := func(ch chan *tailcfg.MapResponse) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
// drain
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clearChannel(node1.ch)
|
||||||
|
clearChannel(secondChannel)
|
||||||
|
clearChannel(thirdChannel)
|
||||||
|
clearChannel(node2.ch)
|
||||||
|
|
||||||
|
// Send a change notification from node2 (so node1 should receive it on all connections)
|
||||||
|
testChangeSet := change.ChangeSet{
|
||||||
|
NodeID: node2.n.ID,
|
||||||
|
Change: change.NodeNewOrUpdate,
|
||||||
|
SelfUpdateOnly: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
batcher.AddWork(testChangeSet)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond) // Let updates propagate
|
||||||
|
|
||||||
|
// Verify all three connections for node1 receive the update
|
||||||
|
connection1Received := false
|
||||||
|
connection2Received := false
|
||||||
|
connection3Received := false
|
||||||
|
|
||||||
|
select {
|
||||||
|
case mapResp := <-node1.ch:
|
||||||
|
connection1Received = (mapResp != nil)
|
||||||
|
t.Logf("Node1 connection 1 received update: %t", connection1Received)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Errorf("Node1 connection 1 did not receive update")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case mapResp := <-secondChannel:
|
||||||
|
connection2Received = (mapResp != nil)
|
||||||
|
t.Logf("Node1 connection 2 received update: %t", connection2Received)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Errorf("Node1 connection 2 did not receive update")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case mapResp := <-thirdChannel:
|
||||||
|
connection3Received = (mapResp != nil)
|
||||||
|
t.Logf("Node1 connection 3 received update: %t", connection3Received)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Errorf("Node1 connection 3 did not receive update")
|
||||||
|
}
|
||||||
|
|
||||||
|
if connection1Received && connection2Received && connection3Received {
|
||||||
|
t.Logf("SUCCESS: All three connections for node1 received the update")
|
||||||
|
} else {
|
||||||
|
t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t",
|
||||||
|
connection1Received, connection2Received, connection3Received)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 6: Test connection removal and verify remaining connections still work
|
||||||
|
t.Logf("Phase 6: Testing connection removal...")
|
||||||
|
|
||||||
|
// Remove the second connection
|
||||||
|
removed := batcher.RemoveNode(node1.n.ID, secondChannel)
|
||||||
|
if !removed {
|
||||||
|
t.Errorf("Failed to remove second connection for node1")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify debug status shows 2 connections now
|
||||||
|
if debugBatcher, ok := batcher.(interface {
|
||||||
|
Debug() map[types.NodeID]any
|
||||||
|
}); ok {
|
||||||
|
debugInfo := debugBatcher.Debug()
|
||||||
|
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||||
|
if infoMap, ok := info.(map[string]any); ok {
|
||||||
|
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||||
|
if activeConnections != 2 {
|
||||||
|
t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections)
|
||||||
|
} else {
|
||||||
|
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package mapper
|
package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
@ -12,7 +13,7 @@ import (
|
|||||||
"tailscale.com/util/multierr"
|
"tailscale.com/util/multierr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
|
||||||
type MapResponseBuilder struct {
|
type MapResponseBuilder struct {
|
||||||
resp *tailcfg.MapResponse
|
resp *tailcfg.MapResponse
|
||||||
mapper *mapper
|
mapper *mapper
|
||||||
@ -21,7 +22,17 @@ type MapResponseBuilder struct {
|
|||||||
errs []error
|
errs []error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
type debugType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
fullResponseDebug debugType = "full"
|
||||||
|
patchResponseDebug debugType = "patch"
|
||||||
|
removeResponseDebug debugType = "remove"
|
||||||
|
changeResponseDebug debugType = "change"
|
||||||
|
derpResponseDebug debugType = "derp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return &MapResponseBuilder{
|
return &MapResponseBuilder{
|
||||||
@ -35,32 +46,39 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addError adds an error to the builder's error list
|
// addError adds an error to the builder's error list.
|
||||||
func (b *MapResponseBuilder) addError(err error) {
|
func (b *MapResponseBuilder) addError(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.errs = append(b.errs, err)
|
b.errs = append(b.errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasErrors returns true if the builder has accumulated any errors
|
// hasErrors returns true if the builder has accumulated any errors.
|
||||||
func (b *MapResponseBuilder) hasErrors() bool {
|
func (b *MapResponseBuilder) hasErrors() bool {
|
||||||
return len(b.errs) > 0
|
return len(b.errs) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCapabilityVersion sets the capability version for the response
|
// WithCapabilityVersion sets the capability version for the response.
|
||||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||||
b.capVer = capVer
|
b.capVer = capVer
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSelfNode adds the requesting node to the response
|
// WithSelfNode adds the requesting node to the response.
|
||||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always use batcher's view of online status for self node
|
||||||
|
// The batcher respects grace periods for logout scenarios
|
||||||
|
node := nodeView.AsStruct()
|
||||||
|
// if b.mapper.batcher != nil {
|
||||||
|
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||||
|
// }
|
||||||
|
|
||||||
_, matchers := b.mapper.state.Filter()
|
_, matchers := b.mapper.state.Filter()
|
||||||
tailnode, err := tailNode(
|
tailnode, err := tailNode(
|
||||||
node.View(), b.capVer, b.mapper.state,
|
node.View(), b.capVer, b.mapper.state,
|
||||||
@ -74,29 +92,38 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.Node = tailnode
|
b.resp.Node = tailnode
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDERPMap adds the DERP map to the response
|
func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder {
|
||||||
|
if debugDumpMapResponsePath != "" {
|
||||||
|
b.debugType = t
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDERPMap adds the DERP map to the response.
|
||||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||||
b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct()
|
b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct()
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDomain adds the domain configuration
|
// WithDomain adds the domain configuration.
|
||||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||||
b.resp.Domain = b.mapper.cfg.Domain()
|
b.resp.Domain = b.mapper.cfg.Domain()
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCollectServicesDisabled sets the collect services flag to false
|
// WithCollectServicesDisabled sets the collect services flag to false.
|
||||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||||
b.resp.CollectServices.Set(false)
|
b.resp.CollectServices.Set(false)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDebugConfig adds debug configuration
|
// WithDebugConfig adds debug configuration
|
||||||
// It disables log tailing if the mapper's LogTail is not enabled
|
// It disables log tailing if the mapper's LogTail is not enabled.
|
||||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||||
b.resp.Debug = &tailcfg.Debug{
|
b.resp.Debug = &tailcfg.Debug{
|
||||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||||
@ -104,53 +131,56 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
// WithSSHPolicy adds SSH policy configuration for the requesting node.
|
||||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
sshPolicy, err := b.mapper.state.SSHPolicy(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.addError(err)
|
b.addError(err)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
b.resp.SSHPolicy = sshPolicy
|
b.resp.SSHPolicy = sshPolicy
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDNSConfig adds DNS configuration for the requesting node
|
// WithDNSConfig adds DNS configuration for the requesting node.
|
||||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
// WithUserProfiles adds user profiles for the requesting node and given peers.
|
||||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPacketFilters adds packet filter rules based on policy
|
// WithPacketFilters adds packet filter rules based on policy.
|
||||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,15 +191,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
|||||||
// new PacketFilters field and "base" allows us to send a full update when we
|
// new PacketFilters field and "base" allows us to send a full update when we
|
||||||
// have to send an empty list, avoiding the hack in the else block.
|
// have to send an empty list, avoiding the hack in the else block.
|
||||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
"base": policy.ReduceFilterRules(node, filter),
|
||||||
}
|
}
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
// WithPeers adds full peer list with policy filtering (for full map response).
|
||||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
|
|
||||||
tailPeers, err := b.buildTailPeers(peers)
|
tailPeers, err := b.buildTailPeers(peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.addError(err)
|
b.addError(err)
|
||||||
@ -177,12 +206,12 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.Peers = tailPeers
|
b.resp.Peers = tailPeers
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
// WithPeerChanges adds changed peers with policy filtering (for incremental updates).
|
||||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
|
|
||||||
tailPeers, err := b.buildTailPeers(peers)
|
tailPeers, err := b.buildTailPeers(peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.addError(err)
|
b.addError(err)
|
||||||
@ -190,14 +219,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.PeersChanged = tailPeers
|
b.resp.PeersChanged = tailPeers
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
|
||||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, errors.New("node not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, matchers := b.mapper.state.Filter()
|
filter, matchers := b.mapper.state.Filter()
|
||||||
@ -206,15 +236,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
|||||||
// access each-other at all and remove them from the peers.
|
// access each-other at all and remove them from the peers.
|
||||||
var changedViews views.Slice[types.NodeView]
|
var changedViews views.Slice[types.NodeView]
|
||||||
if len(filter) > 0 {
|
if len(filter) > 0 {
|
||||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
changedViews = policy.ReduceNodes(node, peers, matchers)
|
||||||
} else {
|
} else {
|
||||||
changedViews = peers.ViewSlice()
|
changedViews = peers
|
||||||
}
|
}
|
||||||
|
|
||||||
tailPeers, err := tailNodes(
|
tailPeers, err := tailNodes(
|
||||||
changedViews, b.capVer, b.mapper.state,
|
changedViews, b.capVer, b.mapper.state,
|
||||||
func(id types.NodeID) []netip.Prefix {
|
func(id types.NodeID) []netip.Prefix {
|
||||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||||
},
|
},
|
||||||
b.mapper.cfg)
|
b.mapper.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -229,19 +259,20 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
|||||||
return tailPeers, nil
|
return tailPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeerChangedPatch adds peer change patches
|
// WithPeerChangedPatch adds peer change patches.
|
||||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||||
b.resp.PeersChangedPatch = changes
|
b.resp.PeersChangedPatch = changes
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeersRemoved adds removed peer IDs
|
// WithPeersRemoved adds removed peer IDs.
|
||||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||||
var tailscaleIDs []tailcfg.NodeID
|
var tailscaleIDs []tailcfg.NodeID
|
||||||
for _, id := range removedIDs {
|
for _, id := range removedIDs {
|
||||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||||
}
|
}
|
||||||
b.resp.PeersRemoved = tailscaleIDs
|
b.resp.PeersRemoved = tailscaleIDs
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,11 +282,7 @@ func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
|||||||
return nil, multierr.New(b.errs...)
|
return nil, multierr.New(b.errs...)
|
||||||
}
|
}
|
||||||
if debugDumpMapResponsePath != "" {
|
if debugDumpMapResponsePath != "" {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
writeDebugMapResponse(b.resp, node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.resp, nil
|
return b.resp, nil
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -69,16 +70,18 @@ func newMapper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateUserProfiles(
|
func generateUserProfiles(
|
||||||
node *types.Node,
|
node types.NodeView,
|
||||||
peers types.Nodes,
|
peers views.Slice[types.NodeView],
|
||||||
) []tailcfg.UserProfile {
|
) []tailcfg.UserProfile {
|
||||||
userMap := make(map[uint]*types.User)
|
userMap := make(map[uint]*types.User)
|
||||||
ids := make([]uint, 0, len(userMap))
|
ids := make([]uint, 0, len(userMap))
|
||||||
userMap[node.User.ID] = &node.User
|
user := node.User()
|
||||||
ids = append(ids, node.User.ID)
|
userMap[user.ID] = &user
|
||||||
for _, peer := range peers {
|
ids = append(ids, user.ID)
|
||||||
userMap[peer.User.ID] = &peer.User
|
for _, peer := range peers.All() {
|
||||||
ids = append(ids, peer.User.ID)
|
peerUser := peer.User()
|
||||||
|
userMap[peerUser.ID] = &peerUser
|
||||||
|
ids = append(ids, peerUser.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(ids)
|
slices.Sort(ids)
|
||||||
@ -95,7 +98,7 @@ func generateUserProfiles(
|
|||||||
|
|
||||||
func generateDNSConfig(
|
func generateDNSConfig(
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
node *types.Node,
|
node types.NodeView,
|
||||||
) *tailcfg.DNSConfig {
|
) *tailcfg.DNSConfig {
|
||||||
if cfg.TailcfgDNSConfig == nil {
|
if cfg.TailcfgDNSConfig == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -115,12 +118,12 @@ func generateDNSConfig(
|
|||||||
//
|
//
|
||||||
// This will produce a resolver like:
|
// This will produce a resolver like:
|
||||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||||
for _, resolver := range resolvers {
|
for _, resolver := range resolvers {
|
||||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||||
attrs := url.Values{
|
attrs := url.Values{
|
||||||
"device_name": []string{node.Hostname},
|
"device_name": []string{node.Hostname()},
|
||||||
"device_model": []string{node.Hostinfo.OS},
|
"device_model": []string{node.Hostinfo().OS()},
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(node.IPs()) > 0 {
|
if len(node.IPs()) > 0 {
|
||||||
@ -138,10 +141,7 @@ func (m *mapper) fullMapResponse(
|
|||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
peers, err := m.listPeers(nodeID)
|
peers := m.state.ListPeers(nodeID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.NewMapResponseBuilder(nodeID).
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
@ -183,10 +183,7 @@ func (m *mapper) peerChangeResponse(
|
|||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
changedNodeID types.NodeID,
|
changedNodeID types.NodeID,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.NewMapResponseBuilder(nodeID).
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
@ -208,7 +205,8 @@ func (m *mapper) peerRemovedResponse(
|
|||||||
|
|
||||||
func writeDebugMapResponse(
|
func writeDebugMapResponse(
|
||||||
resp *tailcfg.MapResponse,
|
resp *tailcfg.MapResponse,
|
||||||
node *types.Node,
|
t debugType,
|
||||||
|
nodeID types.NodeID,
|
||||||
) {
|
) {
|
||||||
body, err := json.MarshalIndent(resp, "", " ")
|
body, err := json.MarshalIndent(resp, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -236,25 +234,6 @@ func writeDebugMapResponse(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
|
||||||
// If no peer IDs are given, all peers are returned.
|
|
||||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
|
||||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
|
||||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(kradalby): Add back online via batcher. This was removed
|
|
||||||
// to avoid a circular dependency between the mapper and the notification.
|
|
||||||
for _, peer := range peers {
|
|
||||||
online := m.batcher.IsConnected(peer.ID)
|
|
||||||
peer.IsOnline = &online
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||||
// from the primary route manager to the node.
|
// from the primary route manager to the node.
|
||||||
|
@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||||||
&types.Config{
|
&types.Config{
|
||||||
TailcfgDNSConfig: &dnsConfigOrig,
|
TailcfgDNSConfig: &dnsConfigOrig,
|
||||||
},
|
},
|
||||||
nodeInShared1,
|
nodeInShared1.View(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||||
|
@ -133,13 +133,12 @@ func tailNode(
|
|||||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
|
// Set LastSeen only for offline nodes to avoid confusing Tailscale clients
|
||||||
// LastSeen is only set when node is
|
// during rapid reconnection cycles. Online nodes should not have LastSeen set
|
||||||
// not connected to the control server.
|
// as this can make clients interpret them as "not online" despite Online=true.
|
||||||
if node.LastSeen().Valid() {
|
if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() {
|
||||||
lastSeen := node.LastSeen().Get()
|
lastSeen := node.LastSeen().Get()
|
||||||
tNode.LastSeen = &lastSeen
|
tNode.LastSeen = &lastSeen
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tNode, nil
|
return &tNode, nil
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/control/controlbase"
|
"tailscale.com/control/controlbase"
|
||||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
@ -296,16 +295,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||||
// and validates that it matches the MachineKey from the Noise session.
|
// and validates that it matches the MachineKey from the Noise session.
|
||||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||||
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||||
if err != nil {
|
if !ok {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
|
||||||
}
|
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nv := node.View()
|
|
||||||
|
|
||||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||||
if ns.machineKey != nv.MachineKey() {
|
if ns.machineKey != nv.MachineKey() {
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||||
|
@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// The user claims are now updated from the the userinfo endpoint so we can verify the user a
|
// The user claims are now updated from the userinfo endpoint so we can verify the user
|
||||||
// against allowed emails, email domains, and groups.
|
// against allowed emails, email domains, and groups.
|
||||||
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims)
|
user, c, err := a.createOrUpdateUserFromClaim(&claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
// Send policy update notifications if needed
|
||||||
if policyChanged {
|
a.h.Change(c)
|
||||||
a.h.Change(change.PolicyChange())
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(kradalby): Is this comment right?
|
// TODO(kradalby): Is this comment right?
|
||||||
// If the node exists, then the node should be reauthenticated,
|
// If the node exists, then the node should be reauthenticated,
|
||||||
@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
|
|||||||
|
|
||||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
claims *types.OIDCClaims,
|
claims *types.OIDCClaims,
|
||||||
) (*types.User, bool, error) {
|
) (*types.User, change.ChangeSet, error) {
|
||||||
var user *types.User
|
var user *types.User
|
||||||
var err error
|
var err error
|
||||||
var newUser bool
|
var newUser bool
|
||||||
var policyChanged bool
|
var c change.ChangeSet
|
||||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the user is still not found, create a new empty user.
|
// if the user is still not found, create a new empty user.
|
||||||
@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
user.FromClaim(claims)
|
user.FromClaim(claims)
|
||||||
|
|
||||||
if newUser {
|
if newUser {
|
||||||
user, policyChanged, err = a.h.state.CreateUser(*user)
|
user, c, err = a.h.state.CreateUser(*user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
_, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||||
*u = *user
|
*u = *user
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("updating user: %w", err)
|
return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, policyChanged, nil
|
return user, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthProviderOIDC) handleRegistration(
|
func (a *AuthProviderOIDC) handleRegistration(
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
@ -138,39 +139,74 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
|
||||||
// the nodes perspective according to the given policy.
|
// and returns the new list of approved routes.
|
||||||
// It reports true if any routes were approved.
|
// The approved routes will include:
|
||||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
|
// 1. ALL previously approved routes (regardless of whether they're still advertised)
|
||||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
// 2. New routes from announcedRoutes that can be auto-approved by policy
|
||||||
|
// This ensures that:
|
||||||
|
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
|
||||||
|
// - New routes can be auto-approved according to policy
|
||||||
|
// - Routes can only be removed by explicit admin action (not by auto-approval).
|
||||||
|
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
|
||||||
if pm == nil {
|
if pm == nil {
|
||||||
return false
|
return currentApproved, false
|
||||||
}
|
}
|
||||||
nodeView := node.View()
|
|
||||||
var newApproved []netip.Prefix
|
// Start with ALL currently approved routes - we never remove approved routes
|
||||||
for _, route := range nodeView.AnnouncedRoutes() {
|
newApproved := make([]netip.Prefix, len(currentApproved))
|
||||||
if pm.NodeCanApproveRoute(nodeView, route) {
|
copy(newApproved, currentApproved)
|
||||||
|
|
||||||
|
// Then, check for new routes that can be auto-approved
|
||||||
|
for _, route := range announcedRoutes {
|
||||||
|
// Skip if already approved
|
||||||
|
if slices.Contains(newApproved, route) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this new route can be auto-approved by policy
|
||||||
|
canApprove := pm.NodeCanApproveRoute(nv, route)
|
||||||
|
if canApprove {
|
||||||
newApproved = append(newApproved, route)
|
newApproved = append(newApproved, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
// Sort and deduplicate
|
||||||
// This prevents clearing existing approved routes when nodes
|
tsaddr.SortPrefixes(newApproved)
|
||||||
// temporarily don't have announced routes during policy changes.
|
newApproved = slices.Compact(newApproved)
|
||||||
if len(newApproved) > 0 {
|
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||||
combined := append(newApproved, node.ApprovedRoutes...)
|
return route.IsValid()
|
||||||
tsaddr.SortPrefixes(combined)
|
})
|
||||||
combined = slices.Compact(combined)
|
|
||||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
|
||||||
return route.IsValid()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Only update if the routes actually changed
|
// Sort the current approved for comparison
|
||||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||||
node.ApprovedRoutes = combined
|
copy(sortedCurrent, currentApproved)
|
||||||
return true
|
tsaddr.SortPrefixes(sortedCurrent)
|
||||||
|
|
||||||
|
// Only update if the routes actually changed
|
||||||
|
if !slices.Equal(sortedCurrent, newApproved) {
|
||||||
|
// Log what changed
|
||||||
|
var added, kept []netip.Prefix
|
||||||
|
for _, route := range newApproved {
|
||||||
|
if !slices.Contains(sortedCurrent, route) {
|
||||||
|
added = append(added, route)
|
||||||
|
} else {
|
||||||
|
kept = append(kept, route)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(added) > 0 {
|
||||||
|
log.Debug().
|
||||||
|
Uint64("node.id", nv.ID().Uint64()).
|
||||||
|
Str("node.name", nv.Hostname()).
|
||||||
|
Strs("routes.added", util.PrefixesToString(added)).
|
||||||
|
Strs("routes.kept", util.PrefixesToString(kept)).
|
||||||
|
Int("routes.total", len(newApproved)).
|
||||||
|
Msg("Routes auto-approved by policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
return newApproved, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return newApproved, false
|
||||||
}
|
}
|
||||||
|
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/net/tsaddr"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
"tailscale.com/types/views"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||||
|
user1 := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser@",
|
||||||
|
}
|
||||||
|
user2 := types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "otheruser@",
|
||||||
|
}
|
||||||
|
users := []types.User{user1, user2}
|
||||||
|
|
||||||
|
node1 := &types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "test-node",
|
||||||
|
UserID: user1.ID,
|
||||||
|
User: user1,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ForcedTags: []string{"tag:test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
node2 := &types.Node{
|
||||||
|
ID: 2,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "other-node",
|
||||||
|
UserID: user2.ID,
|
||||||
|
User: user2,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a policy that auto-approves specific routes
|
||||||
|
policyJSON := `{
|
||||||
|
"groups": {
|
||||||
|
"group:test": ["testuser@"]
|
||||||
|
},
|
||||||
|
"tagOwners": {
|
||||||
|
"tag:test": ["testuser@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["*"],
|
||||||
|
"dst": ["*:*"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.0.0.0/8": ["testuser@", "tag:test"],
|
||||||
|
"10.1.0.0/24": ["testuser@"],
|
||||||
|
"10.2.0.0/24": ["testuser@"],
|
||||||
|
"192.168.0.0/24": ["tag:test"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
node *types.Node
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "previously_approved_route_no_longer_advertised_should_remain",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Should still be here!
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "Previously approved routes should never be removed even when no longer advertised",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add_new_auto_approved_route_keeps_old_approved",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8)
|
||||||
|
netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
description: "New auto-approved routes should be added while keeping old approved routes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_announced_routes_keeps_all_approved",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{}, // No routes announced
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "All approved routes should remain when no routes are announced",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_changes_when_announced_equals_approved",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "No changes should occur when announced routes match approved routes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto_approve_multiple_new_routes",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved
|
||||||
|
netip.MustParsePrefix("172.16.0.0/24"), // Original kept
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
description: "Multiple new routes should be auto-approved while keeping existing approved routes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "node_without_permission_no_auto_approval",
|
||||||
|
node: node2, // Different node without the tag
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "Routes should not be auto-approved for nodes without proper permissions",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||||
|
|
||||||
|
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||||
|
tsaddr.SortPrefixes(tt.wantApproved)
|
||||||
|
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||||
|
|
||||||
|
// Verify that all previously approved routes are still present
|
||||||
|
for _, prevRoute := range tt.currentApproved {
|
||||||
|
assert.Contains(t, gotApproved, prevRoute,
|
||||||
|
"previously approved route %s was removed - this should never happen", prevRoute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||||
|
// Create a basic policy for edge case testing
|
||||||
|
aclPolicy := `
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.1.0.0/24": ["test@"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}`
|
||||||
|
|
||||||
|
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil_policy_manager",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_current_approved",
|
||||||
|
currentApproved: nil,
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_announced_routes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: nil,
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate_approved_routes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_slices",
|
||||||
|
currentApproved: []netip.Prefix{},
|
||||||
|
announcedRoutes: []netip.Prefix{},
|
||||||
|
wantApproved: []netip.Prefix{},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for i, pmf := range pmfs {
|
||||||
|
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||||
|
// Create test user
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "test",
|
||||||
|
}
|
||||||
|
users := []types.User{user}
|
||||||
|
|
||||||
|
// Create test node
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "testnode",
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: tt.currentApproved,
|
||||||
|
}
|
||||||
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
|
// Create policy manager or use nil if specified
|
||||||
|
var pm PolicyManager
|
||||||
|
var err error
|
||||||
|
if tt.name != "nil_policy_manager" {
|
||||||
|
pm, err = pmf(users, nodes.ViewSlice())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
pm = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
|
||||||
|
|
||||||
|
// Handle nil vs empty slice comparison
|
||||||
|
if tt.wantApproved == nil {
|
||||||
|
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||||
|
} else {
|
||||||
|
tsaddr.SortPrefixes(tt.wantApproved)
|
||||||
|
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
361
hscontrol/policy/policy_route_approval_test.go
Normal file
361
hscontrol/policy/policy_route_approval_test.go
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||||
|
// Test policy that allows specific routes to be auto-approved
|
||||||
|
aclPolicy := `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:admins": ["test@"],
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.0.0.0/24": ["test@"],
|
||||||
|
"192.168.0.0/24": ["group:admins"],
|
||||||
|
"172.16.0.0/16": ["tag:approved"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tagOwners": {
|
||||||
|
"tag:approved": ["test@"],
|
||||||
|
},
|
||||||
|
}`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
nodeHostname string
|
||||||
|
nodeUser string
|
||||||
|
nodeTags []string
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "previously_approved_route_no_longer_advertised_remains",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add_new_auto_approved_route_keeps_existing",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Still advertised
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New route
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_announced_routes_keeps_all_approved",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manually_approved_route_not_in_policy_remains",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tagged_node_gets_tag_approved_routes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
nodeTags: []string{"tag:approved"},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex_scenario_multiple_changes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
|
||||||
|
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for i, pmf := range pmfs {
|
||||||
|
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||||
|
// Create test user
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: tt.nodeUser,
|
||||||
|
}
|
||||||
|
users := []types.User{user}
|
||||||
|
|
||||||
|
// Create test node
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: tt.nodeHostname,
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: tt.announcedRoutes,
|
||||||
|
},
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: tt.currentApproved,
|
||||||
|
ForcedTags: tt.nodeTags,
|
||||||
|
}
|
||||||
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
|
// Create policy manager
|
||||||
|
pm, err := pmf(users, nodes.ViewSlice())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, pm)
|
||||||
|
|
||||||
|
// Test ApproveRoutesWithPolicy
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||||
|
pm,
|
||||||
|
node.View(),
|
||||||
|
tt.currentApproved,
|
||||||
|
tt.announcedRoutes,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check change flag
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch")
|
||||||
|
|
||||||
|
// Check approved routes match expected
|
||||||
|
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||||
|
t.Logf("Want: %v", tt.wantApproved)
|
||||||
|
t.Logf("Got: %v", gotApproved)
|
||||||
|
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all previously approved routes are still present
|
||||||
|
for _, prevRoute := range tt.currentApproved {
|
||||||
|
assert.Contains(t, gotApproved, prevRoute,
|
||||||
|
"previously approved route %s was removed - this should NEVER happen", prevRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no routes were incorrectly removed
|
||||||
|
for _, removedRoute := range tt.wantRemovedRoutes {
|
||||||
|
assert.NotContains(t, gotApproved, removedRoute,
|
||||||
|
"route %s should have been removed but wasn't", removedRoute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
||||||
|
aclPolicy := `
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.0.0.0/8": ["test@"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil_current_approved",
|
||||||
|
currentApproved: nil,
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_current_approved",
|
||||||
|
currentApproved: []netip.Prefix{},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate_routes_handled",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true, // Duplicates are removed, so it's a change
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for i, pmf := range pmfs {
|
||||||
|
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||||
|
// Create test user
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "test",
|
||||||
|
}
|
||||||
|
users := []types.User{user}
|
||||||
|
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "testnode",
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: tt.announcedRoutes,
|
||||||
|
},
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: tt.currentApproved,
|
||||||
|
}
|
||||||
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
|
pm, err := pmf(users, nodes.ViewSlice())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||||
|
pm,
|
||||||
|
node.View(),
|
||||||
|
tt.currentApproved,
|
||||||
|
tt.announcedRoutes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||||
|
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
currentApproved := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
}
|
||||||
|
announcedRoutes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "testnode",
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: announcedRoutes,
|
||||||
|
},
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: currentApproved,
|
||||||
|
}
|
||||||
|
|
||||||
|
// With nil policy manager, should return current approved unchanged
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes)
|
||||||
|
|
||||||
|
assert.False(t, gotChanged)
|
||||||
|
assert.Equal(t, currentApproved, gotApproved)
|
||||||
|
}
|
@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||||||
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
|
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
|
||||||
canApprove: false,
|
canApprove: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "policy-without-autoApprovers-section",
|
||||||
|
node: normalNode,
|
||||||
|
route: p("10.33.0.0/16"),
|
||||||
|
policy: `{
|
||||||
|
"groups": {
|
||||||
|
"group:admin": ["user1@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admin"],
|
||||||
|
"dst": ["group:admin:*"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admin"],
|
||||||
|
"dst": ["10.33.0.0/16:*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
canApprove: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||||||
// The fast path is that a node requests to approve a prefix
|
// The fast path is that a node requests to approve a prefix
|
||||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||||
// check and return quickly
|
// check and return quickly
|
||||||
if _, ok := pm.autoApproveMap[route]; ok {
|
if approvers, ok := pm.autoApproveMap[route]; ok {
|
||||||
if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) {
|
canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains)
|
||||||
|
if canApprove {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||||||
// Check if prefix is larger (so containing) and then overlaps
|
// Check if prefix is larger (so containing) and then overlaps
|
||||||
// the route to see if the node can approve a subset of an autoapprover
|
// the route to see if the node can approve a subset of an autoapprover
|
||||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||||
if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) {
|
canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains)
|
||||||
|
if canApprove {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/sasha-s/go-deadlock"
|
"github.com/sasha-s/go-deadlock"
|
||||||
@ -112,6 +111,15 @@ func (m *mapSession) serve() {
|
|||||||
// This is the mechanism where the node gives us information about its
|
// This is the mechanism where the node gives us information about its
|
||||||
// current configuration.
|
// current configuration.
|
||||||
//
|
//
|
||||||
|
// Process the MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||||
|
c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||||
|
if err != nil {
|
||||||
|
httpError(m.w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.h.Change(c)
|
||||||
|
|
||||||
// If OmitPeers is true and Stream is false
|
// If OmitPeers is true and Stream is false
|
||||||
// then the server will let clients update their endpoints without
|
// then the server will let clients update their endpoints without
|
||||||
// breaking existing long-polling (Stream == true) connections.
|
// breaking existing long-polling (Stream == true) connections.
|
||||||
@ -122,14 +130,6 @@ func (m *mapSession) serve() {
|
|||||||
// the response and just wants a 200.
|
// the response and just wants a 200.
|
||||||
// !req.stream && req.OmitPeers
|
// !req.stream && req.OmitPeers
|
||||||
if m.isEndpointUpdate() {
|
if m.isEndpointUpdate() {
|
||||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
|
||||||
if err != nil {
|
|
||||||
httpError(m.w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.h.Change(c)
|
|
||||||
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
m.w.WriteHeader(http.StatusOK)
|
||||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||||
}
|
}
|
||||||
@ -142,6 +142,8 @@ func (m *mapSession) serve() {
|
|||||||
func (m *mapSession) serveLongPoll() {
|
func (m *mapSession) serveLongPoll() {
|
||||||
m.beforeServeLongPoll()
|
m.beforeServeLongPoll()
|
||||||
|
|
||||||
|
log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected")
|
||||||
|
|
||||||
// Clean up the session when the client disconnects
|
// Clean up the session when the client disconnects
|
||||||
defer func() {
|
defer func() {
|
||||||
m.cancelChMu.Lock()
|
m.cancelChMu.Lock()
|
||||||
@ -149,18 +151,38 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
close(m.cancelCh)
|
close(m.cancelCh)
|
||||||
m.cancelChMu.Unlock()
|
m.cancelChMu.Unlock()
|
||||||
|
|
||||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
_ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
|
||||||
// nodes has access to the same routes, so it might not be a big deal.
|
|
||||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
// When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather).
|
||||||
if err != nil {
|
// Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects.
|
||||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
// If it does reconnect, the existing mapSession will be replaced and the node remains online.
|
||||||
|
// If it doesn't reconnect within the timeout, we mark it as offline.
|
||||||
|
//
|
||||||
|
// This avoids flapping nodes in the UI and unnecessary churn in the network.
|
||||||
|
// This is not my favourite solution, but it kind of works in our eventually consistent world.
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
disconnected := true
|
||||||
|
// Wait up to 10 seconds for the node to reconnect.
|
||||||
|
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
|
||||||
|
for range 10 {
|
||||||
|
if m.h.mapBatcher.IsConnected(m.node.ID) {
|
||||||
|
disconnected = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
<-ticker.C
|
||||||
}
|
}
|
||||||
m.h.Change(disconnectChange)
|
|
||||||
|
|
||||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
if disconnected {
|
||||||
|
disconnectChanges, err := m.h.state.Disconnect(m.node.ID)
|
||||||
|
if err != nil {
|
||||||
|
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||||
|
}
|
||||||
|
|
||||||
m.afterServeLongPoll()
|
m.h.Change(disconnectChanges...)
|
||||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
m.afterServeLongPoll()
|
||||||
|
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up the client stream
|
// Set up the client stream
|
||||||
@ -172,25 +194,25 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
|
|
||||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||||
|
|
||||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||||
// where the change is sent before the node is in the batcher's node map
|
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
|
||||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
|
||||||
m.errf(err, "failed to add node to batcher")
|
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
|
||||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
|
||||||
select {
|
// from AvailableRoutes.
|
||||||
case m.ch <- &tailcfg.MapResponse{}:
|
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||||
default:
|
if err != nil {
|
||||||
// Channel might be closed
|
m.errf(err, "failed to update node from initial MapRequest")
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
// Connect the node after its state has been updated.
|
||||||
// but we still need to update routes and other state-level changes
|
// We send two separate change notifications because these are distinct operations:
|
||||||
connectChange := m.h.state.Connect(m.node)
|
// 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo)
|
||||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
|
||||||
m.h.Change(connectChange)
|
// While this results in two notifications, it ensures route data is synchronized before
|
||||||
}
|
// primary route selection occurs, which is critical for proper HA subnet router failover.
|
||||||
|
connectChanges := m.h.state.Connect(m.node.ID)
|
||||||
|
|
||||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||||
|
|
||||||
@ -235,6 +257,7 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||||
}
|
}
|
||||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||||
|
m.resetKeepAlive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
403
hscontrol/state/node_store.go
Normal file
403
hscontrol/state/node_store.go
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/views"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
batchSize = 10
|
||||||
|
batchTimeout = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
put = 1
|
||||||
|
del = 2
|
||||||
|
update = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
const prometheusNamespace = "headscale"
|
||||||
|
|
||||||
|
var (
|
||||||
|
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_operations_total",
|
||||||
|
Help: "Total number of NodeStore operations",
|
||||||
|
}, []string{"operation"})
|
||||||
|
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_operation_duration_seconds",
|
||||||
|
Help: "Duration of NodeStore operations",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
}, []string{"operation"})
|
||||||
|
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_batch_size",
|
||||||
|
Help: "Size of NodeStore write batches",
|
||||||
|
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
|
||||||
|
})
|
||||||
|
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_batch_duration_seconds",
|
||||||
|
Help: "Duration of NodeStore batch processing",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
})
|
||||||
|
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_snapshot_build_duration_seconds",
|
||||||
|
Help: "Duration of NodeStore snapshot building from nodes",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
})
|
||||||
|
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_nodes_total",
|
||||||
|
Help: "Total number of nodes in the NodeStore",
|
||||||
|
})
|
||||||
|
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_peers_calculation_duration_seconds",
|
||||||
|
Help: "Duration of peers calculation in NodeStore",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
})
|
||||||
|
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_queue_depth",
|
||||||
|
Help: "Current depth of NodeStore write queue",
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
// NodeStore is a thread-safe store for nodes.
|
||||||
|
// It is a copy-on-write structure, replacing the "snapshot"
|
||||||
|
// when a change to the structure occurs. It is optimised for reads,
|
||||||
|
// and while batches are not fast, they are grouped together
|
||||||
|
// to do less of the expensive peer calculation if there are many
|
||||||
|
// changes rapidly.
|
||||||
|
//
|
||||||
|
// Writes will block until committed, while reads are never
|
||||||
|
// blocked. This means that the caller of a write operation
|
||||||
|
// is responsible for ensuring an update depending on a write
|
||||||
|
// is not issued before the write is complete.
|
||||||
|
type NodeStore struct {
|
||||||
|
data atomic.Pointer[Snapshot]
|
||||||
|
|
||||||
|
peersFunc PeersFunc
|
||||||
|
writeQueue chan work
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||||
|
nodes := make(map[types.NodeID]types.Node, len(allNodes))
|
||||||
|
for _, n := range allNodes {
|
||||||
|
nodes[n.ID] = *n
|
||||||
|
}
|
||||||
|
snap := snapshotFromNodes(nodes, peersFunc)
|
||||||
|
|
||||||
|
store := &NodeStore{
|
||||||
|
peersFunc: peersFunc,
|
||||||
|
}
|
||||||
|
store.data.Store(&snap)
|
||||||
|
|
||||||
|
// Initialize node count gauge
|
||||||
|
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot is the representation of the current state of the NodeStore.
|
||||||
|
// It contains all nodes and their relationships.
|
||||||
|
// It is a copy-on-write structure, meaning that when a write occurs,
|
||||||
|
// a new Snapshot is created with the updated state,
|
||||||
|
// and replaces the old one atomically.
|
||||||
|
type Snapshot struct {
|
||||||
|
// nodesByID is the main source of truth for nodes.
|
||||||
|
nodesByID map[types.NodeID]types.Node
|
||||||
|
|
||||||
|
// calculated from nodesByID
|
||||||
|
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||||
|
peersByNode map[types.NodeID][]types.NodeView
|
||||||
|
nodesByUser map[types.UserID][]types.NodeView
|
||||||
|
allNodes []types.NodeView
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||||
|
// with the relationships between nodes and their peers.
|
||||||
|
// This will typically be used to calculate which nodes can see each other
|
||||||
|
// based on the current policy.
|
||||||
|
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||||
|
|
||||||
|
// work represents a single operation to be performed on the NodeStore.
|
||||||
|
type work struct {
|
||||||
|
op int
|
||||||
|
nodeID types.NodeID
|
||||||
|
node types.Node
|
||||||
|
updateFn UpdateNodeFunc
|
||||||
|
result chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutNode adds or updates a node in the store.
|
||||||
|
// If the node already exists, it will be replaced.
|
||||||
|
// If the node does not exist, it will be added.
|
||||||
|
// This is a blocking operation that waits for the write to complete.
|
||||||
|
func (s *NodeStore) PutNode(n types.Node) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
work := work{
|
||||||
|
op: put,
|
||||||
|
nodeID: n.ID,
|
||||||
|
node: n,
|
||||||
|
result: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeStoreQueueDepth.Inc()
|
||||||
|
s.writeQueue <- work
|
||||||
|
<-work.result
|
||||||
|
nodeStoreQueueDepth.Dec()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||||
|
type UpdateNodeFunc func(n *types.Node)
|
||||||
|
|
||||||
|
// UpdateNode applies a function to modify a specific node in the store.
|
||||||
|
// This is a blocking operation that waits for the write to complete.
|
||||||
|
// This is analogous to a database "transaction", or, the caller should
|
||||||
|
// rather collect all data they want to change, and then call this function.
|
||||||
|
// Fewer calls are better.
|
||||||
|
//
|
||||||
|
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||||
|
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||||
|
// This is because the main nodesByID map contains the struct, and every other map is using a
|
||||||
|
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
|
||||||
|
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||||
|
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||||
|
// on all read operations.
|
||||||
|
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
work := work{
|
||||||
|
op: update,
|
||||||
|
nodeID: nodeID,
|
||||||
|
updateFn: updateFn,
|
||||||
|
result: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeStoreQueueDepth.Inc()
|
||||||
|
s.writeQueue <- work
|
||||||
|
<-work.result
|
||||||
|
nodeStoreQueueDepth.Dec()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNode removes a node from the store by its ID.
|
||||||
|
// This is a blocking operation that waits for the write to complete.
|
||||||
|
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
work := work{
|
||||||
|
op: del,
|
||||||
|
nodeID: id,
|
||||||
|
result: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeStoreQueueDepth.Inc()
|
||||||
|
s.writeQueue <- work
|
||||||
|
<-work.result
|
||||||
|
nodeStoreQueueDepth.Dec()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("delete").Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes the NodeStore and starts processing the write queue.
|
||||||
|
func (s *NodeStore) Start() {
|
||||||
|
s.writeQueue = make(chan work)
|
||||||
|
go s.processWrite()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the NodeStore.
|
||||||
|
func (s *NodeStore) Stop() {
|
||||||
|
close(s.writeQueue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processWrite processes the write queue in batches.
|
||||||
|
func (s *NodeStore) processWrite() {
|
||||||
|
c := time.NewTicker(batchTimeout)
|
||||||
|
defer c.Stop()
|
||||||
|
batch := make([]work, 0, batchSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case w, ok := <-s.writeQueue:
|
||||||
|
if !ok {
|
||||||
|
// Channel closed, apply any remaining batch and exit
|
||||||
|
if len(batch) != 0 {
|
||||||
|
s.applyBatch(batch)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch = append(batch, w)
|
||||||
|
if len(batch) >= batchSize {
|
||||||
|
s.applyBatch(batch)
|
||||||
|
batch = batch[:0]
|
||||||
|
c.Reset(batchTimeout)
|
||||||
|
}
|
||||||
|
case <-c.C:
|
||||||
|
if len(batch) != 0 {
|
||||||
|
s.applyBatch(batch)
|
||||||
|
batch = batch[:0]
|
||||||
|
}
|
||||||
|
c.Reset(batchTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyBatch applies a batch of work to the node store.
|
||||||
|
// This means that it takes a copy of the current nodes,
|
||||||
|
// then applies the batch of operations to that copy,
|
||||||
|
// runs any precomputation needed (like calculating peers),
|
||||||
|
// and finally replaces the snapshot in the store with the new one.
|
||||||
|
// The replacement of the snapshot is atomic, ensuring that reads
|
||||||
|
// are never blocked by writes.
|
||||||
|
// Each write item is blocked until the batch is applied to ensure
|
||||||
|
// the caller knows the operation is complete and do not send any
|
||||||
|
// updates that are dependent on a read that is yet to be written.
|
||||||
|
func (s *NodeStore) applyBatch(batch []work) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreBatchDuration)
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreBatchSize.Observe(float64(len(batch)))
|
||||||
|
|
||||||
|
nodes := make(map[types.NodeID]types.Node)
|
||||||
|
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||||
|
|
||||||
|
for _, w := range batch {
|
||||||
|
switch w.op {
|
||||||
|
case put:
|
||||||
|
nodes[w.nodeID] = w.node
|
||||||
|
case update:
|
||||||
|
// Update the specific node identified by nodeID
|
||||||
|
if n, exists := nodes[w.nodeID]; exists {
|
||||||
|
w.updateFn(&n)
|
||||||
|
nodes[w.nodeID] = n
|
||||||
|
}
|
||||||
|
case del:
|
||||||
|
delete(nodes, w.nodeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
||||||
|
s.data.Store(&newSnap)
|
||||||
|
|
||||||
|
// Update node count gauge
|
||||||
|
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||||
|
|
||||||
|
for _, w := range batch {
|
||||||
|
close(w.result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// snapshotFromNodes creates a new Snapshot from the provided nodes.
|
||||||
|
// It builds a lot of "indexes" to make lookups fast for datasets we
|
||||||
|
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
|
||||||
|
// This is not a fast operation, it is the "slow" part of our copy-on-write
|
||||||
|
// structure, but it allows us to have fast reads and efficient lookups.
|
||||||
|
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
allNodes := make([]types.NodeView, 0, len(nodes))
|
||||||
|
for _, n := range nodes {
|
||||||
|
allNodes = append(allNodes, n.View())
|
||||||
|
}
|
||||||
|
|
||||||
|
newSnap := Snapshot{
|
||||||
|
nodesByID: nodes,
|
||||||
|
allNodes: allNodes,
|
||||||
|
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||||
|
|
||||||
|
// peersByNode is most likely the most expensive operation,
|
||||||
|
// it will use the list of all nodes, combined with the
|
||||||
|
// current policy to precalculate which nodes are peers and
|
||||||
|
// can see each other.
|
||||||
|
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||||
|
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||||
|
defer peersTimer.ObserveDuration()
|
||||||
|
return peersFunc(allNodes)
|
||||||
|
}(),
|
||||||
|
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build nodesByUser and nodesByNodeKey maps
|
||||||
|
for _, n := range nodes {
|
||||||
|
nodeView := n.View()
|
||||||
|
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
|
||||||
|
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
|
||||||
|
}
|
||||||
|
|
||||||
|
return newSnap
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNode retrieves a node by its ID.
|
||||||
|
// The bool indicates if the node exists or is available (like "err not found").
|
||||||
|
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||||
|
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||||
|
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("get").Inc()
|
||||||
|
|
||||||
|
n, exists := s.data.Load().nodesByID[id]
|
||||||
|
if !exists {
|
||||||
|
return types.NodeView{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return n.View(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
||||||
|
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
|
||||||
|
return s.data.Load().nodesByNodeKey[nodeKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNodes returns a slice of all nodes in the store.
|
||||||
|
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("list").Inc()
|
||||||
|
|
||||||
|
return views.SliceOf(s.data.Load().allNodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPeers returns a slice of all peers for a given node ID.
|
||||||
|
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("list_peers").Inc()
|
||||||
|
|
||||||
|
return views.SliceOf(s.data.Load().peersByNode[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
||||||
|
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
|
||||||
|
|
||||||
|
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
||||||
|
}
|
501
hscontrol/state/node_store_test.go
Normal file
501
hscontrol/state/node_store_test.go
Normal file
@ -0,0 +1,501 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSnapshotFromNodes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
|
||||||
|
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty nodes",
|
||||||
|
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||||
|
nodes := make(map[types.NodeID]types.Node)
|
||||||
|
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||||
|
return make(map[types.NodeID][]types.NodeView)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes, peersFunc
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
|
assert.Empty(t, snapshot.nodesByID)
|
||||||
|
assert.Empty(t, snapshot.allNodes)
|
||||||
|
assert.Empty(t, snapshot.peersByNode)
|
||||||
|
assert.Empty(t, snapshot.nodesByUser)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single node",
|
||||||
|
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||||
|
nodes := map[types.NodeID]types.Node{
|
||||||
|
1: createTestNode(1, 1, "user1", "node1"),
|
||||||
|
}
|
||||||
|
return nodes, allowAllPeersFunc
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
|
assert.Len(t, snapshot.nodesByID, 1)
|
||||||
|
assert.Len(t, snapshot.allNodes, 1)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 1)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 1)
|
||||||
|
|
||||||
|
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||||
|
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
|
||||||
|
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||||
|
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple nodes same user",
|
||||||
|
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||||
|
nodes := map[types.NodeID]types.Node{
|
||||||
|
1: createTestNode(1, 1, "user1", "node1"),
|
||||||
|
2: createTestNode(2, 1, "user1", "node2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes, allowAllPeersFunc
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
|
assert.Len(t, snapshot.nodesByID, 2)
|
||||||
|
assert.Len(t, snapshot.allNodes, 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 2)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 1)
|
||||||
|
|
||||||
|
// Each node sees the other as peer (but not itself)
|
||||||
|
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||||
|
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||||
|
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||||
|
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple nodes different users",
|
||||||
|
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||||
|
nodes := map[types.NodeID]types.Node{
|
||||||
|
1: createTestNode(1, 1, "user1", "node1"),
|
||||||
|
2: createTestNode(2, 2, "user2", "node2"),
|
||||||
|
3: createTestNode(3, 1, "user1", "node3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes, allowAllPeersFunc
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
|
assert.Len(t, snapshot.nodesByID, 3)
|
||||||
|
assert.Len(t, snapshot.allNodes, 3)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 3)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 2)
|
||||||
|
|
||||||
|
// Each node should have 2 peers (all others, but not itself)
|
||||||
|
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||||
|
|
||||||
|
// User groupings
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
|
||||||
|
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "odd-even peers filtering",
|
||||||
|
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||||
|
nodes := map[types.NodeID]types.Node{
|
||||||
|
1: createTestNode(1, 1, "user1", "node1"),
|
||||||
|
2: createTestNode(2, 2, "user2", "node2"),
|
||||||
|
3: createTestNode(3, 3, "user3", "node3"),
|
||||||
|
4: createTestNode(4, 4, "user4", "node4"),
|
||||||
|
}
|
||||||
|
peersFunc := oddEvenPeersFunc
|
||||||
|
|
||||||
|
return nodes, peersFunc
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
|
assert.Len(t, snapshot.nodesByID, 4)
|
||||||
|
assert.Len(t, snapshot.allNodes, 4)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 4)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 4)
|
||||||
|
|
||||||
|
// Odd nodes should only see other odd nodes as peers
|
||||||
|
require.Len(t, snapshot.peersByNode[1], 1)
|
||||||
|
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||||
|
|
||||||
|
require.Len(t, snapshot.peersByNode[3], 1)
|
||||||
|
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||||
|
|
||||||
|
// Even nodes should only see other even nodes as peers
|
||||||
|
require.Len(t, snapshot.peersByNode[2], 1)
|
||||||
|
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||||
|
|
||||||
|
require.Len(t, snapshot.peersByNode[4], 1)
|
||||||
|
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
nodes, peersFunc := tt.setupFunc()
|
||||||
|
snapshot := snapshotFromNodes(nodes, peersFunc)
|
||||||
|
tt.validate(t, nodes, snapshot)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
|
||||||
|
now := time.Now()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
|
||||||
|
ipv4 := netip.MustParseAddr("100.64.0.1")
|
||||||
|
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
||||||
|
|
||||||
|
return types.Node{
|
||||||
|
ID: nodeID,
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: hostname,
|
||||||
|
GivenName: hostname,
|
||||||
|
UserID: userID,
|
||||||
|
User: types.User{
|
||||||
|
Name: username,
|
||||||
|
DisplayName: username,
|
||||||
|
},
|
||||||
|
RegisterMethod: "test",
|
||||||
|
IPv4: &ipv4,
|
||||||
|
IPv6: &ipv6,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peer functions
|
||||||
|
|
||||||
|
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||||
|
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||||
|
for _, node := range nodes {
|
||||||
|
var peers []types.NodeView
|
||||||
|
for _, n := range nodes {
|
||||||
|
if n.ID() != node.ID() {
|
||||||
|
peers = append(peers, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret[node.ID()] = peers
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||||
|
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||||
|
for _, node := range nodes {
|
||||||
|
var peers []types.NodeView
|
||||||
|
nodeIsOdd := node.ID()%2 == 1
|
||||||
|
|
||||||
|
for _, n := range nodes {
|
||||||
|
if n.ID() == node.ID() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIsOdd := n.ID()%2 == 1
|
||||||
|
|
||||||
|
// Only add peer if both are odd or both are even
|
||||||
|
if nodeIsOdd == peerIsOdd {
|
||||||
|
peers = append(peers, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret[node.ID()] = peers
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeStoreOperations(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupFunc func(t *testing.T) *NodeStore
|
||||||
|
steps []testStep
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create empty store and add single node",
|
||||||
|
setupFunc: func(t *testing.T) *NodeStore {
|
||||||
|
return NewNodeStore(nil, allowAllPeersFunc)
|
||||||
|
},
|
||||||
|
steps: []testStep{
|
||||||
|
{
|
||||||
|
name: "verify empty store",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Empty(t, snapshot.nodesByID)
|
||||||
|
assert.Empty(t, snapshot.allNodes)
|
||||||
|
assert.Empty(t, snapshot.peersByNode)
|
||||||
|
assert.Empty(t, snapshot.nodesByUser)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add first node",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
node := createTestNode(1, 1, "user1", "node1")
|
||||||
|
store.PutNode(node)
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 1)
|
||||||
|
assert.Len(t, snapshot.allNodes, 1)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 1)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 1)
|
||||||
|
|
||||||
|
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||||
|
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
|
||||||
|
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create store with initial node and add more",
|
||||||
|
setupFunc: func(t *testing.T) *NodeStore {
|
||||||
|
node1 := createTestNode(1, 1, "user1", "node1")
|
||||||
|
initialNodes := types.Nodes{&node1}
|
||||||
|
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||||
|
},
|
||||||
|
steps: []testStep{
|
||||||
|
{
|
||||||
|
name: "verify initial state",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 1)
|
||||||
|
assert.Len(t, snapshot.allNodes, 1)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 1)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 1)
|
||||||
|
assert.Empty(t, snapshot.peersByNode[1])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add second node same user",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
node2 := createTestNode(2, 1, "user1", "node2")
|
||||||
|
store.PutNode(node2)
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 2)
|
||||||
|
assert.Len(t, snapshot.allNodes, 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 2)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 1)
|
||||||
|
|
||||||
|
// Now both nodes should see each other as peers
|
||||||
|
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||||
|
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||||
|
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||||
|
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add third node different user",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
node3 := createTestNode(3, 2, "user2", "node3")
|
||||||
|
store.PutNode(node3)
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 3)
|
||||||
|
assert.Len(t, snapshot.allNodes, 3)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 3)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 2)
|
||||||
|
|
||||||
|
// All nodes should see the other 2 as peers
|
||||||
|
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||||
|
|
||||||
|
// User groupings
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
|
||||||
|
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test node deletion",
|
||||||
|
setupFunc: func(t *testing.T) *NodeStore {
|
||||||
|
node1 := createTestNode(1, 1, "user1", "node1")
|
||||||
|
node2 := createTestNode(2, 1, "user1", "node2")
|
||||||
|
node3 := createTestNode(3, 2, "user2", "node3")
|
||||||
|
initialNodes := types.Nodes{&node1, &node2, &node3}
|
||||||
|
|
||||||
|
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||||
|
},
|
||||||
|
steps: []testStep{
|
||||||
|
{
|
||||||
|
name: "verify initial 3 nodes",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 3)
|
||||||
|
assert.Len(t, snapshot.allNodes, 3)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 3)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete middle node",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
store.DeleteNode(2)
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 2)
|
||||||
|
assert.Len(t, snapshot.allNodes, 2)
|
||||||
|
assert.Len(t, snapshot.peersByNode, 2)
|
||||||
|
assert.Len(t, snapshot.nodesByUser, 2)
|
||||||
|
|
||||||
|
// Node 2 should be gone
|
||||||
|
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
|
||||||
|
|
||||||
|
// Remaining nodes should see each other as peers
|
||||||
|
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||||
|
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||||
|
assert.Len(t, snapshot.peersByNode[3], 1)
|
||||||
|
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||||
|
|
||||||
|
// User groupings updated
|
||||||
|
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
|
||||||
|
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete all remaining nodes",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
store.DeleteNode(1)
|
||||||
|
store.DeleteNode(3)
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Empty(t, snapshot.nodesByID)
|
||||||
|
assert.Empty(t, snapshot.allNodes)
|
||||||
|
assert.Empty(t, snapshot.peersByNode)
|
||||||
|
assert.Empty(t, snapshot.nodesByUser)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test node updates",
|
||||||
|
setupFunc: func(t *testing.T) *NodeStore {
|
||||||
|
node1 := createTestNode(1, 1, "user1", "node1")
|
||||||
|
node2 := createTestNode(2, 1, "user1", "node2")
|
||||||
|
initialNodes := types.Nodes{&node1, &node2}
|
||||||
|
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||||
|
},
|
||||||
|
steps: []testStep{
|
||||||
|
{
|
||||||
|
name: "verify initial hostnames",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
|
||||||
|
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update node hostname",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
store.UpdateNode(1, func(n *types.Node) {
|
||||||
|
n.Hostname = "updated-node1"
|
||||||
|
n.GivenName = "updated-node1"
|
||||||
|
})
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
|
||||||
|
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
|
||||||
|
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
|
||||||
|
|
||||||
|
// Peers should still work correctly
|
||||||
|
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||||
|
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test with odd-even peers filtering",
|
||||||
|
setupFunc: func(t *testing.T) *NodeStore {
|
||||||
|
return NewNodeStore(nil, oddEvenPeersFunc)
|
||||||
|
},
|
||||||
|
steps: []testStep{
|
||||||
|
{
|
||||||
|
name: "add nodes with odd-even filtering",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
// Add nodes in sequence
|
||||||
|
store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||||
|
store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||||
|
store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||||
|
store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 4)
|
||||||
|
|
||||||
|
// Verify odd-even peer relationships
|
||||||
|
require.Len(t, snapshot.peersByNode[1], 1)
|
||||||
|
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||||
|
|
||||||
|
require.Len(t, snapshot.peersByNode[2], 1)
|
||||||
|
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||||
|
|
||||||
|
require.Len(t, snapshot.peersByNode[3], 1)
|
||||||
|
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||||
|
|
||||||
|
require.Len(t, snapshot.peersByNode[4], 1)
|
||||||
|
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete odd node and verify even nodes unaffected",
|
||||||
|
action: func(store *NodeStore) {
|
||||||
|
store.DeleteNode(1)
|
||||||
|
|
||||||
|
snapshot := store.data.Load()
|
||||||
|
assert.Len(t, snapshot.nodesByID, 3)
|
||||||
|
|
||||||
|
// Node 3 (odd) should now have no peers
|
||||||
|
assert.Empty(t, snapshot.peersByNode[3])
|
||||||
|
|
||||||
|
// Even nodes should still see each other
|
||||||
|
require.Len(t, snapshot.peersByNode[2], 1)
|
||||||
|
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||||
|
require.Len(t, snapshot.peersByNode[4], 1)
|
||||||
|
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
store := tt.setupFunc(t)
|
||||||
|
store.Start()
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
for _, step := range tt.steps {
|
||||||
|
t.Run(step.name, func(t *testing.T) {
|
||||||
|
step.action(store)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStep struct {
|
||||||
|
name string
|
||||||
|
action func(store *NodeStore)
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool {
|
|||||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
@ -355,6 +356,7 @@ func (node *Node) Proto() *v1.Node {
|
|||||||
GivenName: node.GivenName,
|
GivenName: node.GivenName,
|
||||||
User: node.User.Proto(),
|
User: node.User.Proto(),
|
||||||
ForcedTags: node.ForcedTags,
|
ForcedTags: node.ForcedTags,
|
||||||
|
Online: node.IsOnline != nil && *node.IsOnline,
|
||||||
|
|
||||||
// Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has
|
// Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has
|
||||||
// to be populated manually with PrimaryRoute, to ensure it includes the
|
// to be populated manually with PrimaryRoute, to ensure it includes the
|
||||||
@ -419,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
||||||
|
//
|
||||||
|
// IMPORTANT: This method is used for internal data structures and should NOT be used
|
||||||
|
// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually
|
||||||
|
// with PrimaryRoutes to ensure it includes only routes actively served by the node.
|
||||||
|
// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto.
|
||||||
func (node *Node) SubnetRoutes() []netip.Prefix {
|
func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||||
var routes []netip.Prefix
|
var routes []netip.Prefix
|
||||||
|
|
||||||
@ -511,11 +518,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if node.Hostname != hostInfo.Hostname {
|
if node.Hostname != hostInfo.Hostname {
|
||||||
|
log.Trace().
|
||||||
|
Str("node.id", node.ID.String()).
|
||||||
|
Str("old_hostname", node.Hostname).
|
||||||
|
Str("new_hostname", hostInfo.Hostname).
|
||||||
|
Str("old_given_name", node.GivenName).
|
||||||
|
Bool("given_name_changed", node.GivenNameHasBeenChanged()).
|
||||||
|
Msg("Updating hostname from hostinfo")
|
||||||
|
|
||||||
if node.GivenNameHasBeenChanged() {
|
if node.GivenNameHasBeenChanged() {
|
||||||
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
|
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Hostname = hostInfo.Hostname
|
node.Hostname = hostInfo.Hostname
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Str("node.id", node.ID.String()).
|
||||||
|
Str("new_hostname", node.Hostname).
|
||||||
|
Str("new_given_name", node.GivenName).
|
||||||
|
Msg("Hostname updated")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -759,6 +780,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix {
|
|||||||
return v.ж.ExitRoutes()
|
return v.ж.ExitRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequestTags returns the ACL tags that the node is requesting.
|
||||||
|
func (v NodeView) RequestTags() []string {
|
||||||
|
if !v.Valid() || !v.Hostinfo().Valid() {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
return v.Hostinfo().RequestTags().AsSlice()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proto converts the NodeView to a protobuf representation.
|
||||||
|
func (v NodeView) Proto() *v1.Node {
|
||||||
|
if !v.Valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v.ж.Proto()
|
||||||
|
}
|
||||||
|
|
||||||
// HasIP reports if a node has a given IP address.
|
// HasIP reports if a node has a given IP address.
|
||||||
func (v NodeView) HasIP(i netip.Addr) bool {
|
func (v NodeView) HasIP(i netip.Addr) bool {
|
||||||
if !v.Valid() {
|
if !v.Valid() {
|
||||||
|
@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse each hop line
|
// Parse each hop line
|
||||||
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
|
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
|
||||||
|
|
||||||
for i := 1; i < len(lines); i++ {
|
for i := 1; i < len(lines); i++ {
|
||||||
matches := hopRegex.FindStringSubmatch(lines[i])
|
matches := hopRegex.FindStringSubmatch(lines[i])
|
||||||
|
@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||||||
assert.NoError(ct, err)
|
assert.NoError(ct, err)
|
||||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||||
}
|
}
|
||||||
|
assertTailscaleNodesLogout(t, allClients)
|
||||||
}, shortAccessTTL+10*time.Second, 5*time.Second)
|
}, shortAccessTTL+10*time.Second, 5*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -547,6 +547,8 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
|||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
// Wait for nodestore batch processing to complete
|
||||||
|
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||||
var nodes []*v1.Node
|
var nodes []*v1.Node
|
||||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||||
err := executeAndUnmarshal(
|
err := executeAndUnmarshal(
|
||||||
@ -642,27 +644,34 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
|||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
err = executeAndUnmarshal(
|
// Wait for nodestore batch processing to complete
|
||||||
headscale,
|
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||||
[]string{
|
assert.Eventually(t, func() bool {
|
||||||
"headscale",
|
err = executeAndUnmarshal(
|
||||||
"node",
|
headscale,
|
||||||
"list",
|
[]string{
|
||||||
"--output",
|
"headscale",
|
||||||
"json",
|
"node",
|
||||||
},
|
"list",
|
||||||
&nodes,
|
"--output",
|
||||||
)
|
"json",
|
||||||
|
},
|
||||||
|
&nodes,
|
||||||
|
)
|
||||||
|
|
||||||
assertNoErr(t, err)
|
if err != nil || len(nodes) != 3 {
|
||||||
assert.Len(t, nodes, 3)
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
|
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
|
||||||
givenName := fmt.Sprintf("%d-givenname", node.GetId())
|
givenName := fmt.Sprintf("%d-givenname", node.GetId())
|
||||||
assert.Equal(t, hostname+"NEW", node.GetName())
|
if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName {
|
||||||
assert.Equal(t, givenName, node.GetGivenName())
|
return false
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExpireNode(t *testing.T) {
|
func TestExpireNode(t *testing.T) {
|
||||||
|
@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) {
|
|||||||
assert.Len(t, node.GetSubnetRoutes(), 1)
|
assert.Len(t, node.GetSubnetRoutes(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate to clients
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
// Verify that the clients can see the new routes
|
||||||
|
for _, client := range allClients {
|
||||||
|
status, err := client.Status()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
// Verify that the clients can see the new routes
|
for _, peerKey := range status.Peers() {
|
||||||
for _, client := range allClients {
|
peerStatus := status.Peer[peerKey]
|
||||||
status, err := client.Status()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for _, peerKey := range status.Peers() {
|
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||||
peerStatus := status.Peer[peerKey]
|
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
|
||||||
|
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
||||||
assert.NotNil(t, peerStatus.PrimaryRoutes)
|
}
|
||||||
|
|
||||||
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3)
|
|
||||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
|
||||||
}
|
}
|
||||||
}
|
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
|
||||||
|
|
||||||
_, err = headscale.ApproveRoutes(
|
_, err = headscale.ApproveRoutes(
|
||||||
1,
|
1,
|
||||||
@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate to nodes
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
for _, node := range nodes {
|
||||||
require.NoError(t, err)
|
if node.GetId() == 1 {
|
||||||
|
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
|
||||||
for _, node := range nodes {
|
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
|
||||||
if node.GetId() == 1 {
|
assert.Empty(c, node.GetSubnetRoutes())
|
||||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
|
} else if node.GetId() == 2 {
|
||||||
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
|
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
|
||||||
assert.Empty(t, node.GetSubnetRoutes())
|
assert.Empty(c, node.GetApprovedRoutes())
|
||||||
} else if node.GetId() == 2 {
|
assert.Empty(c, node.GetSubnetRoutes())
|
||||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
|
} else {
|
||||||
assert.Empty(t, node.GetApprovedRoutes())
|
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
|
||||||
assert.Empty(t, node.GetSubnetRoutes())
|
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
|
||||||
} else {
|
assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
|
||||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
|
}
|
||||||
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
|
|
||||||
assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
|
|
||||||
}
|
}
|
||||||
}
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
|
||||||
|
|
||||||
// Verify that the clients can see the new routes
|
// Verify that the clients can see the new routes
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
// Wait for route configuration changes after advertising routes
|
||||||
|
var nodes []*v1.Node
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err := headscale.ListNodes()
|
requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0)
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||||
|
}, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved")
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
|
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
|
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
|
||||||
|
|
||||||
// Verify that no routes has been sent to the client,
|
// Verify that no routes has been sent to the client,
|
||||||
// they are not yet enabled.
|
// they are not yet enabled.
|
||||||
@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
// Wait for route approval on first subnet router
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||||
|
}, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route")
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
|
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine and can access
|
// Verify that the client has routes from the primary machine and can access
|
||||||
// the webservice.
|
// the webservice.
|
||||||
@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
// Wait for route approval on second subnet router
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||||
|
}, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route")
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine
|
// Verify that the client has routes from the primary machine
|
||||||
srs1 = subRouter1.MustStatus()
|
srs1 = subRouter1.MustStatus()
|
||||||
@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
// Wait for route approval on third subnet router
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0)
|
||||||
|
}, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route")
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 1, 0)
|
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine
|
// Verify that the client has routes from the primary machine
|
||||||
srs1 = subRouter1.MustStatus()
|
srs1 = subRouter1.MustStatus()
|
||||||
@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
|
|
||||||
tr, err = client.Traceroute(webip)
|
// Wait for traceroute to work correctly through the expected router
|
||||||
require.NoError(t, err)
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
assertTracerouteViaIP(t, tr, subRouter1.MustIPv4())
|
tr, err := client.Traceroute(webip)
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
|
// Get the expected router IP - use a more robust approach to handle temporary disconnections
|
||||||
|
ips, err := subRouter1.IPs()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
|
||||||
|
|
||||||
|
var expectedIP netip.Addr
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.Is4() {
|
||||||
|
expectedIP = ip
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
|
||||||
|
|
||||||
|
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1")
|
||||||
|
|
||||||
// Take down the current primary
|
// Take down the current primary
|
||||||
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
|
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
|
||||||
@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
err = subRouter1.Down()
|
err = subRouter1.Down()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for router status changes after r1 goes down
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
srs2 = subRouter2.MustStatus()
|
||||||
|
clientStatus = client.MustStatus()
|
||||||
|
|
||||||
srs2 = subRouter2.MustStatus()
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
clientStatus = client.MustStatus()
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||||
|
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
|
||||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||||
|
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down")
|
||||||
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
|
||||||
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
|
|
||||||
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up")
|
|
||||||
|
|
||||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
require.NotNil(t, srs2PeerStatus.PrimaryRoutes)
|
require.NotNil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
err = subRouter2.Down()
|
err = subRouter2.Down()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for router status changes after r2 goes down
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
clientStatus, err = client.Status()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
// TODO(kradalby): Check client status
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
// Both are expected to be down
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
|
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||||
|
|
||||||
// Verify that the route is not presented from either router
|
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
|
||||||
clientStatus, err = client.Status()
|
assert.False(c, srs2PeerStatus.Online, "r2 should be offline")
|
||||||
require.NoError(t, err)
|
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||||
|
}, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down")
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
|
||||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
|
||||||
|
|
||||||
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
|
||||||
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
|
|
||||||
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down")
|
|
||||||
|
|
||||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
err = subRouter1.Up()
|
err = subRouter1.Up()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for router status changes after r1 comes back up
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
clientStatus, err = client.Status()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
clientStatus, err = client.Status()
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
require.NoError(t, err)
|
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||||
|
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
assert.True(c, srs1PeerStatus.Online, "r1 should be back online")
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
assert.False(c, srs2PeerStatus.Online, "r2 should still be offline")
|
||||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
assert.True(c, srs3PeerStatus.Online, "r3 should still be online")
|
||||||
|
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up")
|
||||||
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
|
|
||||||
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
|
|
||||||
assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available")
|
|
||||||
|
|
||||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
err = subRouter2.Up()
|
err = subRouter2.Up()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for nodestore batch processing to complete and online status to be updated
|
||||||
|
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
clientStatus, err = client.Status()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||||
clientStatus, err = client.Status()
|
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||||
require.NoError(t, err)
|
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||||
|
|
||||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
assert.True(c, srs1PeerStatus.Online, "r1 should be online")
|
||||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
|
||||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2")
|
||||||
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
|
|
||||||
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
|
|
||||||
assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up")
|
|
||||||
|
|
||||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||||
@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
||||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
|
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for nodestore batch processing and route state changes to complete
|
||||||
|
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
// After disabling route on r3, r1 should become primary with 1 subnet route
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||||
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
}, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3")
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
||||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
|
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for nodestore batch processing and route state changes to complete
|
||||||
|
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
// After disabling route on r1, r2 should become primary with 1 subnet route
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||||
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
}, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1")
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
|
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
|
||||||
)
|
)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes after re-enabling r1
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 6)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||||
assert.Len(t, nodes, 6)
|
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||||
|
}, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping")
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
|
||||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate to nodes
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 2)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||||
require.NoError(t, err)
|
requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0)
|
||||||
require.Len(t, nodes, 2)
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
|
||||||
|
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
|
||||||
requireNodeRouteCount(t, nodes[1], 0, 0, 0)
|
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine
|
// Verify that the client has routes from the primary machine
|
||||||
srs1, _ := subRouter1.Status()
|
srs1, _ := subRouter1.Status()
|
||||||
@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) {
|
|||||||
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
||||||
requireNodeRouteCount(t, nodes[1], 2, 2, 2)
|
requireNodeRouteCount(t, nodes[1], 2, 2, 2)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate to clients
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
// Verify that the clients can see the new routes
|
||||||
|
for _, client := range allClients {
|
||||||
|
status, err := client.Status()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
|
||||||
// Verify that the clients can see the new routes
|
for _, peerKey := range status.Peers() {
|
||||||
for _, client := range allClients {
|
peerStatus := status.Peer[peerKey]
|
||||||
status, err := client.Status()
|
|
||||||
assertNoErr(t, err)
|
|
||||||
|
|
||||||
for _, peerKey := range status.Peers() {
|
assert.NotNil(c, peerStatus.AllowedIPs)
|
||||||
peerStatus := status.Peer[peerKey]
|
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
|
||||||
|
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
||||||
require.NotNil(t, peerStatus.AllowedIPs)
|
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
|
||||||
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4)
|
}
|
||||||
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
|
||||||
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
|
|
||||||
}
|
}
|
||||||
}
|
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSubnetRouterMultiNetwork is an evolution of the subnet router test.
|
// TestSubnetRouterMultiNetwork is an evolution of the subnet router test.
|
||||||
@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate to nodes and clients
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 2)
|
||||||
|
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
// Verify that the routes have been sent to the client
|
||||||
require.NoError(t, err)
|
status, err = user2c.Status()
|
||||||
assert.Len(t, nodes, 2)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
for _, peerKey := range status.Peers() {
|
||||||
status, err = user2c.Status()
|
peerStatus := status.Peer[peerKey]
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for _, peerKey := range status.Peers() {
|
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
||||||
peerStatus := status.Peer[peerKey]
|
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
|
||||||
|
}
|
||||||
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
|
||||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref})
|
|
||||||
}
|
|
||||||
|
|
||||||
usernet1, err := scenario.Network("usernet1")
|
usernet1, err := scenario.Network("usernet1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
|||||||
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
|
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate to nodes and clients
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err = headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 2)
|
||||||
|
requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2)
|
||||||
|
|
||||||
nodes, err = headscale.ListNodes()
|
// Verify that the routes have been sent to the client
|
||||||
require.NoError(t, err)
|
status, err = user2c.Status()
|
||||||
assert.Len(t, nodes, 2)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
for _, peerKey := range status.Peers() {
|
||||||
status, err = user2c.Status()
|
peerStatus := status.Peer[peerKey]
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for _, peerKey := range status.Peers() {
|
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
||||||
peerStatus := status.Peer[peerKey]
|
}
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
|
||||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tell user2c to use user1c as an exit node.
|
// Tell user2c to use user1c as an exit node.
|
||||||
command = []string{
|
command = []string{
|
||||||
@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
var nodes []*v1.Node
|
||||||
opts := []hsic.Option{
|
opts := []hsic.Option{
|
||||||
hsic.WithTestName("autoapprovemulti"),
|
hsic.WithTestName("autoapprovemulti"),
|
||||||
hsic.WithEmbeddedDERPServerOnly(),
|
hsic.WithEmbeddedDERPServerOnly(),
|
||||||
@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
// These route should auto approve, so the node is expected to have a route
|
// These route should auto approve, so the node is expected to have a route
|
||||||
// for all counts.
|
// for all counts.
|
||||||
nodes, err := headscale.ListNodes()
|
nodes, err := headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
// Verify that the routes have been sent to the client.
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
err = headscale.SetPolicy(tt.pol)
|
err = headscale.SetPolicy(tt.pol)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
// These route should auto approve, so the node is expected to have a route
|
// These route should auto approve, so the node is expected to have a route
|
||||||
// for all counts.
|
// for all counts.
|
||||||
nodes, err = headscale.ListNodes()
|
nodes, err = headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
// Verify that the routes have been sent to the client.
|
||||||
status, err = client.Status()
|
status, err = client.Status()
|
||||||
@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
// These route should auto approve, so the node is expected to have a route
|
// These route should auto approve, so the node is expected to have a route
|
||||||
// for all counts.
|
// for all counts.
|
||||||
nodes, err = headscale.ListNodes()
|
nodes, err = headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
// Verify that the routes have been sent to the client.
|
||||||
status, err = client.Status()
|
status, err = client.Status()
|
||||||
@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
err = headscale.SetPolicy(tt.pol)
|
err = headscale.SetPolicy(tt.pol)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
// These route should auto approve, so the node is expected to have a route
|
// These route should auto approve, so the node is expected to have a route
|
||||||
// for all counts.
|
// for all counts.
|
||||||
nodes, err = headscale.ListNodes()
|
nodes, err = headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
// Verify that the routes have been sent to the client.
|
||||||
status, err = client.Status()
|
status, err = client.Status()
|
||||||
@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
_, _, err = routerSubRoute.Execute(command)
|
_, _, err = routerSubRoute.Execute(command)
|
||||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
// These route should auto approve, so the node is expected to have a route
|
// These route should auto approve, so the node is expected to have a route
|
||||||
// for all counts.
|
// for all counts.
|
||||||
nodes, err = headscale.ListNodes()
|
nodes, err = headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
// Verify that the routes have been sent to the client.
|
||||||
@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
_, _, err = routerSubRoute.Execute(command)
|
_, _, err = routerSubRoute.Execute(command)
|
||||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
// These route should auto approve, so the node is expected to have a route
|
// These route should auto approve, so the node is expected to have a route
|
||||||
// for all counts.
|
// for all counts.
|
||||||
nodes, err = headscale.ListNodes()
|
nodes, err = headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||||
requireNodeRouteCount(t, nodes[2], 0, 0, 0)
|
requireNodeRouteCount(t, nodes[2], 0, 0, 0)
|
||||||
|
|
||||||
@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||||||
_, _, err = routerExitNode.Execute(command)
|
_, _, err = routerExitNode.Execute(command)
|
||||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
// Wait for route state changes to propagate
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
nodes, err = headscale.ListNodes()
|
nodes, err = headscale.ListNodes()
|
||||||
require.NoError(t, err)
|
assert.NoError(c, err)
|
||||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||||
requireNodeRouteCount(t, nodes[2], 2, 2, 2)
|
requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
|
|
||||||
// Verify that the routes have been sent to the client.
|
// Verify that the routes have been sent to the client.
|
||||||
status, err = client.Status()
|
status, err = client.Status()
|
||||||
@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) {
|
|||||||
require.Equal(t, tr.Route[0].IP, ip)
|
require.Equal(t, tr.Route[0].IP, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT
|
||||||
|
func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) {
|
||||||
|
assert.NotNil(c, tr)
|
||||||
|
assert.True(c, tr.Success)
|
||||||
|
assert.NoError(c, tr.Err)
|
||||||
|
assert.NotEmpty(c, tr.Route)
|
||||||
|
assert.Equal(c, tr.Route[0].IP, ip)
|
||||||
|
}
|
||||||
|
|
||||||
// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes.
|
// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes.
|
||||||
func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
||||||
|
if status.AllowedIPs.Len() <= 2 && len(expected) != 0 {
|
||||||
|
assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(expected) == 0 {
|
||||||
|
expected = []netip.Prefix{}
|
||||||
|
}
|
||||||
|
|
||||||
|
got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool {
|
||||||
|
if tsaddr.IsExitRoute(p) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
|
||||||
|
assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
|
func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
||||||
@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
|
|||||||
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) {
|
||||||
|
assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
||||||
|
assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
|
||||||
|
assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||||
|
}
|
||||||
|
|
||||||
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
|
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
|
||||||
// that are explicitly allowed in the ACL.
|
// that are explicitly allowed in the ACL.
|
||||||
func TestSubnetRouteACLFiltering(t *testing.T) {
|
func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||||
@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Give some time for the routes to propagate
|
// Wait for route state changes to propagate
|
||||||
time.Sleep(5 * time.Second)
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
// List nodes and verify the router has 3 available routes
|
||||||
|
nodes, err = headscale.NodesByUser()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes, 2)
|
||||||
|
|
||||||
// List nodes and verify the router has 3 available routes
|
// Find the router node
|
||||||
nodes, err = headscale.NodesByUser()
|
routerNode = nodes[routerUser][0]
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, nodes, 2)
|
|
||||||
|
|
||||||
// Find the router node
|
// Check that the router has 3 routes now approved and available
|
||||||
routerNode = nodes[routerUser][0]
|
requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3)
|
||||||
|
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||||
// Check that the router has 3 routes now approved and available
|
|
||||||
requireNodeRouteCount(t, routerNode, 3, 3, 3)
|
|
||||||
|
|
||||||
// Now check the client node status
|
// Now check the client node status
|
||||||
nodeStatus, err := nodeClient.Status()
|
nodeStatus, err := nodeClient.Status()
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
|
|||||||
return nil, fmt.Errorf("no network named: %s", name)
|
return nil, fmt.Errorf("no network named: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ipam := range net.Network.IPAM.Config {
|
if len(net.Network.IPAM.Config) == 0 {
|
||||||
pref, err := netip.ParsePrefix(ipam.Subnet)
|
return nil, fmt.Errorf("no IPAM config found in network: %s", name)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &pref, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("no prefix found in network: %s", name)
|
pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pref, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
|
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
|
||||||
@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(s.spec.Users)
|
|
||||||
for _, user := range s.spec.Users {
|
for _, user := range s.spec.Users {
|
||||||
u, err := s.CreateUser(user)
|
u, err := s.CreateUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user