mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-14 13:51:01 +02:00
rest
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
7eef3cc38c
commit
c24b988247
@ -136,9 +136,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
|
|
||||||
// Initialize ephemeral garbage collector
|
// Initialize ephemeral garbage collector
|
||||||
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||||
node, err := app.state.GetNodeByID(ni)
|
node, ok := app.state.GetNodeByID(ni)
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
|
log.Warn().Uint64("node.id", ni.Uint64()).Msgf("ephemeral node not found for deletion")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -371,7 +371,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
Str("client_address", req.RemoteAddr).
|
Str("client_address", req.RemoteAddr).
|
||||||
Msg("HTTP authentication invoked")
|
Msg("HTTP authentication invoked")
|
||||||
|
|
||||||
authHeader := req.Header.Get("authorization")
|
authHeader := req.Header.Get("Authorization")
|
||||||
|
|
||||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||||
log.Error().
|
log.Error().
|
||||||
@ -487,11 +487,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
|
var err error
|
||||||
capver.CanOldCodeBeCleanedUp()
|
capver.CanOldCodeBeCleanedUp()
|
||||||
|
|
||||||
if profilingEnabled {
|
if profilingEnabled {
|
||||||
if profilingPath != "" {
|
if profilingPath != "" {
|
||||||
err := os.MkdirAll(profilingPath, os.ModePerm)
|
err = os.MkdirAll(profilingPath, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||||
}
|
}
|
||||||
@ -543,10 +544,7 @@ func (h *Headscale) Serve() error {
|
|||||||
// around between restarts, they will reconnect and the GC will
|
// around between restarts, they will reconnect and the GC will
|
||||||
// be cancelled.
|
// be cancelled.
|
||||||
go h.ephemeralGC.Start()
|
go h.ephemeralGC.Start()
|
||||||
ephmNodes, err := h.state.ListEphemeralNodes()
|
ephmNodes := h.state.ListEphemeralNodes()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
|
|
||||||
}
|
|
||||||
for _, node := range ephmNodes.All() {
|
for _, node := range ephmNodes.All() {
|
||||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
}
|
}
|
||||||
@ -778,23 +776,14 @@ func (h *Headscale) Serve() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
changed, err := h.state.ReloadPolicy()
|
changes, err := h.state.ReloadPolicy()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("reloading policy")
|
log.Error().Err(err).Msgf("reloading policy")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if changed {
|
h.Change(changes...)
|
||||||
log.Info().
|
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
|
||||||
|
|
||||||
err = h.state.AutoApproveNodes()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Change(change.PolicySet)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
info := func(msg string) { log.Info().Msg(msg) }
|
info := func(msg string) { log.Info().Msg(msg) }
|
||||||
log.Info().
|
log.Info().
|
||||||
@ -1004,6 +993,8 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||||||
// Change is used to send changes to nodes.
|
// Change is used to send changes to nodes.
|
||||||
// All change should be enqueued here and empty will be automatically
|
// All change should be enqueued here and empty will be automatically
|
||||||
// ignored.
|
// ignored.
|
||||||
func (h *Headscale) Change(c change.ChangeSet) {
|
func (h *Headscale) Change(cs ...change.ChangeSet) {
|
||||||
h.mapBatcher.AddWork(c)
|
for _, c := range cs {
|
||||||
|
h.mapBatcher.AddWork(c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
@ -28,27 +27,9 @@ func (h *Headscale) handleRegister(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, fmt.Errorf("looking up node in database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if node.Valid() {
|
|
||||||
// If an existing node is trying to register with an auth key,
|
|
||||||
// we need to validate the auth key even for existing nodes
|
|
||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
|
||||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
|
||||||
if err != nil {
|
|
||||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
|
||||||
var httpErr HTTPError
|
|
||||||
if errors.As(err, &httpErr) {
|
|
||||||
return nil, httpErr
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if ok {
|
||||||
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||||
@ -69,6 +50,7 @@ func (h *Headscale) handleRegister(
|
|||||||
if errors.As(err, &httpErr) {
|
if errors.As(err, &httpErr) {
|
||||||
return nil, httpErr
|
return nil, httpErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,13 +70,22 @@ func (h *Headscale) handleExistingNode(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
|
|
||||||
if node.MachineKey != machineKey {
|
if node.MachineKey != machineKey {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
expired := node.IsExpired()
|
expired := node.IsExpired()
|
||||||
|
|
||||||
|
// If the node is expired and this is not a re-authentication attempt,
|
||||||
|
// force the client to re-authenticate
|
||||||
|
if expired && regReq.Auth == nil {
|
||||||
|
return &tailcfg.RegisterResponse{
|
||||||
|
NodeKeyExpired: true,
|
||||||
|
MachineAuthorized: false,
|
||||||
|
AuthURL: "", // Client will need to re-authenticate
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
if !expired && !regReq.Expiry.IsZero() {
|
if !expired && !regReq.Expiry.IsZero() {
|
||||||
requestExpiry := regReq.Expiry
|
requestExpiry := regReq.Expiry
|
||||||
|
|
||||||
@ -117,12 +108,16 @@ func (h *Headscale) handleExistingNode(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.Change(c)
|
h.Change(c)
|
||||||
|
|
||||||
|
// CRITICAL: Use the updated node view for the response
|
||||||
|
// The original node object has stale expiry information
|
||||||
|
node = updatedNode.AsStruct()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodeToRegisterResponse(node), nil
|
return nodeToRegisterResponse(node), nil
|
||||||
@ -192,8 +187,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If node is nil, it means an ephemeral node was deleted during logout
|
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||||
if node.Valid() {
|
if !node.Valid() {
|
||||||
h.Change(changed)
|
h.Change(changed)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -212,6 +207,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||||
// now since we dont update the node/pol here anymore
|
// now since we dont update the node/pol here anymore
|
||||||
routeChange := h.state.AutoApproveRoutes(node)
|
routeChange := h.state.AutoApproveRoutes(node)
|
||||||
|
|
||||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||||
}
|
}
|
||||||
@ -229,6 +225,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
user := node.User()
|
user := node.User()
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
MachineAuthorized: true,
|
MachineAuthorized: true,
|
||||||
NodeKeyExpired: node.IsExpired(),
|
NodeKeyExpired: node.IsExpired(),
|
||||||
|
@ -936,7 +936,7 @@ AND auth_key_id NOT IN (
|
|||||||
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
|
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
|
||||||
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
|
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
|
||||||
// - Never write migrations that requires foreign keys to be disabled.
|
// - Never write migrations that requires foreign keys to be disabled.
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := runMigrations(cfg, dbConn, migrations); err != nil {
|
if err := runMigrations(cfg, dbConn, migrations); err != nil {
|
||||||
|
@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||||
// and renames it. If the name is not unique, it will return an error.
|
// and renames it. Validation should be done in the state layer before calling this function.
|
||||||
func RenameNode(tx *gorm.DB,
|
func RenameNode(tx *gorm.DB,
|
||||||
nodeID types.NodeID, newName string,
|
nodeID types.NodeID, newName string,
|
||||||
) error {
|
) error {
|
||||||
err := util.CheckForFQDNRules(
|
// Check if the new name is unique
|
||||||
newName,
|
var count int64
|
||||||
)
|
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
||||||
return fmt.Errorf("renaming node: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uniq, err := isUniqueName(tx, newName)
|
if count > 0 {
|
||||||
if err != nil {
|
return fmt.Errorf("name is not unique")
|
||||||
return fmt.Errorf("checking if name is unique: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !uniq {
|
|
||||||
return fmt.Errorf("name is not unique: %s", newName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||||
@ -409,9 +403,16 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CRITICAL: Reload the node to get the updated expiry
|
||||||
|
// Without this, we return stale node data to NodeStore
|
||||||
|
updatedNode, err := GetNodeByID(tx, node.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to reload node after expiry update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
nodeChange = change.KeyExpiry(node.ID)
|
nodeChange = change.KeyExpiry(node.ID)
|
||||||
|
|
||||||
return node, nil
|
return updatedNode, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -445,8 +446,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
node.ID = oldNode.ID
|
node.ID = oldNode.ID
|
||||||
node.GivenName = oldNode.GivenName
|
node.GivenName = oldNode.GivenName
|
||||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||||
ipv4 = oldNode.IPv4
|
// Don't overwrite the provided IPs with old ones when they exist
|
||||||
ipv6 = oldNode.IPv6
|
if ipv4 == nil {
|
||||||
|
ipv4 = oldNode.IPv4
|
||||||
|
}
|
||||||
|
if ipv6 == nil {
|
||||||
|
ipv6 = oldNode.IPv6
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the node exists and it already has IP(s), we just save it
|
// If the node exists and it already has IP(s), we just save it
|
||||||
@ -781,19 +787,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
|||||||
|
|
||||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||||
|
|
||||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
// Allocate IPs for the test node using the database's IP allocator
|
||||||
_, err := RegisterNode(tx, *node, nil, nil)
|
// This is a simplified allocation for testing - in production this would use State.ipAlloc
|
||||||
|
ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
var registeredNode *types.Node
|
||||||
|
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
var err error
|
||||||
|
registeredNode, err = RegisterNode(tx, *node, ipv4, ipv6)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return registeredNode
|
return registeredNode
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -842,3 +852,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
|||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allocateTestIPs allocates sequential test IPs for nodes during testing.
|
||||||
|
func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("allocateTestIPs can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use simple sequential allocation for tests
|
||||||
|
// IPv4: 100.64.0.x (where x is nodeID)
|
||||||
|
// IPv6: fd7a:115c:a1e0::x (where x is nodeID)
|
||||||
|
|
||||||
|
if nodeID > 254 {
|
||||||
|
return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)})
|
||||||
|
ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)})
|
||||||
|
|
||||||
|
return &ipv4, &ipv6, nil
|
||||||
|
}
|
||||||
|
@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
|||||||
|
|
||||||
func TestAutoApproveRoutes(t *testing.T) {
|
func TestAutoApproveRoutes(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
acl string
|
acl string
|
||||||
routes []netip.Prefix
|
routes []netip.Prefix
|
||||||
want []netip.Prefix
|
want []netip.Prefix
|
||||||
want2 []netip.Prefix
|
want2 []netip.Prefix
|
||||||
|
expectChange bool // whether to expect route changes
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "no-auto-approvers-empty-policy",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:admins": ["test@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admins"],
|
||||||
|
"dst": ["group:admins:*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||||
|
want: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||||
|
want2: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||||
|
expectChange: false, // No changes expected
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-auto-approvers-explicit-empty",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:admins": ["test@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admins"],
|
||||||
|
"dst": ["group:admins:*"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {},
|
||||||
|
"exitNode": []
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||||
|
want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||||
|
want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||||
|
expectChange: false, // No changes expected
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "2068-approve-issue-sub-kube",
|
name: "2068-approve-issue-sub-kube",
|
||||||
acl: `
|
acl: `
|
||||||
@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||||
|
expectChange: true, // Routes should be approved
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "2068-approve-issue-sub-exit-tag",
|
name: "2068-approve-issue-sub-exit-tag",
|
||||||
@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
tsaddr.AllIPv4(),
|
tsaddr.AllIPv4(),
|
||||||
tsaddr.AllIPv6(),
|
tsaddr.AllIPv6(),
|
||||||
},
|
},
|
||||||
|
expectChange: true, // Routes should be approved
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, pm)
|
require.NotNil(t, pm)
|
||||||
|
|
||||||
changed1 := policy.AutoApproveRoutes(pm, &node)
|
newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes)
|
||||||
assert.True(t, changed1)
|
assert.Equal(t, tt.expectChange, changed1)
|
||||||
|
|
||||||
err = adb.DB.Save(&node).Error
|
if changed1 {
|
||||||
require.NoError(t, err)
|
err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
|
newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes)
|
||||||
|
if changed2 {
|
||||||
err = adb.DB.Save(&nodeTagged).Error
|
err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
node1ByID, err := adb.GetNodeByID(1)
|
node1ByID, err := adb.GetNodeByID(1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
// For empty auto-approvers tests, handle nil vs empty slice comparison
|
||||||
|
expectedRoutes1 := tt.want
|
||||||
|
if len(expectedRoutes1) == 0 {
|
||||||
|
expectedRoutes1 = nil
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
node2ByID, err := adb.GetNodeByID(2)
|
node2ByID, err := adb.GetNodeByID(2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
expectedRoutes2 := tt.want2
|
||||||
|
if len(expectedRoutes2) == 0 {
|
||||||
|
expectedRoutes2 = nil
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) {
|
|||||||
// No parameter means no filter, should return all peers
|
// No parameter means no filter, should return all peers
|
||||||
nodes, err = db.ListPeers(1)
|
nodes, err = db.ListPeers(1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Empty node list should return all peers
|
// Empty node list should return all peers
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// No match in IDs should return empty list and no error
|
// No match in IDs should return empty list and no error
|
||||||
@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) {
|
|||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs, but node ID is still filtered out
|
// Several matched IDs, but node ID is still filtered out
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) {
|
|||||||
// No parameter means no filter, should return all nodes
|
// No parameter means no filter, should return all nodes
|
||||||
nodes, err = db.ListNodes()
|
nodes, err = db.ListNodes()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(nodes))
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
// Empty node list should return all nodes
|
// Empty node list should return all nodes
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(nodes))
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) {
|
|||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(nodes))
|
assert.Len(t, nodes, 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs
|
// Several matched IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(nodes))
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
}
|
}
|
||||||
|
@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser assigns a Node to a user.
|
// AssignNodeToUser assigns a Node to a user.
|
||||||
|
// Note: Validation should be done in the state layer before calling this function.
|
||||||
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
|
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
|
||||||
node, err := GetNodeByID(tx, nodeID)
|
// Check if the user exists
|
||||||
if err != nil {
|
var userExists bool
|
||||||
return err
|
if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||||
}
|
}
|
||||||
user, err := GetUserByID(tx, uid)
|
|
||||||
if err != nil {
|
if !userExists {
|
||||||
return err
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
node.User = *user
|
|
||||||
node.UserID = user.ID
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil {
|
||||||
if result := tx.Save(&node); result.Error != nil {
|
return fmt.Errorf("failed to assign node to user: %w", err)
|
||||||
return result.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -288,9 +288,9 @@ func (api headscaleV1APIServer) GetNode(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetNodeRequest,
|
request *v1.GetNodeRequest,
|
||||||
) (*v1.GetNodeResponse, error) {
|
) (*v1.GetNodeResponse, error) {
|
||||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := node.Proto()
|
resp := node.Proto()
|
||||||
@ -334,7 +334,12 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.SetApprovedRoutesRequest,
|
request *v1.SetApprovedRoutesRequest,
|
||||||
) (*v1.SetApprovedRoutesResponse, error) {
|
) (*v1.SetApprovedRoutesResponse, error) {
|
||||||
var routes []netip.Prefix
|
log.Debug().
|
||||||
|
Uint64("node.id", request.GetNodeId()).
|
||||||
|
Strs("requestedRoutes", request.GetRoutes()).
|
||||||
|
Msg("gRPC SetApprovedRoutes called")
|
||||||
|
|
||||||
|
var newApproved []netip.Prefix
|
||||||
for _, route := range request.GetRoutes() {
|
for _, route := range request.GetRoutes() {
|
||||||
prefix, err := netip.ParsePrefix(route)
|
prefix, err := netip.ParsePrefix(route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -344,31 +349,34 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
// If the prefix is an exit route, add both. The client expect both
|
// If the prefix is an exit route, add both. The client expect both
|
||||||
// to annotate the node as an exit node.
|
// to annotate the node as an exit node.
|
||||||
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
||||||
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||||
} else {
|
} else {
|
||||||
routes = append(routes, prefix)
|
newApproved = append(newApproved, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tsaddr.SortPrefixes(routes)
|
tsaddr.SortPrefixes(newApproved)
|
||||||
routes = slices.Compact(routes)
|
newApproved = slices.Compact(newApproved)
|
||||||
|
|
||||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...)
|
|
||||||
|
|
||||||
// Always propagate node changes from SetApprovedRoutes
|
// Always propagate node changes from SetApprovedRoutes
|
||||||
api.h.Change(nodeChange)
|
api.h.Change(nodeChange)
|
||||||
|
|
||||||
// If routes changed, propagate those changes too
|
|
||||||
if !routeChange.Empty() {
|
|
||||||
api.h.Change(routeChange)
|
|
||||||
}
|
|
||||||
|
|
||||||
proto := node.Proto()
|
proto := node.Proto()
|
||||||
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID()))
|
// Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
|
||||||
|
// routes that are actively served from the node (per architectural requirement in types/node.go)
|
||||||
|
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
|
||||||
|
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Uint64("node.id", node.ID().Uint64()).
|
||||||
|
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||||
|
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
||||||
|
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
||||||
|
Msg("gRPC SetApprovedRoutes completed")
|
||||||
|
|
||||||
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
||||||
}
|
}
|
||||||
@ -390,9 +398,9 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteNodeRequest,
|
request *v1.DeleteNodeRequest,
|
||||||
) (*v1.DeleteNodeResponse, error) {
|
) (*v1.DeleteNodeResponse, error) {
|
||||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeChange, err := api.h.state.DeleteNode(node)
|
nodeChange, err := api.h.state.DeleteNode(node)
|
||||||
@ -463,19 +471,13 @@ func (api headscaleV1APIServer) ListNodes(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := api.h.state.ListNodes()
|
nodes := api.h.state.ListNodes()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||||
@ -499,6 +501,7 @@ func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
|
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
|
||||||
|
|
||||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
|
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
@ -674,11 +677,8 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
// a scenario where they might be allowed if the server has no nodes
|
// a scenario where they might be allowed if the server has no nodes
|
||||||
// yet, but it should help for the general case and for hot reloading
|
// yet, but it should help for the general case and for hot reloading
|
||||||
// configurations.
|
// configurations.
|
||||||
nodes, err := api.h.state.ListNodes()
|
nodes := api.h.state.ListNodes()
|
||||||
if err != nil {
|
_, err := api.h.state.SetPolicy([]byte(p))
|
||||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
|
||||||
}
|
|
||||||
changed, err := api.h.state.SetPolicy([]byte(p))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting policy: %w", err)
|
return nil, fmt.Errorf("setting policy: %w", err)
|
||||||
}
|
}
|
||||||
@ -695,16 +695,16 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only send update if the packet filter has changed.
|
// Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
|
||||||
if changed {
|
// This ensures that routes are re-evaluated for auto-approval in cases where routes
|
||||||
err = api.h.state.AutoApproveNodes()
|
// were manually disabled but could now be auto-approved with the current policy.
|
||||||
if err != nil {
|
cs, err := api.h.state.ReloadPolicy()
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
return nil, fmt.Errorf("reloading policy: %w", err)
|
||||||
|
|
||||||
api.h.Change(change.PolicyChange())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
api.h.Change(cs...)
|
||||||
|
|
||||||
response := &v1.SetPolicyResponse{
|
response := &v1.SetPolicyResponse{
|
||||||
Policy: updated.Data,
|
Policy: updated.Data,
|
||||||
UpdatedAt: timestamppb.New(updated.UpdatedAt),
|
UpdatedAt: timestamppb.New(updated.UpdatedAt),
|
||||||
|
@ -94,10 +94,7 @@ func (h *Headscale) handleVerifyRequest(
|
|||||||
return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
|
return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := h.state.ListNodes()
|
nodes := h.state.ListNodes()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot list nodes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any node has the requested NodeKey
|
// Check if any node has the requested NodeKey
|
||||||
var nodeKeyFound bool
|
var nodeKeyFound bool
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package mapper
|
package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
|||||||
type Batcher interface {
|
type Batcher interface {
|
||||||
Start()
|
Start()
|
||||||
Close()
|
Close()
|
||||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
|
||||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||||
IsConnected(id types.NodeID) bool
|
IsConnected(id types.NodeID) bool
|
||||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||||
AddWork(c change.ChangeSet)
|
AddWork(c change.ChangeSet)
|
||||||
@ -119,7 +120,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
|||||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||||
if nc == nil {
|
if nc == nil {
|
||||||
return fmt.Errorf("nodeConnection is nil")
|
return errors.New("nodeConnection is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := nc.nodeID()
|
nodeID := nc.nodeID()
|
||||||
|
@ -3,7 +3,6 @@ package mapper
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -21,7 +20,6 @@ type LockFreeBatcher struct {
|
|||||||
mapper *mapper
|
mapper *mapper
|
||||||
workers int
|
workers int
|
||||||
|
|
||||||
// Lock-free concurrent maps
|
|
||||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||||
connected *xsync.Map[types.NodeID, *time.Time]
|
connected *xsync.Map[types.NodeID, *time.Time]
|
||||||
|
|
||||||
@ -32,7 +30,6 @@ type LockFreeBatcher struct {
|
|||||||
|
|
||||||
// Batching state
|
// Batching state
|
||||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||||
batchMutex sync.RWMutex
|
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
totalNodes atomic.Int64
|
totalNodes atomic.Int64
|
||||||
@ -46,16 +43,13 @@ type LockFreeBatcher struct {
|
|||||||
// It creates or updates the node's connection data, validates the initial map generation,
|
// It creates or updates the node's connection data, validates the initial map generation,
|
||||||
// and notifies other nodes that this node has come online.
|
// and notifies other nodes that this node has come online.
|
||||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||||
// First validate that we can generate initial map before doing anything else
|
|
||||||
fullSelfChange := change.FullSelf(id)
|
|
||||||
|
|
||||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||||
// This currently means that the goroutine for the node connection will do the processing
|
// This currently means that the goroutine for the node connection will do the processing
|
||||||
// which means that we might have uncontrolled concurrency.
|
// which means that we might have uncontrolled concurrency.
|
||||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||||
// it to be processed in a more controlled manner.
|
// it to be processed in a more controlled manner.
|
||||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
initialMap, err := generateMapResponse(id, version, b.mapper, change.FullSelf(id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||||
}
|
}
|
||||||
@ -73,10 +67,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
conn = newConn
|
conn = newConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark as connected only after validation succeeds
|
|
||||||
b.connected.Store(id, nil) // nil = connected
|
b.connected.Store(id, nil) // nil = connected
|
||||||
|
|
||||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
log.Info().Uint64("node.id", id.Uint64()).Msg("Node connected to batcher")
|
||||||
|
|
||||||
// Send the validated initial map
|
// Send the validated initial map
|
||||||
if initialMap != nil {
|
if initialMap != nil {
|
||||||
@ -86,9 +79,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
b.connected.Delete(id)
|
b.connected.Delete(id)
|
||||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify other nodes that this node came online
|
|
||||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -97,12 +87,14 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||||
// It validates the connection channel matches the current one, closes the connection,
|
// It validates the connection channel matches the current one, closes the connection,
|
||||||
// and notifies other nodes that this node has gone offline.
|
// and notifies other nodes that this node has gone offline.
|
||||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
// It reports if the node was actually closed. Returns false if the channel does not match the current connection,
|
||||||
|
// indicating that we are actually not disconnecting the node, but rather ignoring the request.
|
||||||
|
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||||
// Check if this is the current connection and mark it as closed
|
// Check if this is the current connection and mark it as closed
|
||||||
if existing, ok := b.nodes.Load(id); ok {
|
if existing, ok := b.nodes.Load(id); ok {
|
||||||
if !existing.matchesChannel(c) {
|
if !existing.matchesChannel(c) {
|
||||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called on a different channel, ignoring")
|
||||||
return // Not the current connection, not an error
|
return false // Not the current connection, not an error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark the connection as closed to prevent further sends
|
// Mark the connection as closed to prevent further sends
|
||||||
@ -111,15 +103,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
log.Info().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher, marking as offline")
|
||||||
|
|
||||||
// Remove node and mark disconnected atomically
|
// Remove node and mark disconnected atomically
|
||||||
b.nodes.Delete(id)
|
b.nodes.Delete(id)
|
||||||
b.connected.Store(id, ptr.To(time.Now()))
|
b.connected.Store(id, ptr.To(time.Now()))
|
||||||
b.totalNodes.Add(-1)
|
b.totalNodes.Add(-1)
|
||||||
|
|
||||||
// Notify other nodes that this node went offline
|
return true
|
||||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddWork queues a change to be processed by the batcher.
|
// AddWork queues a change to be processed by the batcher.
|
||||||
@ -214,6 +205,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
Dur("duration", duration).
|
Dur("duration", duration).
|
||||||
Msg("slow synchronous work processing")
|
Msg("slow synchronous work processing")
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
Uint64("node.id", w.nodeID.Uint64()).
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
Str("change", w.c.Change.String()).
|
Str("change", w.c.Change.String()).
|
||||||
Msg("skipping work for closed connection")
|
Msg("skipping work for closed connection")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,12 +233,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
Str("change", w.c.Change.String()).
|
Str("change", w.c.Change.String()).
|
||||||
Msg("failed to apply change")
|
Msg("failed to apply change")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
log.Debug().
|
|
||||||
Int("workerID", workerID).
|
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
|
||||||
Str("change", w.c.Change.String()).
|
|
||||||
Msg("node not found for asynchronous work - node may have disconnected")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
duration := time.Since(startTime)
|
duration := time.Since(startTime)
|
||||||
@ -276,8 +263,10 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,7 +274,7 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
|||||||
b.addToBatch(c)
|
b.addToBatch(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// queueWork safely queues work
|
// queueWork safely queues work.
|
||||||
func (b *LockFreeBatcher) queueWork(w work) {
|
func (b *LockFreeBatcher) queueWork(w work) {
|
||||||
b.workQueuedCount.Add(1)
|
b.workQueuedCount.Add(1)
|
||||||
|
|
||||||
@ -298,7 +287,7 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldProcessImmediately determines if a change should bypass batching
|
// shouldProcessImmediately determines if a change should bypass batching.
|
||||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||||
// Process these changes immediately to avoid delaying critical functionality
|
// Process these changes immediately to avoid delaying critical functionality
|
||||||
switch c.Change {
|
switch c.Change {
|
||||||
@ -309,11 +298,8 @@ func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addToBatch adds a change to the pending batch
|
// addToBatch adds a change to the pending batch.
|
||||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||||
b.batchMutex.Lock()
|
|
||||||
defer b.batchMutex.Unlock()
|
|
||||||
|
|
||||||
if c.SelfUpdateOnly {
|
if c.SelfUpdateOnly {
|
||||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||||
changes = append(changes, c)
|
changes = append(changes, c)
|
||||||
@ -329,15 +315,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
|||||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||||
changes = append(changes, c)
|
changes = append(changes, c)
|
||||||
b.pendingChanges.Store(nodeID, changes)
|
b.pendingChanges.Store(nodeID, changes)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// processBatchedChanges processes all pending batched changes
|
// processBatchedChanges processes all pending batched changes.
|
||||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||||
b.batchMutex.Lock()
|
|
||||||
defer b.batchMutex.Unlock()
|
|
||||||
|
|
||||||
if b.pendingChanges == nil {
|
if b.pendingChanges == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -355,17 +339,27 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
|||||||
|
|
||||||
// Clear the pending changes for this node
|
// Clear the pending changes for this node
|
||||||
b.pendingChanges.Delete(nodeID)
|
b.pendingChanges.Delete(nodeID)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected is lock-free read.
|
// IsConnected is lock-free read.
|
||||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||||
if val, ok := b.connected.Load(id); ok {
|
val, ok := b.connected.Load(id)
|
||||||
// nil means connected
|
if !ok {
|
||||||
return val == nil
|
return false
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
|
// nil means connected
|
||||||
|
if val == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// During grace period, always return true to allow DNS resolution
|
||||||
|
// for logout HTTP requests to complete successfully
|
||||||
|
gracePeriod := 45 * time.Second
|
||||||
|
return time.Since(*val) < gracePeriod
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||||
@ -487,5 +481,6 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
|||||||
// the channel is still open.
|
// the channel is still open.
|
||||||
connData.c <- data
|
connData.c <- data
|
||||||
nc.updateCount.Add(1)
|
nc.updateCount.Add(1)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,43 @@ type batcherTestCase struct {
|
|||||||
fn batcherFunc
|
fn batcherFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testBatcherWrapper wraps a real batcher to add online/offline notifications
|
||||||
|
// that would normally be sent by poll.go in production
|
||||||
|
type testBatcherWrapper struct {
|
||||||
|
Batcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||||
|
// First add the node to the real batcher
|
||||||
|
err := t.Batcher.AddNode(id, c, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then send the online notification that poll.go would normally send
|
||||||
|
t.Batcher.AddWork(change.NodeOnline(id))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||||
|
// First remove from the real batcher
|
||||||
|
removed := t.Batcher.RemoveNode(id, c)
|
||||||
|
if !removed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then send the offline notification that poll.go would normally send
|
||||||
|
t.Batcher.AddWork(change.NodeOffline(id))
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapBatcherForTest wraps a batcher with test-specific behavior
|
||||||
|
func wrapBatcherForTest(b Batcher) Batcher {
|
||||||
|
return &testBatcherWrapper{Batcher: b}
|
||||||
|
}
|
||||||
|
|
||||||
// allBatcherFunctions contains all batcher implementations to test.
|
// allBatcherFunctions contains all batcher implementations to test.
|
||||||
var allBatcherFunctions = []batcherTestCase{
|
var allBatcherFunctions = []batcherTestCase{
|
||||||
{"LockFree", NewBatcherAndMapper},
|
{"LockFree", NewBatcherAndMapper},
|
||||||
@ -176,8 +213,8 @@ func setupBatcherWithTestData(
|
|||||||
"acls": [
|
"acls": [
|
||||||
{
|
{
|
||||||
"action": "accept",
|
"action": "accept",
|
||||||
"users": ["*"],
|
"src": ["*"],
|
||||||
"ports": ["*:*"]
|
"dst": ["*:*"]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`
|
}`
|
||||||
@ -187,8 +224,8 @@ func setupBatcherWithTestData(
|
|||||||
t.Fatalf("Failed to set allow-all policy: %v", err)
|
t.Fatalf("Failed to set allow-all policy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create batcher with the state
|
// Create batcher with the state and wrap it for testing
|
||||||
batcher := bf(cfg, state)
|
batcher := wrapBatcherForTest(bf(cfg, state))
|
||||||
batcher.Start()
|
batcher.Start()
|
||||||
|
|
||||||
testData := &TestData{
|
testData := &TestData{
|
||||||
@ -455,7 +492,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
|||||||
testNode.start()
|
testNode.start()
|
||||||
|
|
||||||
// Connect the node to the batcher
|
// Connect the node to the batcher
|
||||||
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||||
time.Sleep(100 * time.Millisecond) // Let connection settle
|
time.Sleep(100 * time.Millisecond) // Let connection settle
|
||||||
|
|
||||||
// Generate some work
|
// Generate some work
|
||||||
@ -558,7 +595,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||||
for i := range allNodes {
|
for i := range allNodes {
|
||||||
node := &allNodes[i]
|
node := &allNodes[i]
|
||||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Issue full update after each join to ensure connectivity
|
// Issue full update after each join to ensure connectivity
|
||||||
batcher.AddWork(change.FullSet)
|
batcher.AddWork(change.FullSet)
|
||||||
@ -606,7 +643,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
// Disconnect all nodes
|
// Disconnect all nodes
|
||||||
for i := range allNodes {
|
for i := range allNodes {
|
||||||
node := &allNodes[i]
|
node := &allNodes[i]
|
||||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
batcher.RemoveNode(node.n.ID, node.ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Give time for final updates to process
|
// Give time for final updates to process
|
||||||
@ -724,7 +761,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
tn2 := testData.Nodes[1]
|
tn2 := testData.Nodes[1]
|
||||||
|
|
||||||
// Test AddNode with real node ID
|
// Test AddNode with real node ID
|
||||||
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
|
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||||
if !batcher.IsConnected(tn.n.ID) {
|
if !batcher.IsConnected(tn.n.ID) {
|
||||||
t.Error("Node should be connected after AddNode")
|
t.Error("Node should be connected after AddNode")
|
||||||
}
|
}
|
||||||
@ -744,14 +781,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||||
|
|
||||||
// Add the second node and verify update message
|
// Add the second node and verify update message
|
||||||
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
|
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||||
|
|
||||||
// First node should get an update that second node has connected.
|
// First node should get an update that second node has connected.
|
||||||
select {
|
select {
|
||||||
case data := <-tn.ch:
|
case data := <-tn.ch:
|
||||||
assertOnlineMapResponse(t, data, true)
|
assertOnlineMapResponse(t, data, true)
|
||||||
case <-time.After(200 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Did not receive expected Online response update")
|
t.Error("Did not receive expected Online response update")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -765,19 +802,19 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
len(data.Peers) >= 1 || data.Node != nil,
|
len(data.Peers) >= 1 || data.Node != nil,
|
||||||
"Should receive initial full map",
|
"Should receive initial full map",
|
||||||
)
|
)
|
||||||
case <-time.After(200 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Second node should receive its initial full map")
|
t.Error("Second node should receive its initial full map")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect the second node
|
// Disconnect the second node
|
||||||
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
|
batcher.RemoveNode(tn2.n.ID, tn2.ch)
|
||||||
assert.False(t, batcher.IsConnected(tn2.n.ID))
|
assert.False(t, batcher.IsConnected(tn2.n.ID))
|
||||||
|
|
||||||
// First node should get update that second has disconnected.
|
// First node should get update that second has disconnected.
|
||||||
select {
|
select {
|
||||||
case data := <-tn.ch:
|
case data := <-tn.ch:
|
||||||
assertOnlineMapResponse(t, data, false)
|
assertOnlineMapResponse(t, data, false)
|
||||||
case <-time.After(200 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Did not receive expected Online response update")
|
t.Error("Did not receive expected Online response update")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -803,7 +840,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// Test RemoveNode
|
// Test RemoveNode
|
||||||
batcher.RemoveNode(tn.n.ID, tn.ch, false)
|
batcher.RemoveNode(tn.n.ID, tn.ch)
|
||||||
if batcher.IsConnected(tn.n.ID) {
|
if batcher.IsConnected(tn.n.ID) {
|
||||||
t.Error("Node should be disconnected after RemoveNode")
|
t.Error("Node should be disconnected after RemoveNode")
|
||||||
}
|
}
|
||||||
@ -949,7 +986,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
|||||||
testNodes := testData.Nodes
|
testNodes := testData.Nodes
|
||||||
|
|
||||||
ch := make(chan *tailcfg.MapResponse, 10)
|
ch := make(chan *tailcfg.MapResponse, 10)
|
||||||
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Track update content for validation
|
// Track update content for validation
|
||||||
var receivedUpdates []*tailcfg.MapResponse
|
var receivedUpdates []*tailcfg.MapResponse
|
||||||
@ -1045,7 +1082,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Add real work during connection chaos
|
// Add real work during connection chaos
|
||||||
@ -1059,7 +1096,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
time.Sleep(1 * time.Microsecond)
|
time.Sleep(1 * time.Microsecond)
|
||||||
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Remove second connection
|
// Remove second connection
|
||||||
@ -1067,7 +1104,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
time.Sleep(2 * time.Microsecond)
|
time.Sleep(2 * time.Microsecond)
|
||||||
batcher.RemoveNode(testNode.n.ID, ch2, false)
|
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@ -1142,7 +1179,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
ch := make(chan *tailcfg.MapResponse, 5)
|
ch := make(chan *tailcfg.MapResponse, 5)
|
||||||
|
|
||||||
// Add node and immediately queue real work
|
// Add node and immediately queue real work
|
||||||
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
batcher.AddWork(change.DERPSet)
|
batcher.AddWork(change.DERPSet)
|
||||||
|
|
||||||
// Consumer goroutine to validate data and detect channel issues
|
// Consumer goroutine to validate data and detect channel issues
|
||||||
@ -1184,7 +1221,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
|
|
||||||
// Rapid removal creates race between worker and removal
|
// Rapid removal creates race between worker and removal
|
||||||
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
||||||
batcher.RemoveNode(testNode.n.ID, ch, false)
|
batcher.RemoveNode(testNode.n.ID, ch)
|
||||||
|
|
||||||
// Give workers time to process and close channels
|
// Give workers time to process and close channels
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
@ -1254,7 +1291,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
for _, node := range stableNodes {
|
for _, node := range stableNodes {
|
||||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||||
stableChannels[node.n.ID] = ch
|
stableChannels[node.n.ID] = ch
|
||||||
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Monitor updates for each stable client
|
// Monitor updates for each stable client
|
||||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||||
@ -1312,7 +1349,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
churningChannelsMutex.Lock()
|
churningChannelsMutex.Lock()
|
||||||
churningChannels[nodeID] = ch
|
churningChannels[nodeID] = ch
|
||||||
churningChannelsMutex.Unlock()
|
churningChannelsMutex.Unlock()
|
||||||
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Consume updates to prevent blocking
|
// Consume updates to prevent blocking
|
||||||
go func() {
|
go func() {
|
||||||
@ -1349,7 +1386,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
ch, exists := churningChannels[nodeID]
|
ch, exists := churningChannels[nodeID]
|
||||||
churningChannelsMutex.Unlock()
|
churningChannelsMutex.Unlock()
|
||||||
if exists {
|
if exists {
|
||||||
batcher.RemoveNode(nodeID, ch, false)
|
batcher.RemoveNode(nodeID, ch)
|
||||||
}
|
}
|
||||||
}(node.n.ID)
|
}(node.n.ID)
|
||||||
}
|
}
|
||||||
@ -1599,7 +1636,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
var connectedNodesMutex sync.RWMutex
|
var connectedNodesMutex sync.RWMutex
|
||||||
for i := range testNodes {
|
for i := range testNodes {
|
||||||
node := &testNodes[i]
|
node := &testNodes[i]
|
||||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
connectedNodes[node.n.ID] = true
|
connectedNodes[node.n.ID] = true
|
||||||
connectedNodesMutex.Unlock()
|
connectedNodesMutex.Unlock()
|
||||||
@ -1666,7 +1703,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
connectedNodesMutex.RUnlock()
|
connectedNodesMutex.RUnlock()
|
||||||
|
|
||||||
if isConnected {
|
if isConnected {
|
||||||
batcher.RemoveNode(nodeID, channel, false)
|
batcher.RemoveNode(nodeID, channel)
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
connectedNodes[nodeID] = false
|
connectedNodes[nodeID] = false
|
||||||
connectedNodesMutex.Unlock()
|
connectedNodesMutex.Unlock()
|
||||||
@ -1690,7 +1727,6 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
batcher.AddNode(
|
batcher.AddNode(
|
||||||
nodeID,
|
nodeID,
|
||||||
channel,
|
channel,
|
||||||
false,
|
|
||||||
tailcfg.CapabilityVersion(100),
|
tailcfg.CapabilityVersion(100),
|
||||||
)
|
)
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
@ -1792,7 +1828,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
// Now disconnect all nodes from batcher to stop new updates
|
// Now disconnect all nodes from batcher to stop new updates
|
||||||
for i := range testNodes {
|
for i := range testNodes {
|
||||||
node := &testNodes[i]
|
node := &testNodes[i]
|
||||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
batcher.RemoveNode(node.n.ID, node.ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
||||||
@ -1924,7 +1960,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
|
|
||||||
// Connect nodes one at a time to avoid overwhelming the work queue
|
// Connect nodes one at a time to avoid overwhelming the work queue
|
||||||
for i, node := range allNodes {
|
for i, node := range allNodes {
|
||||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||||
// Small delay between connections to allow NodeCameOnline processing
|
// Small delay between connections to allow NodeCameOnline processing
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
@ -1936,12 +1972,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
|
|
||||||
// Check how many peers each node should see
|
// Check how many peers each node should see
|
||||||
for i, node := range allNodes {
|
for i, node := range allNodes {
|
||||||
peers, err := testData.State.ListPeers(node.n.ID)
|
peers := testData.State.ListPeers(node.n.ID)
|
||||||
if err != nil {
|
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||||
t.Errorf("Error listing peers for node %d: %v", i, err)
|
|
||||||
} else {
|
|
||||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a full update - this should generate full peer lists
|
// Send a full update - this should generate full peer lists
|
||||||
@ -1957,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
foundFullUpdate := false
|
foundFullUpdate := false
|
||||||
|
|
||||||
// Read all available updates for each node
|
// Read all available updates for each node
|
||||||
for i := range len(allNodes) {
|
for i := range allNodes {
|
||||||
nodeUpdates := 0
|
nodeUpdates := 0
|
||||||
t.Logf("Reading updates for node %d:", i)
|
t.Logf("Reading updates for node %d:", i)
|
||||||
|
|
||||||
@ -2047,7 +2079,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
|||||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||||
|
|
||||||
// Connect first node
|
// Connect first node
|
||||||
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
|
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, tailcfg.CapabilityVersion(100))
|
||||||
t.Logf("Connected node %d", nodes[0].n.ID)
|
t.Logf("Connected node %d", nodes[0].n.ID)
|
||||||
|
|
||||||
// Wait for initial NodeCameOnline to be processed
|
// Wait for initial NodeCameOnline to be processed
|
||||||
@ -2102,14 +2134,10 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if there should be peers available
|
// Check if there should be peers available
|
||||||
peers, err := testData.State.ListPeers(nodes[0].n.ID)
|
peers := testData.State.ListPeers(nodes[0].n.ID)
|
||||||
if err != nil {
|
t.Logf("State shows %d peers available for this node", peers.Len())
|
||||||
t.Errorf("Error getting peers from state: %v", err)
|
if peers.Len() > 0 && len(data.Peers) == 0 {
|
||||||
} else {
|
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
|
||||||
t.Logf("State shows %d peers available for this node", peers.Len())
|
|
||||||
if peers.Len() > 0 && len(data.Peers) == 0 {
|
|
||||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Errorf("Response data is nil")
|
t.Errorf("Response data is nil")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package mapper
|
package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
@ -8,11 +9,12 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
"tailscale.com/util/multierr"
|
"tailscale.com/util/multierr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
|
||||||
type MapResponseBuilder struct {
|
type MapResponseBuilder struct {
|
||||||
resp *tailcfg.MapResponse
|
resp *tailcfg.MapResponse
|
||||||
mapper *mapper
|
mapper *mapper
|
||||||
@ -21,7 +23,7 @@ type MapResponseBuilder struct {
|
|||||||
errs []error
|
errs []error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return &MapResponseBuilder{
|
return &MapResponseBuilder{
|
||||||
@ -35,37 +37,44 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addError adds an error to the builder's error list
|
// addError adds an error to the builder's error list.
|
||||||
func (b *MapResponseBuilder) addError(err error) {
|
func (b *MapResponseBuilder) addError(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.errs = append(b.errs, err)
|
b.errs = append(b.errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasErrors returns true if the builder has accumulated any errors
|
// hasErrors returns true if the builder has accumulated any errors.
|
||||||
func (b *MapResponseBuilder) hasErrors() bool {
|
func (b *MapResponseBuilder) hasErrors() bool {
|
||||||
return len(b.errs) > 0
|
return len(b.errs) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCapabilityVersion sets the capability version for the response
|
// WithCapabilityVersion sets the capability version for the response.
|
||||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||||
b.capVer = capVer
|
b.capVer = capVer
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSelfNode adds the requesting node to the response
|
// WithSelfNode adds the requesting node to the response.
|
||||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always use batcher's view of online status for self node
|
||||||
|
// The batcher respects grace periods for logout scenarios
|
||||||
|
node := nodeView.AsStruct()
|
||||||
|
if b.mapper.batcher != nil {
|
||||||
|
node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||||
|
}
|
||||||
|
|
||||||
_, matchers := b.mapper.state.Filter()
|
_, matchers := b.mapper.state.Filter()
|
||||||
tailnode, err := tailNode(
|
tailnode, err := tailNode(
|
||||||
node, b.capVer, b.mapper.state,
|
node.View(), b.capVer, b.mapper.state,
|
||||||
func(id types.NodeID) []netip.Prefix {
|
func(id types.NodeID) []netip.Prefix {
|
||||||
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||||
},
|
},
|
||||||
b.mapper.cfg)
|
b.mapper.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -74,29 +83,30 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.Node = tailnode
|
b.resp.Node = tailnode
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDERPMap adds the DERP map to the response
|
// WithDERPMap adds the DERP map to the response.
|
||||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||||
b.resp.DERPMap = b.mapper.state.DERPMap()
|
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDomain adds the domain configuration
|
// WithDomain adds the domain configuration.
|
||||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||||
b.resp.Domain = b.mapper.cfg.Domain()
|
b.resp.Domain = b.mapper.cfg.Domain()
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCollectServicesDisabled sets the collect services flag to false
|
// WithCollectServicesDisabled sets the collect services flag to false.
|
||||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||||
b.resp.CollectServices.Set(false)
|
b.resp.CollectServices.Set(false)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDebugConfig adds debug configuration
|
// WithDebugConfig adds debug configuration
|
||||||
// It disables log tailing if the mapper's LogTail is not enabled
|
// It disables log tailing if the mapper's LogTail is not enabled.
|
||||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||||
b.resp.Debug = &tailcfg.Debug{
|
b.resp.Debug = &tailcfg.Debug{
|
||||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||||
@ -104,11 +114,11 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
// WithSSHPolicy adds SSH policy configuration for the requesting node.
|
||||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,38 +129,41 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.SSHPolicy = sshPolicy
|
b.resp.SSHPolicy = sshPolicy
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDNSConfig adds DNS configuration for the requesting node
|
// WithDNSConfig adds DNS configuration for the requesting node.
|
||||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
// WithUserProfiles adds user profiles for the requesting node and given peers.
|
||||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPacketFilters adds packet filter rules based on policy
|
// WithPacketFilters adds packet filter rules based on policy.
|
||||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
b.addError(err)
|
b.addError(errors.New("node not found"))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,9 +180,8 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
// WithPeers adds full peer list with policy filtering (for full map response).
|
||||||
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
|
|
||||||
tailPeers, err := b.buildTailPeers(peers)
|
tailPeers, err := b.buildTailPeers(peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.addError(err)
|
b.addError(err)
|
||||||
@ -177,12 +189,12 @@ func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.Peers = tailPeers
|
b.resp.Peers = tailPeers
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
// WithPeerChanges adds changed peers with policy filtering (for incremental updates).
|
||||||
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
|
|
||||||
tailPeers, err := b.buildTailPeers(peers)
|
tailPeers, err := b.buildTailPeers(peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.addError(err)
|
b.addError(err)
|
||||||
@ -190,14 +202,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.resp.PeersChanged = tailPeers
|
b.resp.PeersChanged = tailPeers
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting
|
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
|
||||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, errors.New("node not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, matchers := b.mapper.state.Filter()
|
filter, matchers := b.mapper.state.Filter()
|
||||||
@ -229,24 +242,24 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
|
|||||||
return tailPeers, nil
|
return tailPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeerChangedPatch adds peer change patches
|
// WithPeerChangedPatch adds peer change patches.
|
||||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||||
b.resp.PeersChangedPatch = changes
|
b.resp.PeersChangedPatch = changes
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPeersRemoved adds removed peer IDs
|
// WithPeersRemoved adds removed peer IDs.
|
||||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||||
|
|
||||||
var tailscaleIDs []tailcfg.NodeID
|
var tailscaleIDs []tailcfg.NodeID
|
||||||
for _, id := range removedIDs {
|
for _, id := range removedIDs {
|
||||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||||
}
|
}
|
||||||
b.resp.PeersRemoved = tailscaleIDs
|
b.resp.PeersRemoved = tailscaleIDs
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build finalizes the response and returns marshaled bytes
|
// Build finalizes the response and returns marshaled bytes.
|
||||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||||
if len(b.errs) > 0 {
|
if len(b.errs) > 0 {
|
||||||
return nil, multierr.New(b.errs...)
|
return nil, multierr.New(b.errs...)
|
||||||
|
@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockState := &state.State{}
|
mockState := &state.State{}
|
||||||
m := &mapper{
|
m := &mapper{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID)
|
builder := m.NewMapResponseBuilder(nodeID)
|
||||||
|
|
||||||
// Test basic builder creation
|
// Test basic builder creation
|
||||||
assert.NotNil(t, builder)
|
assert.NotNil(t, builder)
|
||||||
assert.Equal(t, nodeID, builder.nodeID)
|
assert.Equal(t, nodeID, builder.nodeID)
|
||||||
@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
capVer := tailcfg.CapabilityVersion(42)
|
capVer := tailcfg.CapabilityVersion(42)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer)
|
WithCapabilityVersion(capVer)
|
||||||
|
|
||||||
assert.Equal(t, capVer, builder.capVer)
|
assert.Equal(t, capVer, builder.capVer)
|
||||||
assert.False(t, builder.hasErrors())
|
assert.False(t, builder.hasErrors())
|
||||||
}
|
}
|
||||||
@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
|||||||
ServerURL: "https://test.example.com",
|
ServerURL: "https://test.example.com",
|
||||||
BaseDomain: domain,
|
BaseDomain: domain,
|
||||||
}
|
}
|
||||||
|
|
||||||
mockState := &state.State{}
|
mockState := &state.State{}
|
||||||
m := &mapper{
|
m := &mapper{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithDomain()
|
WithDomain()
|
||||||
|
|
||||||
assert.Equal(t, domain, builder.resp.Domain)
|
assert.Equal(t, domain, builder.resp.Domain)
|
||||||
assert.False(t, builder.hasErrors())
|
assert.False(t, builder.hasErrors())
|
||||||
}
|
}
|
||||||
@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithCollectServicesDisabled()
|
WithCollectServicesDisabled()
|
||||||
|
|
||||||
value, isSet := builder.resp.CollectServices.Get()
|
value, isSet := builder.resp.CollectServices.Get()
|
||||||
assert.True(t, isSet)
|
assert.True(t, isSet)
|
||||||
assert.False(t, value)
|
assert.False(t, value)
|
||||||
@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
|||||||
|
|
||||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
logTailEnabled bool
|
logTailEnabled bool
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "LogTail enabled",
|
name: "LogTail enabled",
|
||||||
logTailEnabled: true,
|
logTailEnabled: true,
|
||||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "LogTail disabled",
|
name: "LogTail disabled",
|
||||||
logTailEnabled: false,
|
logTailEnabled: false,
|
||||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cfg := &types.Config{
|
cfg := &types.Config{
|
||||||
@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithDebugConfig()
|
WithDebugConfig()
|
||||||
|
|
||||||
require.NotNil(t, builder.resp.Debug)
|
require.NotNil(t, builder.resp.Debug)
|
||||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||||
assert.False(t, builder.hasErrors())
|
assert.False(t, builder.hasErrors())
|
||||||
@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
changes := []*tailcfg.PeerChange{
|
changes := []*tailcfg.PeerChange{
|
||||||
{
|
{
|
||||||
NodeID: 123,
|
NodeID: 123,
|
||||||
DERPRegion: 1,
|
DERPRegion: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
NodeID: 456,
|
NodeID: 456,
|
||||||
DERPRegion: 2,
|
DERPRegion: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithPeerChangedPatch(changes)
|
WithPeerChangedPatch(changes)
|
||||||
|
|
||||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||||
assert.False(t, builder.hasErrors())
|
assert.False(t, builder.hasErrors())
|
||||||
}
|
}
|
||||||
@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
removedID1 := types.NodeID(123)
|
removedID1 := types.NodeID(123)
|
||||||
removedID2 := types.NodeID(456)
|
removedID2 := types.NodeID(456)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithPeersRemoved(removedID1, removedID2)
|
WithPeersRemoved(removedID1, removedID2)
|
||||||
|
|
||||||
expected := []tailcfg.NodeID{
|
expected := []tailcfg.NodeID{
|
||||||
removedID1.NodeID(),
|
removedID1.NodeID(),
|
||||||
removedID2.NodeID(),
|
removedID2.NodeID(),
|
||||||
@ -197,23 +197,23 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
// Simulate an error in the builder
|
// Simulate an error in the builder
|
||||||
builder := m.NewMapResponseBuilder(nodeID)
|
builder := m.NewMapResponseBuilder(nodeID)
|
||||||
builder.addError(assert.AnError)
|
builder.addError(assert.AnError)
|
||||||
|
|
||||||
// All subsequent calls should continue to work and accumulate errors
|
// All subsequent calls should continue to work and accumulate errors
|
||||||
result := builder.
|
result := builder.
|
||||||
WithDomain().
|
WithDomain().
|
||||||
WithCollectServicesDisabled().
|
WithCollectServicesDisabled().
|
||||||
WithDebugConfig()
|
WithDebugConfig()
|
||||||
|
|
||||||
assert.True(t, result.hasErrors())
|
assert.True(t, result.hasErrors())
|
||||||
assert.Len(t, result.errs, 1)
|
assert.Len(t, result.errs, 1)
|
||||||
assert.Equal(t, assert.AnError, result.errs[0])
|
assert.Equal(t, assert.AnError, result.errs[0])
|
||||||
|
|
||||||
// Build should return the error
|
// Build should return the error
|
||||||
data, err := result.Build("none")
|
data, err := result.Build("none")
|
||||||
assert.Nil(t, data)
|
assert.Nil(t, data)
|
||||||
@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
|||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockState := &state.State{}
|
mockState := &state.State{}
|
||||||
m := &mapper{
|
m := &mapper{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
capVer := tailcfg.CapabilityVersion(99)
|
capVer := tailcfg.CapabilityVersion(99)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
WithDomain().
|
WithDomain().
|
||||||
WithCollectServicesDisabled().
|
WithCollectServicesDisabled().
|
||||||
WithDebugConfig()
|
WithDebugConfig()
|
||||||
|
|
||||||
// Verify all fields are set correctly
|
// Verify all fields are set correctly
|
||||||
assert.Equal(t, capVer, builder.capVer)
|
assert.Equal(t, capVer, builder.capVer)
|
||||||
assert.Equal(t, domain, builder.resp.Domain)
|
assert.Equal(t, domain, builder.resp.Domain)
|
||||||
@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
removedID1 := types.NodeID(100)
|
removedID1 := types.NodeID(100)
|
||||||
removedID2 := types.NodeID(200)
|
removedID2 := types.NodeID(200)
|
||||||
|
|
||||||
// Test calling WithPeersRemoved multiple times
|
// Test calling WithPeersRemoved multiple times
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithPeersRemoved(removedID1).
|
WithPeersRemoved(removedID1).
|
||||||
WithPeersRemoved(removedID2)
|
WithPeersRemoved(removedID2)
|
||||||
|
|
||||||
// Second call should overwrite the first
|
// Second call should overwrite the first
|
||||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||||
@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||||
|
|
||||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||||
assert.False(t, builder.hasErrors())
|
assert.False(t, builder.hasErrors())
|
||||||
}
|
}
|
||||||
@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
builder := m.NewMapResponseBuilder(nodeID).
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
WithPeerChangedPatch(nil)
|
WithPeerChangedPatch(nil)
|
||||||
|
|
||||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||||
assert.False(t, builder.hasErrors())
|
assert.False(t, builder.hasErrors())
|
||||||
}
|
}
|
||||||
@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: mockState,
|
state: mockState,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
// Create a builder and add multiple errors
|
// Create a builder and add multiple errors
|
||||||
builder := m.NewMapResponseBuilder(nodeID)
|
builder := m.NewMapResponseBuilder(nodeID)
|
||||||
builder.addError(assert.AnError)
|
builder.addError(assert.AnError)
|
||||||
builder.addError(assert.AnError)
|
builder.addError(assert.AnError)
|
||||||
builder.addError(nil) // This should be ignored
|
builder.addError(nil) // This should be ignored
|
||||||
|
|
||||||
// All subsequent calls should continue to work
|
// All subsequent calls should continue to work
|
||||||
result := builder.
|
result := builder.
|
||||||
WithDomain().
|
WithDomain().
|
||||||
WithCollectServicesDisabled()
|
WithCollectServicesDisabled()
|
||||||
|
|
||||||
assert.True(t, result.hasErrors())
|
assert.True(t, result.hasErrors())
|
||||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||||
|
|
||||||
// Build should return a multierr
|
// Build should return a multierr
|
||||||
data, err := result.Build("none")
|
data, err := result.Build("none")
|
||||||
assert.Nil(t, data)
|
assert.Nil(t, data)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// The error should contain information about multiple errors
|
// The error should contain information about multiple errors
|
||||||
assert.Contains(t, err.Error(), "multiple errors")
|
assert.Contains(t, err.Error(), "multiple errors")
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -49,6 +50,37 @@ type mapper struct {
|
|||||||
created time.Time
|
created time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addOnlineStatusToPeers adds fresh online status from batcher to peer nodes.
|
||||||
|
//
|
||||||
|
// We do a last-minute copy-and-write on the NodeView to inject current online status
|
||||||
|
// from the batcher's connection map. Online status is not populated upstream in NodeStore
|
||||||
|
// for consistency reasons - it's runtime connection state that should come from the
|
||||||
|
// connection manager (batcher) to ensure map responses have the freshest data.
|
||||||
|
func (m *mapper) addOnlineStatusToPeers(peers views.Slice[types.NodeView]) views.Slice[types.NodeView] {
|
||||||
|
if peers.Len() == 0 || m.batcher == nil {
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]types.NodeView, 0, peers.Len())
|
||||||
|
for _, peer := range peers.All() {
|
||||||
|
if !peer.Valid() {
|
||||||
|
result = append(result, peer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get online status from batcher connection map
|
||||||
|
// The batcher respects grace periods for logout scenarios
|
||||||
|
isOnline := m.batcher.IsConnected(peer.ID())
|
||||||
|
|
||||||
|
// Create a mutable copy and set online status
|
||||||
|
peerCopy := peer.AsStruct()
|
||||||
|
peerCopy.IsOnline = ptr.To(isOnline)
|
||||||
|
result = append(result, peerCopy.View())
|
||||||
|
}
|
||||||
|
|
||||||
|
return views.SliceOf(result)
|
||||||
|
}
|
||||||
|
|
||||||
type patch struct {
|
type patch struct {
|
||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
change *tailcfg.PeerChange
|
change *tailcfg.PeerChange
|
||||||
@ -140,10 +172,10 @@ func (m *mapper) fullMapResponse(
|
|||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
peers, err := m.state.ListPeers(nodeID)
|
peers := m.state.ListPeers(nodeID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
// Add fresh online status to peers from batcher connection state
|
||||||
}
|
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||||
|
|
||||||
return m.NewMapResponseBuilder(nodeID).
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
@ -154,9 +186,9 @@ func (m *mapper) fullMapResponse(
|
|||||||
WithDebugConfig().
|
WithDebugConfig().
|
||||||
WithSSHPolicy().
|
WithSSHPolicy().
|
||||||
WithDNSConfig().
|
WithDNSConfig().
|
||||||
WithUserProfiles(peers).
|
WithUserProfiles(peersWithOnlineStatus).
|
||||||
WithPacketFilters().
|
WithPacketFilters().
|
||||||
WithPeers(peers).
|
WithPeers(peersWithOnlineStatus).
|
||||||
Build(messages...)
|
Build(messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,16 +217,16 @@ func (m *mapper) peerChangeResponse(
|
|||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
changedNodeID types.NodeID,
|
changedNodeID types.NodeID,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
peers, err := m.state.ListPeers(nodeID, changedNodeID)
|
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
// Add fresh online status to peers from batcher connection state
|
||||||
}
|
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||||
|
|
||||||
return m.NewMapResponseBuilder(nodeID).
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
WithCapabilityVersion(capVer).
|
WithCapabilityVersion(capVer).
|
||||||
WithSelfNode().
|
WithSelfNode().
|
||||||
WithUserProfiles(peers).
|
WithUserProfiles(peersWithOnlineStatus).
|
||||||
WithPeerChanges(peers).
|
WithPeerChanges(peersWithOnlineStatus).
|
||||||
Build()
|
Build()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,11 +133,15 @@ func tailNode(
|
|||||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
|
// Always set LastSeen if it's valid, regardless of online status
|
||||||
// LastSeen is only set when node is
|
// This ensures that during logout grace periods (when IsOnline might be true
|
||||||
// not connected to the control server.
|
// for DNS preservation), other nodes can still see when this node disconnected
|
||||||
if node.LastSeen().Valid() {
|
if node.LastSeen().Valid() {
|
||||||
lastSeen := node.LastSeen().Get()
|
lastSeen := node.LastSeen().Get()
|
||||||
|
// Only set LastSeen if the node is offline OR if LastSeen is recent
|
||||||
|
// (indicating it disconnected recently but might be in grace period)
|
||||||
|
if !node.IsOnline().Valid() || !node.IsOnline().Get() ||
|
||||||
|
time.Since(lastSeen) < 60*time.Second {
|
||||||
tNode.LastSeen = &lastSeen
|
tNode.LastSeen = &lastSeen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/control/controlbase"
|
"tailscale.com/control/controlbase"
|
||||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
@ -296,12 +295,9 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||||
// and validates that it matches the MachineKey from the Noise session.
|
// and validates that it matches the MachineKey from the Noise session.
|
||||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||||
nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||||
if err != nil {
|
if !ok {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
|
||||||
}
|
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
@ -138,39 +139,61 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
|
||||||
// the nodes perspective according to the given policy.
|
// and returns the new list of approved routes.
|
||||||
// It reports true if any routes were approved.
|
// The approved routes will include:
|
||||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
|
// 1. ALL previously approved routes (regardless of whether they're still advertised)
|
||||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
// 2. New routes from announcedRoutes that can be auto-approved by policy
|
||||||
|
// This ensures that:
|
||||||
|
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
|
||||||
|
// - New routes can be auto-approved according to policy
|
||||||
|
// - Routes can only be removed by explicit admin action (not by auto-approval)
|
||||||
|
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
|
||||||
if pm == nil {
|
if pm == nil {
|
||||||
return false
|
return currentApproved, false
|
||||||
}
|
}
|
||||||
nodeView := node.View()
|
|
||||||
var newApproved []netip.Prefix
|
// Start with ALL currently approved routes - we never remove approved routes
|
||||||
for _, route := range nodeView.AnnouncedRoutes() {
|
newApproved := make([]netip.Prefix, len(currentApproved))
|
||||||
if pm.NodeCanApproveRoute(nodeView, route) {
|
copy(newApproved, currentApproved)
|
||||||
|
|
||||||
|
// Then, check for new routes that can be auto-approved
|
||||||
|
for _, route := range announcedRoutes {
|
||||||
|
// Skip if already approved
|
||||||
|
if slices.Contains(newApproved, route) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this new route can be auto-approved by policy
|
||||||
|
canApprove := pm.NodeCanApproveRoute(nv, route)
|
||||||
|
if canApprove {
|
||||||
newApproved = append(newApproved, route)
|
newApproved = append(newApproved, route)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Uint64("node.id", nv.ID().Uint64()).
|
||||||
|
Str("node.name", nv.Hostname()).
|
||||||
|
Str("route", route.String()).
|
||||||
|
Bool("can_approve", canApprove).
|
||||||
|
Msg("Evaluating route for auto-approval")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
// Sort and deduplicate
|
||||||
// This prevents clearing existing approved routes when nodes
|
tsaddr.SortPrefixes(newApproved)
|
||||||
// temporarily don't have announced routes during policy changes.
|
newApproved = slices.Compact(newApproved)
|
||||||
if len(newApproved) > 0 {
|
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||||
combined := append(newApproved, node.ApprovedRoutes...)
|
return route.IsValid()
|
||||||
tsaddr.SortPrefixes(combined)
|
})
|
||||||
combined = slices.Compact(combined)
|
|
||||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
|
||||||
return route.IsValid()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Only update if the routes actually changed
|
// Sort the current approved for comparison
|
||||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||||
node.ApprovedRoutes = combined
|
copy(sortedCurrent, currentApproved)
|
||||||
return true
|
tsaddr.SortPrefixes(sortedCurrent)
|
||||||
}
|
|
||||||
|
// Only update if the routes actually changed
|
||||||
|
if !slices.Equal(sortedCurrent, newApproved) {
|
||||||
|
return newApproved, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return newApproved, false
|
||||||
}
|
}
|
||||||
|
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/net/tsaddr"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
"tailscale.com/types/views"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||||
|
user1 := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser@",
|
||||||
|
}
|
||||||
|
user2 := types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "otheruser@",
|
||||||
|
}
|
||||||
|
users := []types.User{user1, user2}
|
||||||
|
|
||||||
|
node1 := &types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "test-node",
|
||||||
|
UserID: user1.ID,
|
||||||
|
User: user1,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ForcedTags: []string{"tag:test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
node2 := &types.Node{
|
||||||
|
ID: 2,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "other-node",
|
||||||
|
UserID: user2.ID,
|
||||||
|
User: user2,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a policy that auto-approves specific routes
|
||||||
|
policyJSON := `{
|
||||||
|
"groups": {
|
||||||
|
"group:test": ["testuser@"]
|
||||||
|
},
|
||||||
|
"tagOwners": {
|
||||||
|
"tag:test": ["testuser@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["*"],
|
||||||
|
"dst": ["*:*"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.0.0.0/8": ["testuser@", "tag:test"],
|
||||||
|
"10.1.0.0/24": ["testuser@"],
|
||||||
|
"10.2.0.0/24": ["testuser@"],
|
||||||
|
"192.168.0.0/24": ["tag:test"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
node *types.Node
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "previously_approved_route_no_longer_advertised_should_remain",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Should still be here!
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "Previously approved routes should never be removed even when no longer advertised",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add_new_auto_approved_route_keeps_old_approved",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8)
|
||||||
|
netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
description: "New auto-approved routes should be added while keeping old approved routes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_announced_routes_keeps_all_approved",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{}, // No routes announced
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "All approved routes should remain when no routes are announced",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_changes_when_announced_equals_approved",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "No changes should occur when announced routes match approved routes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto_approve_multiple_new_routes",
|
||||||
|
node: node1,
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved
|
||||||
|
netip.MustParsePrefix("172.16.0.0/24"), // Original kept
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
description: "Multiple new routes should be auto-approved while keeping existing approved routes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "node_without_permission_no_auto_approval",
|
||||||
|
node: node2, // Different node without the tag
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
description: "Routes should not be auto-approved for nodes without proper permissions",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||||
|
|
||||||
|
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||||
|
tsaddr.SortPrefixes(tt.wantApproved)
|
||||||
|
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||||
|
|
||||||
|
// Verify that all previously approved routes are still present
|
||||||
|
for _, prevRoute := range tt.currentApproved {
|
||||||
|
assert.Contains(t, gotApproved, prevRoute,
|
||||||
|
"previously approved route %s was removed - this should never happen", prevRoute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||||
|
// Create a basic policy for edge case testing
|
||||||
|
aclPolicy := `
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.1.0.0/24": ["test@"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}`
|
||||||
|
|
||||||
|
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil_policy_manager",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_current_approved",
|
||||||
|
currentApproved: nil,
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_announced_routes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: nil,
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate_approved_routes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_slices",
|
||||||
|
currentApproved: []netip.Prefix{},
|
||||||
|
announcedRoutes: []netip.Prefix{},
|
||||||
|
wantApproved: []netip.Prefix{},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for i, pmf := range pmfs {
|
||||||
|
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||||
|
// Create test user
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "test",
|
||||||
|
}
|
||||||
|
users := []types.User{user}
|
||||||
|
|
||||||
|
// Create test node
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "testnode",
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: tt.currentApproved,
|
||||||
|
}
|
||||||
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
|
// Create policy manager or use nil if specified
|
||||||
|
var pm PolicyManager
|
||||||
|
var err error
|
||||||
|
if tt.name != "nil_policy_manager" {
|
||||||
|
pm, err = pmf(users, nodes.ViewSlice())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
pm = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
|
||||||
|
|
||||||
|
// Handle nil vs empty slice comparison
|
||||||
|
if tt.wantApproved == nil {
|
||||||
|
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||||
|
} else {
|
||||||
|
tsaddr.SortPrefixes(tt.wantApproved)
|
||||||
|
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
361
hscontrol/policy/policy_route_approval_test.go
Normal file
361
hscontrol/policy/policy_route_approval_test.go
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||||
|
// Test policy that allows specific routes to be auto-approved
|
||||||
|
aclPolicy := `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:admins": ["test@"],
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.0.0.0/24": ["test@"],
|
||||||
|
"192.168.0.0/24": ["group:admins"],
|
||||||
|
"172.16.0.0/16": ["tag:approved"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tagOwners": {
|
||||||
|
"tag:approved": ["test@"],
|
||||||
|
},
|
||||||
|
}`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
nodeHostname string
|
||||||
|
nodeUser string
|
||||||
|
nodeTags []string
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "previously_approved_route_no_longer_advertised_remains",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add_new_auto_approved_route_keeps_existing",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Still advertised
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New route
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_announced_routes_keeps_all_approved",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manually_approved_route_not_in_policy_remains",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tagged_node_gets_tag_approved_routes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
nodeTags: []string{"tag:approved"},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex_scenario_multiple_changes",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
|
||||||
|
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
|
||||||
|
},
|
||||||
|
nodeUser: "test",
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for i, pmf := range pmfs {
|
||||||
|
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||||
|
// Create test user
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: tt.nodeUser,
|
||||||
|
}
|
||||||
|
users := []types.User{user}
|
||||||
|
|
||||||
|
// Create test node
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: tt.nodeHostname,
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: tt.announcedRoutes,
|
||||||
|
},
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: tt.currentApproved,
|
||||||
|
ForcedTags: tt.nodeTags,
|
||||||
|
}
|
||||||
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
|
// Create policy manager
|
||||||
|
pm, err := pmf(users, nodes.ViewSlice())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, pm)
|
||||||
|
|
||||||
|
// Test ApproveRoutesWithPolicy
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||||
|
pm,
|
||||||
|
node.View(),
|
||||||
|
tt.currentApproved,
|
||||||
|
tt.announcedRoutes,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check change flag
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch")
|
||||||
|
|
||||||
|
// Check approved routes match expected
|
||||||
|
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||||
|
t.Logf("Want: %v", tt.wantApproved)
|
||||||
|
t.Logf("Got: %v", gotApproved)
|
||||||
|
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all previously approved routes are still present
|
||||||
|
for _, prevRoute := range tt.currentApproved {
|
||||||
|
assert.Contains(t, gotApproved, prevRoute,
|
||||||
|
"previously approved route %s was removed - this should NEVER happen", prevRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no routes were incorrectly removed
|
||||||
|
for _, removedRoute := range tt.wantRemovedRoutes {
|
||||||
|
assert.NotContains(t, gotApproved, removedRoute,
|
||||||
|
"route %s should have been removed but wasn't", removedRoute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
||||||
|
aclPolicy := `
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||||
|
],
|
||||||
|
"autoApprovers": {
|
||||||
|
"routes": {
|
||||||
|
"10.0.0.0/8": ["test@"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentApproved []netip.Prefix
|
||||||
|
announcedRoutes []netip.Prefix
|
||||||
|
wantApproved []netip.Prefix
|
||||||
|
wantChanged bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil_current_approved",
|
||||||
|
currentApproved: nil,
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_current_approved",
|
||||||
|
currentApproved: []netip.Prefix{},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate_routes_handled",
|
||||||
|
currentApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
},
|
||||||
|
announcedRoutes: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantApproved: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
},
|
||||||
|
wantChanged: true, // Duplicates are removed, so it's a change
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
for i, pmf := range pmfs {
|
||||||
|
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||||
|
// Create test user
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "test",
|
||||||
|
}
|
||||||
|
users := []types.User{user}
|
||||||
|
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "testnode",
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: tt.announcedRoutes,
|
||||||
|
},
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: tt.currentApproved,
|
||||||
|
}
|
||||||
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
|
pm, err := pmf(users, nodes.ViewSlice())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||||
|
pm,
|
||||||
|
node.View(),
|
||||||
|
tt.currentApproved,
|
||||||
|
tt.announcedRoutes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantChanged, gotChanged)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||||
|
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
currentApproved := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
}
|
||||||
|
announcedRoutes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
node := types.Node{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: key.NewMachine().Public(),
|
||||||
|
NodeKey: key.NewNode().Public(),
|
||||||
|
Hostname: "testnode",
|
||||||
|
UserID: user.ID,
|
||||||
|
User: user,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: announcedRoutes,
|
||||||
|
},
|
||||||
|
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||||
|
ApprovedRoutes: currentApproved,
|
||||||
|
}
|
||||||
|
|
||||||
|
// With nil policy manager, should return current approved unchanged
|
||||||
|
gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes)
|
||||||
|
|
||||||
|
assert.False(t, gotChanged)
|
||||||
|
assert.Equal(t, currentApproved, gotApproved)
|
||||||
|
}
|
@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||||||
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
|
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
|
||||||
canApprove: false,
|
canApprove: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "policy-without-autoApprovers-section",
|
||||||
|
node: normalNode,
|
||||||
|
route: p("10.33.0.0/16"),
|
||||||
|
policy: `{
|
||||||
|
"groups": {
|
||||||
|
"group:admin": ["user1@"]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admin"],
|
||||||
|
"dst": ["group:admin:*"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:admin"],
|
||||||
|
"dst": ["10.33.0.0/16:*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
canApprove: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||||||
// The fast path is that a node requests to approve a prefix
|
// The fast path is that a node requests to approve a prefix
|
||||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||||
// check and return quickly
|
// check and return quickly
|
||||||
if _, ok := pm.autoApproveMap[route]; ok {
|
if approvers, ok := pm.autoApproveMap[route]; ok {
|
||||||
if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) {
|
canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains)
|
||||||
|
if canApprove {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||||||
// Check if prefix is larger (so containing) and then overlaps
|
// Check if prefix is larger (so containing) and then overlaps
|
||||||
// the route to see if the node can approve a subset of an autoapprover
|
// the route to see if the node can approve a subset of an autoapprover
|
||||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||||
if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) {
|
canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains)
|
||||||
|
if canApprove {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/sasha-s/go-deadlock"
|
"github.com/sasha-s/go-deadlock"
|
||||||
@ -112,6 +111,15 @@ func (m *mapSession) serve() {
|
|||||||
// This is the mechanism where the node gives us information about its
|
// This is the mechanism where the node gives us information about its
|
||||||
// current configuration.
|
// current configuration.
|
||||||
//
|
//
|
||||||
|
// Process the MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||||
|
c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||||
|
if err != nil {
|
||||||
|
httpError(m.w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.h.Change(c)
|
||||||
|
|
||||||
// If OmitPeers is true and Stream is false
|
// If OmitPeers is true and Stream is false
|
||||||
// then the server will let clients update their endpoints without
|
// then the server will let clients update their endpoints without
|
||||||
// breaking existing long-polling (Stream == true) connections.
|
// breaking existing long-polling (Stream == true) connections.
|
||||||
@ -122,14 +130,6 @@ func (m *mapSession) serve() {
|
|||||||
// the response and just wants a 200.
|
// the response and just wants a 200.
|
||||||
// !req.stream && req.OmitPeers
|
// !req.stream && req.OmitPeers
|
||||||
if m.isEndpointUpdate() {
|
if m.isEndpointUpdate() {
|
||||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
|
||||||
if err != nil {
|
|
||||||
httpError(m.w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.h.Change(c)
|
|
||||||
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
m.w.WriteHeader(http.StatusOK)
|
||||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||||
}
|
}
|
||||||
@ -142,6 +142,8 @@ func (m *mapSession) serve() {
|
|||||||
func (m *mapSession) serveLongPoll() {
|
func (m *mapSession) serveLongPoll() {
|
||||||
m.beforeServeLongPoll()
|
m.beforeServeLongPoll()
|
||||||
|
|
||||||
|
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("starting long poll session chan(%p)", m.ch)
|
||||||
|
|
||||||
// Clean up the session when the client disconnects
|
// Clean up the session when the client disconnects
|
||||||
defer func() {
|
defer func() {
|
||||||
m.cancelChMu.Lock()
|
m.cancelChMu.Lock()
|
||||||
@ -149,18 +151,26 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
close(m.cancelCh)
|
close(m.cancelCh)
|
||||||
m.cancelChMu.Unlock()
|
m.cancelChMu.Unlock()
|
||||||
|
|
||||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removing session from batcher chan(%p)", m.ch)
|
||||||
// nodes has access to the same routes, so it might not be a big deal.
|
|
||||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
// Validate if we are actually closing the current session or
|
||||||
if err != nil {
|
// if the connection has been replaced. If the connection has been replaced,
|
||||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
// do not run the rest of the disconnect logic.
|
||||||
|
if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) {
|
||||||
|
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removed from batcher chan(%p)", m.ch)
|
||||||
|
// First update NodeStore to mark the node as offline
|
||||||
|
// This ensures the state is consistent before notifying the batcher
|
||||||
|
disconnectChange, err := m.h.state.Disconnect(m.node.ID)
|
||||||
|
if err != nil {
|
||||||
|
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the disconnect change notification
|
||||||
|
m.h.Change(disconnectChange)
|
||||||
|
|
||||||
|
m.afterServeLongPoll()
|
||||||
|
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||||
}
|
}
|
||||||
m.h.Change(disconnectChange)
|
|
||||||
|
|
||||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
|
||||||
|
|
||||||
m.afterServeLongPoll()
|
|
||||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up the client stream
|
// Set up the client stream
|
||||||
@ -172,25 +182,37 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
|
|
||||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||||
|
|
||||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
// Add node to batcher so it can receive updates,
|
||||||
// where the change is sent before the node is in the batcher's node map
|
// adding this before connecting it to the state ensure that
|
||||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
// it does not miss any updates that might be sent in the split
|
||||||
|
// time between the node connecting and the batcher being ready.
|
||||||
|
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||||
m.errf(err, "failed to add node to batcher")
|
m.errf(err, "failed to add node to batcher")
|
||||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
|
||||||
select {
|
|
||||||
case m.ch <- &tailcfg.MapResponse{}:
|
|
||||||
default:
|
|
||||||
// Channel might be closed
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||||
// but we still need to update routes and other state-level changes
|
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
|
||||||
connectChange := m.h.state.Connect(m.node)
|
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
|
||||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
|
||||||
m.h.Change(connectChange)
|
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
|
||||||
|
// from AvailableRoutes.
|
||||||
|
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||||
|
if err != nil {
|
||||||
|
m.errf(err, "failed to update node from initial MapRequest")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
m.h.Change(mapReqChange)
|
||||||
|
|
||||||
|
// Connect the node after its state has been updated.
|
||||||
|
// We send two separate change notifications because these are distinct operations:
|
||||||
|
// 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo)
|
||||||
|
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
|
||||||
|
// While this results in two notifications, it ensures route data is synchronized before
|
||||||
|
// primary route selection occurs, which is critical for proper HA subnet router failover.
|
||||||
|
connectChange := m.h.state.Connect(m.node.ID)
|
||||||
|
m.h.Change(connectChange)
|
||||||
|
|
||||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||||
|
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
package state
|
package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
@ -21,6 +25,56 @@ const (
|
|||||||
update = 3
|
update = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const prometheusNamespace = "headscale"
|
||||||
|
|
||||||
|
var (
|
||||||
|
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_operations_total",
|
||||||
|
Help: "Total number of NodeStore operations",
|
||||||
|
}, []string{"operation"})
|
||||||
|
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_operation_duration_seconds",
|
||||||
|
Help: "Duration of NodeStore operations",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
}, []string{"operation"})
|
||||||
|
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_batch_size",
|
||||||
|
Help: "Size of NodeStore write batches",
|
||||||
|
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
|
||||||
|
})
|
||||||
|
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_batch_duration_seconds",
|
||||||
|
Help: "Duration of NodeStore batch processing",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
})
|
||||||
|
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_snapshot_build_duration_seconds",
|
||||||
|
Help: "Duration of NodeStore snapshot building from nodes",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
})
|
||||||
|
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_nodes_total",
|
||||||
|
Help: "Total number of nodes in the NodeStore",
|
||||||
|
})
|
||||||
|
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_peers_calculation_duration_seconds",
|
||||||
|
Help: "Duration of peers calculation in NodeStore",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
})
|
||||||
|
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "nodestore_queue_depth",
|
||||||
|
Help: "Current depth of NodeStore write queue",
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
// NodeStore is a thread-safe store for nodes.
|
// NodeStore is a thread-safe store for nodes.
|
||||||
// It is a copy-on-write structure, replacing the "snapshot"
|
// It is a copy-on-write structure, replacing the "snapshot"
|
||||||
// when a change to the structure occurs. It is optimised for reads,
|
// when a change to the structure occurs. It is optimised for reads,
|
||||||
@ -29,13 +83,14 @@ const (
|
|||||||
// changes rapidly.
|
// changes rapidly.
|
||||||
//
|
//
|
||||||
// Writes will block until committed, while reads are never
|
// Writes will block until committed, while reads are never
|
||||||
// blocked.
|
// blocked. This means that the caller of a write operation
|
||||||
|
// is responsible for ensuring an update depending on a write
|
||||||
|
// is not issued before the write is complete.
|
||||||
type NodeStore struct {
|
type NodeStore struct {
|
||||||
data atomic.Pointer[Snapshot]
|
data atomic.Pointer[Snapshot]
|
||||||
|
|
||||||
peersFunc PeersFunc
|
peersFunc PeersFunc
|
||||||
writeQueue chan work
|
writeQueue chan work
|
||||||
// TODO: metrics
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||||
@ -50,9 +105,17 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
|||||||
}
|
}
|
||||||
store.data.Store(&snap)
|
store.data.Store(&snap)
|
||||||
|
|
||||||
|
// Initialize node count gauge
|
||||||
|
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||||
|
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Snapshot is the representation of the current state of the NodeStore.
|
||||||
|
// It contains all nodes and their relationships.
|
||||||
|
// It is a copy-on-write structure, meaning that when a write occurs,
|
||||||
|
// a new Snapshot is created with the updated state,
|
||||||
|
// and replaces the old one atomically.
|
||||||
type Snapshot struct {
|
type Snapshot struct {
|
||||||
// nodesByID is the main source of truth for nodes.
|
// nodesByID is the main source of truth for nodes.
|
||||||
nodesByID map[types.NodeID]types.Node
|
nodesByID map[types.NodeID]types.Node
|
||||||
@ -64,15 +127,19 @@ type Snapshot struct {
|
|||||||
allNodes []types.NodeView
|
allNodes []types.NodeView
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||||
|
// with the relationships between nodes and their peers.
|
||||||
|
// This will typically be used to calculate which nodes can see each other
|
||||||
|
// based on the current policy.
|
||||||
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||||
|
|
||||||
|
// work represents a single operation to be performed on the NodeStore.
|
||||||
type work struct {
|
type work struct {
|
||||||
op int
|
op int
|
||||||
nodeID types.NodeID
|
nodeID types.NodeID
|
||||||
node types.Node
|
node types.Node
|
||||||
updateFn UpdateNodeFunc
|
updateFn UpdateNodeFunc
|
||||||
result chan struct{}
|
result chan struct{}
|
||||||
immediate bool // For operations that need immediate processing
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PutNode adds or updates a node in the store.
|
// PutNode adds or updates a node in the store.
|
||||||
@ -80,6 +147,9 @@ type work struct {
|
|||||||
// If the node does not exist, it will be added.
|
// If the node does not exist, it will be added.
|
||||||
// This is a blocking operation that waits for the write to complete.
|
// This is a blocking operation that waits for the write to complete.
|
||||||
func (s *NodeStore) PutNode(n types.Node) {
|
func (s *NodeStore) PutNode(n types.Node) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
work := work{
|
work := work{
|
||||||
op: put,
|
op: put,
|
||||||
nodeID: n.ID,
|
nodeID: n.ID,
|
||||||
@ -87,8 +157,12 @@ func (s *NodeStore) PutNode(n types.Node) {
|
|||||||
result: make(chan struct{}),
|
result: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nodeStoreQueueDepth.Inc()
|
||||||
s.writeQueue <- work
|
s.writeQueue <- work
|
||||||
<-work.result
|
<-work.result
|
||||||
|
nodeStoreQueueDepth.Dec()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||||
@ -96,7 +170,21 @@ type UpdateNodeFunc func(n *types.Node)
|
|||||||
|
|
||||||
// UpdateNode applies a function to modify a specific node in the store.
|
// UpdateNode applies a function to modify a specific node in the store.
|
||||||
// This is a blocking operation that waits for the write to complete.
|
// This is a blocking operation that waits for the write to complete.
|
||||||
|
// This is analogous to a database "transaction", or, the caller should
|
||||||
|
// rather collect all data they want to change, and then call this function.
|
||||||
|
// Fewer calls are better.
|
||||||
|
//
|
||||||
|
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||||
|
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||||
|
// This is because the main nodesByID map contains the struct, and every other map is using a
|
||||||
|
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
|
||||||
|
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||||
|
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||||
|
// on all read operations.
|
||||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
work := work{
|
work := work{
|
||||||
op: update,
|
op: update,
|
||||||
nodeID: nodeID,
|
nodeID: nodeID,
|
||||||
@ -104,48 +192,47 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
|||||||
result: make(chan struct{}),
|
result: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nodeStoreQueueDepth.Inc()
|
||||||
s.writeQueue <- work
|
s.writeQueue <- work
|
||||||
<-work.result
|
<-work.result
|
||||||
}
|
nodeStoreQueueDepth.Dec()
|
||||||
|
|
||||||
// UpdateNodeImmediate applies a function to modify a specific node in the store
|
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||||
// with immediate processing (bypassing normal batching delays).
|
|
||||||
// Use this for time-sensitive updates like online status changes.
|
|
||||||
func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
|
||||||
work := work{
|
|
||||||
op: update,
|
|
||||||
nodeID: nodeID,
|
|
||||||
updateFn: updateFn,
|
|
||||||
result: make(chan struct{}),
|
|
||||||
immediate: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.writeQueue <- work
|
|
||||||
<-work.result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNode removes a node from the store by its ID.
|
// DeleteNode removes a node from the store by its ID.
|
||||||
// This is a blocking operation that waits for the write to complete.
|
// This is a blocking operation that waits for the write to complete.
|
||||||
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
work := work{
|
work := work{
|
||||||
op: del,
|
op: del,
|
||||||
nodeID: id,
|
nodeID: id,
|
||||||
result: make(chan struct{}),
|
result: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nodeStoreQueueDepth.Inc()
|
||||||
s.writeQueue <- work
|
s.writeQueue <- work
|
||||||
<-work.result
|
<-work.result
|
||||||
|
nodeStoreQueueDepth.Dec()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("delete").Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start initializes the NodeStore and starts processing the write queue.
|
||||||
func (s *NodeStore) Start() {
|
func (s *NodeStore) Start() {
|
||||||
s.writeQueue = make(chan work)
|
s.writeQueue = make(chan work)
|
||||||
go s.processWrite()
|
go s.processWrite()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops the NodeStore and closes the write queue.
|
||||||
func (s *NodeStore) Stop() {
|
func (s *NodeStore) Stop() {
|
||||||
close(s.writeQueue)
|
close(s.writeQueue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processWrite processes the write queue in batches.
|
||||||
|
// It collects writes into batches and applies them periodically.
|
||||||
func (s *NodeStore) processWrite() {
|
func (s *NodeStore) processWrite() {
|
||||||
c := time.NewTicker(batchTimeout)
|
c := time.NewTicker(batchTimeout)
|
||||||
batch := make([]work, 0, batchSize)
|
batch := make([]work, 0, batchSize)
|
||||||
@ -157,13 +244,7 @@ func (s *NodeStore) processWrite() {
|
|||||||
c.Stop()
|
c.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle immediate operations right away
|
|
||||||
if w.immediate {
|
|
||||||
s.applyBatch([]work{w})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
batch = append(batch, w)
|
batch = append(batch, w)
|
||||||
if len(batch) >= batchSize {
|
if len(batch) >= batchSize {
|
||||||
s.applyBatch(batch)
|
s.applyBatch(batch)
|
||||||
@ -181,7 +262,22 @@ func (s *NodeStore) processWrite() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyBatch applies a batch of work to the node store.
|
||||||
|
// This means that it takes a copy of the current nodes,
|
||||||
|
// then applies the batch of operations to that copy,
|
||||||
|
// runs any precomputation needed (like calculating peers),
|
||||||
|
// and finally replaces the snapshot in the store with the new one.
|
||||||
|
// The replacement of the snapshot is atomic, ensuring that reads
|
||||||
|
// are never blocked by writes.
|
||||||
|
// Each write item is blocked until the batch is applied to ensure
|
||||||
|
// the caller knows the operation is complete and do not send any
|
||||||
|
// updates that are dependent on a read that is yet to be written.
|
||||||
func (s *NodeStore) applyBatch(batch []work) {
|
func (s *NodeStore) applyBatch(batch []work) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreBatchDuration)
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreBatchSize.Observe(float64(len(batch)))
|
||||||
|
|
||||||
nodes := make(map[types.NodeID]types.Node)
|
nodes := make(map[types.NodeID]types.Node)
|
||||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||||
|
|
||||||
@ -201,15 +297,25 @@ func (s *NodeStore) applyBatch(batch []work) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
||||||
|
|
||||||
s.data.Store(&newSnap)
|
s.data.Store(&newSnap)
|
||||||
|
|
||||||
|
// Update node count gauge
|
||||||
|
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||||
|
|
||||||
for _, w := range batch {
|
for _, w := range batch {
|
||||||
close(w.result)
|
close(w.result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// snapshotFromNodes creates a new Snapshot from the provided nodes.
|
||||||
|
// It builds a lot of "indexes" to make lookups fast for datasets we
|
||||||
|
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
|
||||||
|
// This is not a fast operation, it is the "slow" part of our copy-on-write
|
||||||
|
// structure, but it allows us to have fast reads and efficient lookups.
|
||||||
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
allNodes := make([]types.NodeView, 0, len(nodes))
|
allNodes := make([]types.NodeView, 0, len(nodes))
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
allNodes = append(allNodes, n.View())
|
allNodes = append(allNodes, n.View())
|
||||||
@ -219,8 +325,17 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
|||||||
nodesByID: nodes,
|
nodesByID: nodes,
|
||||||
allNodes: allNodes,
|
allNodes: allNodes,
|
||||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||||
peersByNode: peersFunc(allNodes),
|
|
||||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
// peersByNode is most likely the most expensive operation,
|
||||||
|
// it will use the list of all nodes, combined with the
|
||||||
|
// current policy to precalculate which nodes are peers and
|
||||||
|
// can see each other.
|
||||||
|
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||||
|
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||||
|
defer peersTimer.ObserveDuration()
|
||||||
|
return peersFunc(allNodes)
|
||||||
|
}(),
|
||||||
|
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build nodesByUser and nodesByNodeKey maps
|
// Build nodesByUser and nodesByNodeKey maps
|
||||||
@ -234,9 +349,21 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetNode retrieves a node by its ID.
|
// GetNode retrieves a node by its ID.
|
||||||
func (s *NodeStore) GetNode(id types.NodeID) types.NodeView {
|
// The bool indicates if the node exists or is available (like "err not found").
|
||||||
n := s.data.Load().nodesByID[id]
|
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||||
return n.View()
|
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||||
|
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("get").Inc()
|
||||||
|
|
||||||
|
n, exists := s.data.Load().nodesByID[id]
|
||||||
|
if !exists {
|
||||||
|
return types.NodeView{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return n.View(), true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
||||||
@ -306,15 +433,30 @@ func (s *NodeStore) DebugString() string {
|
|||||||
|
|
||||||
// ListNodes returns a slice of all nodes in the store.
|
// ListNodes returns a slice of all nodes in the store.
|
||||||
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("list").Inc()
|
||||||
|
|
||||||
return views.SliceOf(s.data.Load().allNodes)
|
return views.SliceOf(s.data.Load().allNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPeers returns a slice of all peers for a given node ID.
|
// ListPeers returns a slice of all peers for a given node ID.
|
||||||
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("list_peers").Inc()
|
||||||
|
|
||||||
return views.SliceOf(s.data.Load().peersByNode[id])
|
return views.SliceOf(s.data.Load().peersByNode[id])
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
||||||
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
||||||
|
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
|
||||||
|
defer timer.ObserveDuration()
|
||||||
|
|
||||||
|
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
|
||||||
|
|
||||||
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||||||
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||||
return make(map[types.NodeID][]types.NodeView)
|
return make(map[types.NodeID][]types.NodeView)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, peersFunc
|
return nodes, peersFunc
|
||||||
},
|
},
|
||||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
@ -61,6 +62,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||||||
1: createTestNode(1, 1, "user1", "node1"),
|
1: createTestNode(1, 1, "user1", "node1"),
|
||||||
2: createTestNode(2, 1, "user1", "node2"),
|
2: createTestNode(2, 1, "user1", "node2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, allowAllPeersFunc
|
return nodes, allowAllPeersFunc
|
||||||
},
|
},
|
||||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
@ -85,6 +87,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||||||
2: createTestNode(2, 2, "user2", "node2"),
|
2: createTestNode(2, 2, "user2", "node2"),
|
||||||
3: createTestNode(3, 1, "user1", "node3"),
|
3: createTestNode(3, 1, "user1", "node3"),
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, allowAllPeersFunc
|
return nodes, allowAllPeersFunc
|
||||||
},
|
},
|
||||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
@ -113,6 +116,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||||||
4: createTestNode(4, 4, "user4", "node4"),
|
4: createTestNode(4, 4, "user4", "node4"),
|
||||||
}
|
}
|
||||||
peersFunc := oddEvenPeersFunc
|
peersFunc := oddEvenPeersFunc
|
||||||
|
|
||||||
return nodes, peersFunc
|
return nodes, peersFunc
|
||||||
},
|
},
|
||||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||||
@ -191,6 +195,7 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||||||
}
|
}
|
||||||
ret[node.ID()] = peers
|
ret[node.ID()] = peers
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,6 +219,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||||||
}
|
}
|
||||||
ret[node.ID()] = peers
|
ret[node.ID()] = peers
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,6 +335,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||||||
node2 := createTestNode(2, 1, "user1", "node2")
|
node2 := createTestNode(2, 1, "user1", "node2")
|
||||||
node3 := createTestNode(3, 2, "user2", "node3")
|
node3 := createTestNode(3, 2, "user2", "node3")
|
||||||
initialNodes := types.Nodes{&node1, &node2, &node3}
|
initialNodes := types.Nodes{&node1, &node2, &node3}
|
||||||
|
|
||||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||||
},
|
},
|
||||||
steps: []testStep{
|
steps: []testStep{
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool {
|
|||||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,6 +104,7 @@ type Node struct {
|
|||||||
// headscale. It is best effort and not persisted.
|
// headscale. It is best effort and not persisted.
|
||||||
LastSeen *time.Time `gorm:"column:last_seen"`
|
LastSeen *time.Time `gorm:"column:last_seen"`
|
||||||
|
|
||||||
|
|
||||||
// ApprovedRoutes is a list of routes that the node is allowed to announce
|
// ApprovedRoutes is a list of routes that the node is allowed to announce
|
||||||
// as a subnet router. They are not necessarily the routes that the node
|
// as a subnet router. They are not necessarily the routes that the node
|
||||||
// announces at the moment.
|
// announces at the moment.
|
||||||
@ -420,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
||||||
|
//
|
||||||
|
// IMPORTANT: This method is used for internal data structures and should NOT be used
|
||||||
|
// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually
|
||||||
|
// with PrimaryRoutes to ensure it includes only routes actively served by the node.
|
||||||
|
// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto.
|
||||||
func (node *Node) SubnetRoutes() []netip.Prefix {
|
func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||||
var routes []netip.Prefix
|
var routes []netip.Prefix
|
||||||
|
|
||||||
@ -525,7 +531,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node.Hostname = hostInfo.Hostname
|
node.Hostname = hostInfo.Hostname
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node_id", node.ID.String()).
|
Str("node_id", node.ID.String()).
|
||||||
Str("new_hostname", node.Hostname).
|
Str("new_hostname", node.Hostname).
|
||||||
|
Loading…
Reference in New Issue
Block a user