mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-14 13:51:01 +02:00
rest
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
7eef3cc38c
commit
c24b988247
@ -136,9 +136,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
||||
// Initialize ephemeral garbage collector
|
||||
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||
node, err := app.state.GetNodeByID(ni)
|
||||
if err != nil {
|
||||
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
|
||||
node, ok := app.state.GetNodeByID(ni)
|
||||
if !ok {
|
||||
log.Warn().Uint64("node.id", ni.Uint64()).Msgf("ephemeral node not found for deletion")
|
||||
return
|
||||
}
|
||||
|
||||
@ -371,7 +371,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("HTTP authentication invoked")
|
||||
|
||||
authHeader := req.Header.Get("authorization")
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||
log.Error().
|
||||
@ -487,11 +487,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
var err error
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
|
||||
if profilingEnabled {
|
||||
if profilingPath != "" {
|
||||
err := os.MkdirAll(profilingPath, os.ModePerm)
|
||||
err = os.MkdirAll(profilingPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to create profiling directory")
|
||||
}
|
||||
@ -543,10 +544,7 @@ func (h *Headscale) Serve() error {
|
||||
// around between restarts, they will reconnect and the GC will
|
||||
// be cancelled.
|
||||
go h.ephemeralGC.Start()
|
||||
ephmNodes, err := h.state.ListEphemeralNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
|
||||
}
|
||||
ephmNodes := h.state.ListEphemeralNodes()
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
@ -778,23 +776,14 @@ func (h *Headscale) Serve() error {
|
||||
continue
|
||||
}
|
||||
|
||||
changed, err := h.state.ReloadPolicy()
|
||||
changes, err := h.state.ReloadPolicy()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("reloading policy")
|
||||
continue
|
||||
}
|
||||
|
||||
if changed {
|
||||
log.Info().
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
h.Change(changes...)
|
||||
|
||||
err = h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
||||
}
|
||||
|
||||
h.Change(change.PolicySet)
|
||||
}
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
log.Info().
|
||||
@ -1004,6 +993,8 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
// Change is used to send changes to nodes.
|
||||
// All change should be enqueued here and empty will be automatically
|
||||
// ignored.
|
||||
func (h *Headscale) Change(c change.ChangeSet) {
|
||||
h.mapBatcher.AddWork(c)
|
||||
func (h *Headscale) Change(cs ...change.ChangeSet) {
|
||||
for _, c := range cs {
|
||||
h.mapBatcher.AddWork(c)
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@ -28,27 +27,9 @@ func (h *Headscale) handleRegister(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("looking up node in database: %w", err)
|
||||
}
|
||||
|
||||
if node.Valid() {
|
||||
// If an existing node is trying to register with an auth key,
|
||||
// we need to validate the auth key even for existing nodes
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||
|
||||
if ok {
|
||||
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
@ -69,6 +50,7 @@ func (h *Headscale) handleRegister(
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||
}
|
||||
|
||||
@ -88,13 +70,22 @@ func (h *Headscale) handleExistingNode(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
if node.MachineKey != machineKey {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
|
||||
// If the node is expired and this is not a re-authentication attempt,
|
||||
// force the client to re-authenticate
|
||||
if expired && regReq.Auth == nil {
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
AuthURL: "", // Client will need to re-authenticate
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !expired && !regReq.Expiry.IsZero() {
|
||||
requestExpiry := regReq.Expiry
|
||||
|
||||
@ -117,12 +108,16 @@ func (h *Headscale) handleExistingNode(
|
||||
}
|
||||
}
|
||||
|
||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
|
||||
// CRITICAL: Use the updated node view for the response
|
||||
// The original node object has stale expiry information
|
||||
node = updatedNode.AsStruct()
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
@ -192,8 +187,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If node is nil, it means an ephemeral node was deleted during logout
|
||||
if node.Valid() {
|
||||
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||
if !node.Valid() {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
}
|
||||
@ -212,6 +207,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||
// now since we dont update the node/pol here anymore
|
||||
routeChange := h.state.AutoApproveRoutes(node)
|
||||
|
||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
@ -229,6 +225,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// }
|
||||
|
||||
user := node.User()
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
MachineAuthorized: true,
|
||||
NodeKeyExpired: node.IsExpired(),
|
||||
|
@ -936,7 +936,7 @@ AND auth_key_id NOT IN (
|
||||
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
|
||||
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
|
||||
// - Never write migrations that requires foreign keys to be disabled.
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if err := runMigrations(cfg, dbConn, migrations); err != nil {
|
||||
|
@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
||||
}
|
||||
|
||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||
// and renames it. If the name is not unique, it will return an error.
|
||||
// and renames it. Validation should be done in the state layer before calling this function.
|
||||
func RenameNode(tx *gorm.DB,
|
||||
nodeID types.NodeID, newName string,
|
||||
) error {
|
||||
err := util.CheckForFQDNRules(
|
||||
newName,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming node: %w", err)
|
||||
// Check if the new name is unique
|
||||
var count int64
|
||||
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
||||
}
|
||||
|
||||
uniq, err := isUniqueName(tx, newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if name is unique: %w", err)
|
||||
}
|
||||
|
||||
if !uniq {
|
||||
return fmt.Errorf("name is not unique: %s", newName)
|
||||
|
||||
if count > 0 {
|
||||
return fmt.Errorf("name is not unique")
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
@ -409,9 +403,16 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// CRITICAL: Reload the node to get the updated expiry
|
||||
// Without this, we return stale node data to NodeStore
|
||||
updatedNode, err := GetNodeByID(tx, node.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reload node after expiry update: %w", err)
|
||||
}
|
||||
|
||||
nodeChange = change.KeyExpiry(node.ID)
|
||||
|
||||
return node, nil
|
||||
return updatedNode, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -445,8 +446,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
node.ID = oldNode.ID
|
||||
node.GivenName = oldNode.GivenName
|
||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||
ipv4 = oldNode.IPv4
|
||||
ipv6 = oldNode.IPv6
|
||||
// Don't overwrite the provided IPs with old ones when they exist
|
||||
if ipv4 == nil {
|
||||
ipv4 = oldNode.IPv4
|
||||
}
|
||||
if ipv6 == nil {
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
}
|
||||
|
||||
// If the node exists and it already has IP(s), we just save it
|
||||
@ -781,19 +787,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
||||
|
||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||
|
||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, *node, nil, nil)
|
||||
// Allocate IPs for the test node using the database's IP allocator
|
||||
// This is a simplified allocation for testing - in production this would use State.ipAlloc
|
||||
ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err))
|
||||
}
|
||||
|
||||
var registeredNode *types.Node
|
||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
registeredNode, err = RegisterNode(tx, *node, ipv4, ipv6)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||
}
|
||||
|
||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
||||
}
|
||||
|
||||
return registeredNode
|
||||
}
|
||||
|
||||
@ -842,3 +852,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
// allocateTestIPs allocates sequential test IPs for nodes during testing.
|
||||
func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) {
|
||||
if !testing.Testing() {
|
||||
panic("allocateTestIPs can only be called during tests")
|
||||
}
|
||||
|
||||
// Use simple sequential allocation for tests
|
||||
// IPv4: 100.64.0.x (where x is nodeID)
|
||||
// IPv6: fd7a:115c:a1e0::x (where x is nodeID)
|
||||
|
||||
if nodeID > 254 {
|
||||
return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID)
|
||||
}
|
||||
|
||||
ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)})
|
||||
ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)})
|
||||
|
||||
return &ipv4, &ipv6, nil
|
||||
}
|
||||
|
@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
|
||||
func TestAutoApproveRoutes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
want2 []netip.Prefix
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
want2 []netip.Prefix
|
||||
expectChange bool // whether to expect route changes
|
||||
}{
|
||||
{
|
||||
name: "no-auto-approvers-empty-policy",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:admins": ["test@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admins"],
|
||||
"dst": ["group:admins:*"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
want: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||
want2: []netip.Prefix{}, // Should be empty - no auto-approvers
|
||||
expectChange: false, // No changes expected
|
||||
},
|
||||
{
|
||||
name: "no-auto-approvers-explicit-empty",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:admins": ["test@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admins"],
|
||||
"dst": ["group:admins:*"]
|
||||
}
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {},
|
||||
"exitNode": []
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
|
||||
want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||
want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
|
||||
expectChange: false, // No changes expected
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub-kube",
|
||||
acl: `
|
||||
@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
expectChange: true, // Routes should be approved
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub-exit-tag",
|
||||
@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
expectChange: true, // Routes should be approved
|
||||
},
|
||||
}
|
||||
|
||||
@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pm)
|
||||
|
||||
changed1 := policy.AutoApproveRoutes(pm, &node)
|
||||
assert.True(t, changed1)
|
||||
newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes)
|
||||
assert.Equal(t, tt.expectChange, changed1)
|
||||
|
||||
err = adb.DB.Save(&node).Error
|
||||
require.NoError(t, err)
|
||||
if changed1 {
|
||||
err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
|
||||
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes)
|
||||
if changed2 {
|
||||
err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
node1ByID, err := adb.GetNodeByID(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
// For empty auto-approvers tests, handle nil vs empty slice comparison
|
||||
expectedRoutes1 := tt.want
|
||||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
node2ByID, err := adb.GetNodeByID(2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
expectedRoutes2 := tt.want2
|
||||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) {
|
||||
// No parameter means no filter, should return all peers
|
||||
nodes, err = db.ListPeers(1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Empty node list should return all peers
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// No match in IDs should return empty list and no error
|
||||
@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs, but node ID is still filtered out
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
}
|
||||
|
||||
@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) {
|
||||
// No parameter means no filter, should return all nodes
|
||||
nodes, err = db.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
// Empty node list should return all nodes
|
||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
}
|
||||
|
@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
||||
}
|
||||
|
||||
// AssignNodeToUser assigns a Node to a user.
|
||||
// Note: Validation should be done in the state layer before calling this function.
|
||||
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
|
||||
node, err := GetNodeByID(tx, nodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
// Check if the user exists
|
||||
var userExists bool
|
||||
if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
|
||||
return fmt.Errorf("failed to check if user exists: %w", err)
|
||||
}
|
||||
user, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if !userExists {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
node.User = *user
|
||||
node.UserID = user.ID
|
||||
if result := tx.Save(&node); result.Error != nil {
|
||||
return result.Error
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil {
|
||||
return fmt.Errorf("failed to assign node to user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -288,9 +288,9 @@ func (api headscaleV1APIServer) GetNode(
|
||||
ctx context.Context,
|
||||
request *v1.GetNodeRequest,
|
||||
) (*v1.GetNodeResponse, error) {
|
||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||
}
|
||||
|
||||
resp := node.Proto()
|
||||
@ -334,7 +334,12 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
ctx context.Context,
|
||||
request *v1.SetApprovedRoutesRequest,
|
||||
) (*v1.SetApprovedRoutesResponse, error) {
|
||||
var routes []netip.Prefix
|
||||
log.Debug().
|
||||
Uint64("node.id", request.GetNodeId()).
|
||||
Strs("requestedRoutes", request.GetRoutes()).
|
||||
Msg("gRPC SetApprovedRoutes called")
|
||||
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range request.GetRoutes() {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
@ -344,31 +349,34 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
// If the prefix is an exit route, add both. The client expect both
|
||||
// to annotate the node as an exit node.
|
||||
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
|
||||
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||
newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
|
||||
} else {
|
||||
routes = append(routes, prefix)
|
||||
newApproved = append(newApproved, prefix)
|
||||
}
|
||||
}
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...)
|
||||
|
||||
// Always propagate node changes from SetApprovedRoutes
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
// If routes changed, propagate those changes too
|
||||
if !routeChange.Empty() {
|
||||
api.h.Change(routeChange)
|
||||
}
|
||||
|
||||
proto := node.Proto()
|
||||
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID()))
|
||||
// Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
|
||||
// routes that are actively served from the node (per architectural requirement in types/node.go)
|
||||
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
|
||||
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
|
||||
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
||||
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
||||
Msg("gRPC SetApprovedRoutes completed")
|
||||
|
||||
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
||||
}
|
||||
@ -390,9 +398,9 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||
ctx context.Context,
|
||||
request *v1.DeleteNodeRequest,
|
||||
) (*v1.DeleteNodeResponse, error) {
|
||||
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "node not found")
|
||||
}
|
||||
|
||||
nodeChange, err := api.h.state.DeleteNode(node)
|
||||
@ -463,19 +471,13 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
nodes, err := api.h.state.ListNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes := api.h.state.ListNodes()
|
||||
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
@ -499,6 +501,7 @@ func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID,
|
||||
}
|
||||
}
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
|
||||
|
||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
|
||||
response[index] = resp
|
||||
}
|
||||
@ -674,11 +677,8 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
// a scenario where they might be allowed if the server has no nodes
|
||||
// yet, but it should help for the general case and for hot reloading
|
||||
// configurations.
|
||||
nodes, err := api.h.state.ListNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||
}
|
||||
changed, err := api.h.state.SetPolicy([]byte(p))
|
||||
nodes := api.h.state.ListNodes()
|
||||
_, err := api.h.state.SetPolicy([]byte(p))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting policy: %w", err)
|
||||
}
|
||||
@ -695,16 +695,16 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only send update if the packet filter has changed.
|
||||
if changed {
|
||||
err = api.h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.h.Change(change.PolicyChange())
|
||||
// Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
|
||||
// This ensures that routes are re-evaluated for auto-approval in cases where routes
|
||||
// were manually disabled but could now be auto-approved with the current policy.
|
||||
cs, err := api.h.state.ReloadPolicy()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reloading policy: %w", err)
|
||||
}
|
||||
|
||||
api.h.Change(cs...)
|
||||
|
||||
response := &v1.SetPolicyResponse{
|
||||
Policy: updated.Data,
|
||||
UpdatedAt: timestamppb.New(updated.UpdatedAt),
|
||||
|
@ -94,10 +94,7 @@ func (h *Headscale) handleVerifyRequest(
|
||||
return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
|
||||
}
|
||||
|
||||
nodes, err := h.state.ListNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot list nodes: %w", err)
|
||||
}
|
||||
nodes := h.state.ListNodes()
|
||||
|
||||
// Check if any node has the requested NodeKey
|
||||
var nodeKeyFound bool
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
@ -119,7 +120,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
return errors.New("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
@ -3,7 +3,6 @@ package mapper
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -21,7 +20,6 @@ type LockFreeBatcher struct {
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
@ -32,7 +30,6 @@ type LockFreeBatcher struct {
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
@ -46,16 +43,13 @@ type LockFreeBatcher struct {
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, change.FullSelf(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
@ -73,10 +67,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
log.Info().Uint64("node.id", id.Uint64()).Msg("Node connected to batcher")
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
@ -86,9 +79,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -97,12 +87,14 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// It reports if the node was actually closed. Returns false if the channel does not match the current connection,
|
||||
// indicating that we are actually not disconnecting the node, but rather ignoring the request.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called on a different channel, ignoring")
|
||||
return false // Not the current connection, not an error
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
@ -111,15 +103,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
log.Info().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher, marking as offline")
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
return true
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
@ -214,6 +205,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@ -228,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@ -240,12 +233,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
@ -276,8 +263,10 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -285,7 +274,7 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
b.addToBatch(c)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
// queueWork safely queues work.
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
@ -298,7 +287,7 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
// shouldProcessImmediately determines if a change should bypass batching.
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
@ -309,11 +298,8 @@ func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
// addToBatch adds a change to the pending batch.
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
@ -329,15 +315,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
// processBatchedChanges processes all pending batched changes.
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
@ -355,17 +339,27 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
val, ok := b.connected.Load(id)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
|
||||
// nil means connected
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// During grace period, always return true to allow DNS resolution
|
||||
// for logout HTTP requests to complete successfully
|
||||
gracePeriod := 45 * time.Second
|
||||
return time.Since(*val) < gracePeriod
|
||||
}
|
||||
|
||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||
@ -487,5 +481,6 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -26,6 +26,43 @@ type batcherTestCase struct {
|
||||
fn batcherFunc
|
||||
}
|
||||
|
||||
// testBatcherWrapper wraps a real batcher to add online/offline notifications
|
||||
// that would normally be sent by poll.go in production
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
// First add the node to the real batcher
|
||||
err := t.Batcher.AddNode(id, c, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then send the online notification that poll.go would normally send
|
||||
t.Batcher.AddWork(change.NodeOnline(id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
// First remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
// Then send the offline notification that poll.go would normally send
|
||||
t.Batcher.AddWork(change.NodeOffline(id))
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior
|
||||
func wrapBatcherForTest(b Batcher) Batcher {
|
||||
return &testBatcherWrapper{Batcher: b}
|
||||
}
|
||||
|
||||
// allBatcherFunctions contains all batcher implementations to test.
|
||||
var allBatcherFunctions = []batcherTestCase{
|
||||
{"LockFree", NewBatcherAndMapper},
|
||||
@ -176,8 +213,8 @@ func setupBatcherWithTestData(
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"users": ["*"],
|
||||
"ports": ["*:*"]
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
@ -187,8 +224,8 @@ func setupBatcherWithTestData(
|
||||
t.Fatalf("Failed to set allow-all policy: %v", err)
|
||||
}
|
||||
|
||||
// Create batcher with the state
|
||||
batcher := bf(cfg, state)
|
||||
// Create batcher with the state and wrap it for testing
|
||||
batcher := wrapBatcherForTest(bf(cfg, state))
|
||||
batcher.Start()
|
||||
|
||||
testData := &TestData{
|
||||
@ -455,7 +492,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
time.Sleep(100 * time.Millisecond) // Let connection settle
|
||||
|
||||
// Generate some work
|
||||
@ -558,7 +595,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullSet)
|
||||
@ -606,7 +643,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
// Disconnect all nodes
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
// Give time for final updates to process
|
||||
@ -724,7 +761,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
tn2 := testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
}
|
||||
@ -744,14 +781,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
select {
|
||||
case data := <-tn.ch:
|
||||
assertOnlineMapResponse(t, data, true)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Did not receive expected Online response update")
|
||||
}
|
||||
|
||||
@ -765,19 +802,19 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
len(data.Peers) >= 1 || data.Node != nil,
|
||||
"Should receive initial full map",
|
||||
)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Second node should receive its initial full map")
|
||||
}
|
||||
|
||||
// Disconnect the second node
|
||||
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
|
||||
batcher.RemoveNode(tn2.n.ID, tn2.ch)
|
||||
assert.False(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get update that second has disconnected.
|
||||
select {
|
||||
case data := <-tn.ch:
|
||||
assertOnlineMapResponse(t, data, false)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Did not receive expected Online response update")
|
||||
}
|
||||
|
||||
@ -803,7 +840,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
// }
|
||||
|
||||
// Test RemoveNode
|
||||
batcher.RemoveNode(tn.n.ID, tn.ch, false)
|
||||
batcher.RemoveNode(tn.n.ID, tn.ch)
|
||||
if batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be disconnected after RemoveNode")
|
||||
}
|
||||
@ -949,7 +986,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
@ -1045,7 +1082,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Add real work during connection chaos
|
||||
@ -1059,7 +1096,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(1 * time.Microsecond)
|
||||
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
}()
|
||||
|
||||
// Remove second connection
|
||||
@ -1067,7 +1104,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2, false)
|
||||
batcher.RemoveNode(testNode.n.ID, ch2)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
@ -1142,7 +1179,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPSet)
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
@ -1184,7 +1221,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
// Rapid removal creates race between worker and removal
|
||||
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
|
||||
batcher.RemoveNode(testNode.n.ID, ch, false)
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Give workers time to process and close channels
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
@ -1254,7 +1291,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
for _, node := range stableNodes {
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
@ -1312,7 +1349,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
churningChannelsMutex.Lock()
|
||||
churningChannels[nodeID] = ch
|
||||
churningChannelsMutex.Unlock()
|
||||
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
@ -1349,7 +1386,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
ch, exists := churningChannels[nodeID]
|
||||
churningChannelsMutex.Unlock()
|
||||
if exists {
|
||||
batcher.RemoveNode(nodeID, ch, false)
|
||||
batcher.RemoveNode(nodeID, ch)
|
||||
}
|
||||
}(node.n.ID)
|
||||
}
|
||||
@ -1599,7 +1636,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
var connectedNodesMutex sync.RWMutex
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
connectedNodesMutex.Lock()
|
||||
connectedNodes[node.n.ID] = true
|
||||
connectedNodesMutex.Unlock()
|
||||
@ -1666,7 +1703,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
connectedNodesMutex.RUnlock()
|
||||
|
||||
if isConnected {
|
||||
batcher.RemoveNode(nodeID, channel, false)
|
||||
batcher.RemoveNode(nodeID, channel)
|
||||
connectedNodesMutex.Lock()
|
||||
connectedNodes[nodeID] = false
|
||||
connectedNodesMutex.Unlock()
|
||||
@ -1690,7 +1727,6 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
batcher.AddNode(
|
||||
nodeID,
|
||||
channel,
|
||||
false,
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
connectedNodesMutex.Lock()
|
||||
@ -1792,7 +1828,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Now disconnect all nodes from batcher to stop new updates
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch, false)
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
// Give time for enhanced tracking goroutines to process any remaining data in channels
|
||||
@ -1924,7 +1960,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Connect nodes one at a time to avoid overwhelming the work queue
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
// Small delay between connections to allow NodeCameOnline processing
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
@ -1936,12 +1972,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Check how many peers each node should see
|
||||
for i, node := range allNodes {
|
||||
peers, err := testData.State.ListPeers(node.n.ID)
|
||||
if err != nil {
|
||||
t.Errorf("Error listing peers for node %d: %v", i, err)
|
||||
} else {
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
peers := testData.State.ListPeers(node.n.ID)
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
|
||||
// Send a full update - this should generate full peer lists
|
||||
@ -1957,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
foundFullUpdate := false
|
||||
|
||||
// Read all available updates for each node
|
||||
for i := range len(allNodes) {
|
||||
for i := range allNodes {
|
||||
nodeUpdates := 0
|
||||
t.Logf("Reading updates for node %d:", i)
|
||||
|
||||
@ -2047,7 +2079,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
t.Logf("=== WORK QUEUE TRACING TEST ===")
|
||||
|
||||
// Connect first node
|
||||
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d", nodes[0].n.ID)
|
||||
|
||||
// Wait for initial NodeCameOnline to be processed
|
||||
@ -2102,14 +2134,10 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check if there should be peers available
|
||||
peers, err := testData.State.ListPeers(nodes[0].n.ID)
|
||||
if err != nil {
|
||||
t.Errorf("Error getting peers from state: %v", err)
|
||||
} else {
|
||||
t.Logf("State shows %d peers available for this node", peers.Len())
|
||||
if peers.Len() > 0 && len(data.Peers) == 0 {
|
||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
|
||||
}
|
||||
peers := testData.State.ListPeers(nodes[0].n.ID)
|
||||
t.Logf("State shows %d peers available for this node", peers.Len())
|
||||
if peers.Len() > 0 && len(data.Peers) == 0 {
|
||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Response data is nil")
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
@ -8,11 +9,12 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
@ -21,7 +23,7 @@ type MapResponseBuilder struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
@ -35,37 +37,44 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
// addError adds an error to the builder's error list.
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
// hasErrors returns true if the builder has accumulated any errors.
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
// WithCapabilityVersion sets the capability version for the response.
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
// WithSelfNode adds the requesting node to the response.
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
// Always use batcher's view of online status for self node
|
||||
// The batcher respects grace periods for logout scenarios
|
||||
node := nodeView.AsStruct()
|
||||
if b.mapper.batcher != nil {
|
||||
node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
|
||||
}
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node, b.capVer, b.mapper.state,
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
@ -74,29 +83,30 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
// WithDERPMap adds the DERP map to the response.
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
// WithDomain adds the domain configuration.
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
// WithCollectServicesDisabled sets the collect services flag to false.
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
// It disables log tailing if the mapper's LogTail is not enabled.
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
@ -104,11 +114,11 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node.
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
@ -119,38 +129,41 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
// WithDNSConfig adds DNS configuration for the requesting node.
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers.
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
// WithPacketFilters adds packet filter rules based on policy.
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
return b
|
||||
}
|
||||
|
||||
@ -167,9 +180,8 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
// WithPeers adds full peer list with policy filtering (for full map response).
|
||||
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@ -177,12 +189,12 @@ func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapRe
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates).
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@ -190,14 +202,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting
|
||||
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
@ -229,24 +242,24 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
// WithPeerChangedPatch adds peer change patches.
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
// Build finalizes the response and returns marshaled bytes.
|
||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
|
@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
|
||||
|
||||
// Test basic builder creation
|
||||
assert.NotNil(t, builder)
|
||||
assert.Equal(t, nodeID, builder.nodeID)
|
||||
@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(42)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer)
|
||||
|
||||
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||
ServerURL: "https://test.example.com",
|
||||
BaseDomain: domain,
|
||||
}
|
||||
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDomain()
|
||||
|
||||
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
|
||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
logTailEnabled bool
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "LogTail enabled",
|
||||
name: "LogTail enabled",
|
||||
logTailEnabled: true,
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
},
|
||||
{
|
||||
name: "LogTail disabled",
|
||||
name: "LogTail disabled",
|
||||
logTailEnabled: false,
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugConfig()
|
||||
|
||||
|
||||
require.NotNil(t, builder.resp.Debug)
|
||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
changes := []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 123,
|
||||
NodeID: 123,
|
||||
DERPRegion: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 456,
|
||||
NodeID: 456,
|
||||
DERPRegion: 2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changes)
|
||||
|
||||
|
||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(123)
|
||||
removedID2 := types.NodeID(456)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1, removedID2)
|
||||
|
||||
|
||||
expected := []tailcfg.NodeID{
|
||||
removedID1.NodeID(),
|
||||
removedID2.NodeID(),
|
||||
@ -197,23 +197,23 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
// Simulate an error in the builder
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
|
||||
|
||||
// All subsequent calls should continue to work and accumulate errors
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 1)
|
||||
assert.Equal(t, assert.AnError, result.errs[0])
|
||||
|
||||
|
||||
// Build should return the error
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(99)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(100)
|
||||
removedID2 := types.NodeID(200)
|
||||
|
||||
|
||||
// Test calling WithPeersRemoved multiple times
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1).
|
||||
WithPeersRemoved(removedID2)
|
||||
|
||||
|
||||
// Second call should overwrite the first
|
||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||
|
||||
|
||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(nil)
|
||||
|
||||
|
||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
// Create a builder and add multiple errors
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(nil) // This should be ignored
|
||||
|
||||
|
||||
// All subsequent calls should continue to work
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
@ -49,6 +50,37 @@ type mapper struct {
|
||||
created time.Time
|
||||
}
|
||||
|
||||
// addOnlineStatusToPeers adds fresh online status from batcher to peer nodes.
|
||||
//
|
||||
// We do a last-minute copy-and-write on the NodeView to inject current online status
|
||||
// from the batcher's connection map. Online status is not populated upstream in NodeStore
|
||||
// for consistency reasons - it's runtime connection state that should come from the
|
||||
// connection manager (batcher) to ensure map responses have the freshest data.
|
||||
func (m *mapper) addOnlineStatusToPeers(peers views.Slice[types.NodeView]) views.Slice[types.NodeView] {
|
||||
if peers.Len() == 0 || m.batcher == nil {
|
||||
return peers
|
||||
}
|
||||
|
||||
result := make([]types.NodeView, 0, peers.Len())
|
||||
for _, peer := range peers.All() {
|
||||
if !peer.Valid() {
|
||||
result = append(result, peer)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get online status from batcher connection map
|
||||
// The batcher respects grace periods for logout scenarios
|
||||
isOnline := m.batcher.IsConnected(peer.ID())
|
||||
|
||||
// Create a mutable copy and set online status
|
||||
peerCopy := peer.AsStruct()
|
||||
peerCopy.IsOnline = ptr.To(isOnline)
|
||||
result = append(result, peerCopy.View())
|
||||
}
|
||||
|
||||
return views.SliceOf(result)
|
||||
}
|
||||
|
||||
type patch struct {
|
||||
timestamp time.Time
|
||||
change *tailcfg.PeerChange
|
||||
@ -140,10 +172,10 @@ func (m *mapper) fullMapResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.state.ListPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
// Add fresh online status to peers from batcher connection state
|
||||
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
@ -154,9 +186,9 @@ func (m *mapper) fullMapResponse(
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peers).
|
||||
WithUserProfiles(peersWithOnlineStatus).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
WithPeers(peersWithOnlineStatus).
|
||||
Build(messages...)
|
||||
}
|
||||
|
||||
@ -185,16 +217,16 @@ func (m *mapper) peerChangeResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
// Add fresh online status to peers from batcher connection state
|
||||
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
WithUserProfiles(peersWithOnlineStatus).
|
||||
WithPeerChanges(peersWithOnlineStatus).
|
||||
Build()
|
||||
}
|
||||
|
||||
|
@ -133,11 +133,15 @@ func tailNode(
|
||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||
}
|
||||
|
||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
|
||||
// LastSeen is only set when node is
|
||||
// not connected to the control server.
|
||||
if node.LastSeen().Valid() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
// Always set LastSeen if it's valid, regardless of online status
|
||||
// This ensures that during logout grace periods (when IsOnline might be true
|
||||
// for DNS preservation), other nodes can still see when this node disconnected
|
||||
if node.LastSeen().Valid() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
// Only set LastSeen if the node is offline OR if LastSeen is recent
|
||||
// (indicating it disconnected recently but might be in grace period)
|
||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() ||
|
||||
time.Since(lastSeen) < 60*time.Second {
|
||||
tNode.LastSeen = &lastSeen
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -296,12 +295,9 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||
// and validates that it matches the MachineKey from the Noise session.
|
||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||
nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||
nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if !ok {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
|
||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -138,39 +139,61 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
return ret
|
||||
}
|
||||
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
|
||||
// and returns the new list of approved routes.
|
||||
// The approved routes will include:
|
||||
// 1. ALL previously approved routes (regardless of whether they're still advertised)
|
||||
// 2. New routes from announcedRoutes that can be auto-approved by policy
|
||||
// This ensures that:
|
||||
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
|
||||
// - New routes can be auto-approved according to policy
|
||||
// - Routes can only be removed by explicit admin action (not by auto-approval)
|
||||
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
|
||||
if pm == nil {
|
||||
return false
|
||||
return currentApproved, false
|
||||
}
|
||||
nodeView := node.View()
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range nodeView.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(nodeView, route) {
|
||||
|
||||
// Start with ALL currently approved routes - we never remove approved routes
|
||||
newApproved := make([]netip.Prefix, len(currentApproved))
|
||||
copy(newApproved, currentApproved)
|
||||
|
||||
// Then, check for new routes that can be auto-approved
|
||||
for _, route := range announcedRoutes {
|
||||
// Skip if already approved
|
||||
if slices.Contains(newApproved, route) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this new route can be auto-approved by policy
|
||||
canApprove := pm.NodeCanApproveRoute(nv, route)
|
||||
if canApprove {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Uint64("node.id", nv.ID().Uint64()).
|
||||
Str("node.name", nv.Hostname()).
|
||||
Str("route", route.String()).
|
||||
Bool("can_approve", canApprove).
|
||||
Msg("Evaluating route for auto-approval")
|
||||
}
|
||||
|
||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
||||
// This prevents clearing existing approved routes when nodes
|
||||
// temporarily don't have announced routes during policy changes.
|
||||
if len(newApproved) > 0 {
|
||||
combined := append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(combined)
|
||||
combined = slices.Compact(combined)
|
||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
// Sort and deduplicate
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
||||
node.ApprovedRoutes = combined
|
||||
return true
|
||||
}
|
||||
// Sort the current approved for comparison
|
||||
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||
copy(sortedCurrent, currentApproved)
|
||||
tsaddr.SortPrefixes(sortedCurrent)
|
||||
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(sortedCurrent, newApproved) {
|
||||
return newApproved, true
|
||||
}
|
||||
|
||||
return false
|
||||
return newApproved, false
|
||||
}
|
||||
|
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
339
hscontrol/policy/policy_autoapprove_test.go
Normal file
@ -0,0 +1,339 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
user1 := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "testuser@",
|
||||
}
|
||||
user2 := types.User{
|
||||
Model: gorm.Model{ID: 2},
|
||||
Name: "otheruser@",
|
||||
}
|
||||
users := []types.User{user1, user2}
|
||||
|
||||
node1 := &types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "test-node",
|
||||
UserID: user1.ID,
|
||||
User: user1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ForcedTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
node2 := &types.Node{
|
||||
ID: 2,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "other-node",
|
||||
UserID: user2.ID,
|
||||
User: user2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
// Create a policy that auto-approves specific routes
|
||||
policyJSON := `{
|
||||
"groups": {
|
||||
"group:test": ["testuser@"]
|
||||
},
|
||||
"tagOwners": {
|
||||
"tag:test": ["testuser@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.0.0.0/8": ["testuser@", "tag:test"],
|
||||
"10.1.0.0/24": ["testuser@"],
|
||||
"10.2.0.0/24": ["testuser@"],
|
||||
"192.168.0.0/24": ["tag:test"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||
assert.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "previously_approved_route_no_longer_advertised_should_remain",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Should still be here!
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "Previously approved routes should never be removed even when no longer advertised",
|
||||
},
|
||||
{
|
||||
name: "add_new_auto_approved_route_keeps_old_approved",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8)
|
||||
netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept
|
||||
},
|
||||
wantChanged: true,
|
||||
description: "New auto-approved routes should be added while keeping old approved routes",
|
||||
},
|
||||
{
|
||||
name: "no_announced_routes_keeps_all_approved",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{}, // No routes announced
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "All approved routes should remain when no routes are announced",
|
||||
},
|
||||
{
|
||||
name: "no_changes_when_announced_equals_approved",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "No changes should occur when announced routes match approved routes",
|
||||
},
|
||||
{
|
||||
name: "auto_approve_multiple_new_routes",
|
||||
node: node1,
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved
|
||||
netip.MustParsePrefix("172.16.0.0/24"), // Original kept
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||
},
|
||||
wantChanged: true,
|
||||
description: "Multiple new routes should be auto-approved while keeping existing approved routes",
|
||||
},
|
||||
{
|
||||
name: "node_without_permission_no_auto_approval",
|
||||
node: node2, // Different node without the tag
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route
|
||||
},
|
||||
wantChanged: false,
|
||||
description: "Routes should not be auto-approved for nodes without proper permissions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||
|
||||
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||
|
||||
// Verify that all previously approved routes are still present
|
||||
for _, prevRoute := range tt.currentApproved {
|
||||
assert.Contains(t, gotApproved, prevRoute,
|
||||
"previously approved route %s was removed - this should never happen", prevRoute)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||
// Create a basic policy for edge case testing
|
||||
aclPolicy := `
|
||||
{
|
||||
"acls": [
|
||||
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.1.0.0/24": ["test@"],
|
||||
},
|
||||
},
|
||||
}`
|
||||
|
||||
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
}{
|
||||
{
|
||||
name: "nil_policy_manager",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
},
|
||||
{
|
||||
name: "nil_current_approved",
|
||||
currentApproved: nil,
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "nil_announced_routes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: nil,
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate_approved_routes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "empty_slices",
|
||||
currentApproved: []netip.Prefix{},
|
||||
announcedRoutes: []netip.Prefix{},
|
||||
wantApproved: []netip.Prefix{},
|
||||
wantChanged: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for i, pmf := range pmfs {
|
||||
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||
// Create test user
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
}
|
||||
users := []types.User{user}
|
||||
|
||||
// Create test node
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager or use nil if specified
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
if tt.name != "nil_policy_manager" {
|
||||
pm, err = pmf(users, nodes.ViewSlice())
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
pm = nil
|
||||
}
|
||||
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
|
||||
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
|
||||
|
||||
// Handle nil vs empty slice comparison
|
||||
if tt.wantApproved == nil {
|
||||
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||
} else {
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
361
hscontrol/policy/policy_route_approval_test.go
Normal file
361
hscontrol/policy/policy_route_approval_test.go
Normal file
@ -0,0 +1,361 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
// Test policy that allows specific routes to be auto-approved
|
||||
aclPolicy := `
|
||||
{
|
||||
"groups": {
|
||||
"group:admins": ["test@"],
|
||||
},
|
||||
"acls": [
|
||||
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.0.0.0/24": ["test@"],
|
||||
"192.168.0.0/24": ["group:admins"],
|
||||
"172.16.0.0/16": ["tag:approved"],
|
||||
},
|
||||
},
|
||||
"tagOwners": {
|
||||
"tag:approved": ["test@"],
|
||||
},
|
||||
}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
nodeHostname string
|
||||
nodeUser string
|
||||
nodeTags []string
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
|
||||
}{
|
||||
{
|
||||
name: "previously_approved_route_no_longer_advertised_remains",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
|
||||
},
|
||||
{
|
||||
name: "add_new_auto_approved_route_keeps_existing",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Still advertised
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New route
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "no_announced_routes_keeps_all_approved",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
},
|
||||
{
|
||||
name: "manually_approved_route_not_in_policy_remains",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "tagged_node_gets_tag_approved_routes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route
|
||||
},
|
||||
nodeUser: "test",
|
||||
nodeTags: []string{"tag:approved"},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "complex_scenario_multiple_changes",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
|
||||
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
|
||||
},
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
}
|
||||
|
||||
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||
|
||||
for _, tt := range tests {
|
||||
for i, pmf := range pmfs {
|
||||
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||
// Create test user
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: tt.nodeUser,
|
||||
}
|
||||
users := []types.User{user}
|
||||
|
||||
// Create test node
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: tt.nodeHostname,
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
ForcedTags: tt.nodeTags,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pm)
|
||||
|
||||
// Test ApproveRoutesWithPolicy
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||
pm,
|
||||
node.View(),
|
||||
tt.currentApproved,
|
||||
tt.announcedRoutes,
|
||||
)
|
||||
|
||||
// Check change flag
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch")
|
||||
|
||||
// Check approved routes match expected
|
||||
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||
t.Logf("Want: %v", tt.wantApproved)
|
||||
t.Logf("Got: %v", gotApproved)
|
||||
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// Verify all previously approved routes are still present
|
||||
for _, prevRoute := range tt.currentApproved {
|
||||
assert.Contains(t, gotApproved, prevRoute,
|
||||
"previously approved route %s was removed - this should NEVER happen", prevRoute)
|
||||
}
|
||||
|
||||
// Verify no routes were incorrectly removed
|
||||
for _, removedRoute := range tt.wantRemovedRoutes {
|
||||
assert.NotContains(t, gotApproved, removedRoute,
|
||||
"route %s should have been removed but wasn't", removedRoute)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
||||
aclPolicy := `
|
||||
{
|
||||
"acls": [
|
||||
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
|
||||
],
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.0.0.0/8": ["test@"],
|
||||
},
|
||||
},
|
||||
}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentApproved []netip.Prefix
|
||||
announcedRoutes []netip.Prefix
|
||||
wantApproved []netip.Prefix
|
||||
wantChanged bool
|
||||
}{
|
||||
{
|
||||
name: "nil_current_approved",
|
||||
currentApproved: nil,
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "empty_current_approved",
|
||||
currentApproved: []netip.Prefix{},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate_routes_handled",
|
||||
currentApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
},
|
||||
announcedRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
},
|
||||
wantChanged: true, // Duplicates are removed, so it's a change
|
||||
},
|
||||
}
|
||||
|
||||
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
|
||||
|
||||
for _, tt := range tests {
|
||||
for i, pmf := range pmfs {
|
||||
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
|
||||
// Create test user
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
}
|
||||
users := []types.User{user}
|
||||
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(
|
||||
pm,
|
||||
node.View(),
|
||||
tt.currentApproved,
|
||||
tt.announcedRoutes,
|
||||
)
|
||||
|
||||
assert.Equal(t, tt.wantChanged, gotChanged)
|
||||
|
||||
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
currentApproved := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
}
|
||||
announcedRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
}
|
||||
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: currentApproved,
|
||||
}
|
||||
|
||||
// With nil policy manager, should return current approved unchanged
|
||||
gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes)
|
||||
|
||||
assert.False(t, gotChanged)
|
||||
assert.Equal(t, currentApproved, gotApproved)
|
||||
}
|
@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
|
||||
canApprove: false,
|
||||
},
|
||||
{
|
||||
name: "policy-without-autoApprovers-section",
|
||||
node: normalNode,
|
||||
route: p("10.33.0.0/16"),
|
||||
policy: `{
|
||||
"groups": {
|
||||
"group:admin": ["user1@"]
|
||||
},
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admin"],
|
||||
"dst": ["group:admin:*"]
|
||||
},
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["group:admin"],
|
||||
"dst": ["10.33.0.0/16:*"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
canApprove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
// The fast path is that a node requests to approve a prefix
|
||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||
// check and return quickly
|
||||
if _, ok := pm.autoApproveMap[route]; ok {
|
||||
if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) {
|
||||
if approvers, ok := pm.autoApproveMap[route]; ok {
|
||||
canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains)
|
||||
if canApprove {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) {
|
||||
canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains)
|
||||
if canApprove {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
@ -112,6 +111,15 @@ func (m *mapSession) serve() {
|
||||
// This is the mechanism where the node gives us information about its
|
||||
// current configuration.
|
||||
//
|
||||
// Process the MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
m.h.Change(c)
|
||||
|
||||
// If OmitPeers is true and Stream is false
|
||||
// then the server will let clients update their endpoints without
|
||||
// breaking existing long-polling (Stream == true) connections.
|
||||
@ -122,14 +130,6 @@ func (m *mapSession) serve() {
|
||||
// the response and just wants a 200.
|
||||
// !req.stream && req.OmitPeers
|
||||
if m.isEndpointUpdate() {
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
m.h.Change(c)
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
}
|
||||
@ -142,6 +142,8 @@ func (m *mapSession) serve() {
|
||||
func (m *mapSession) serveLongPoll() {
|
||||
m.beforeServeLongPoll()
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("starting long poll session chan(%p)", m.ch)
|
||||
|
||||
// Clean up the session when the client disconnects
|
||||
defer func() {
|
||||
m.cancelChMu.Lock()
|
||||
@ -149,18 +151,26 @@ func (m *mapSession) serveLongPoll() {
|
||||
close(m.cancelCh)
|
||||
m.cancelChMu.Unlock()
|
||||
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removing session from batcher chan(%p)", m.ch)
|
||||
|
||||
// Validate if we are actually closing the current session or
|
||||
// if the connection has been replaced. If the connection has been replaced,
|
||||
// do not run the rest of the disconnect logic.
|
||||
if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) {
|
||||
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removed from batcher chan(%p)", m.ch)
|
||||
// First update NodeStore to mark the node as offline
|
||||
// This ensures the state is consistent before notifying the batcher
|
||||
disconnectChange, err := m.h.state.Disconnect(m.node.ID)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
}
|
||||
|
||||
// Send the disconnect change notification
|
||||
m.h.Change(disconnectChange)
|
||||
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
}
|
||||
m.h.Change(disconnectChange)
|
||||
|
||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
||||
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
}()
|
||||
|
||||
// Set up the client stream
|
||||
@ -172,25 +182,37 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||
|
||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
||||
// where the change is sent before the node is in the batcher's node map
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
||||
// Add node to batcher so it can receive updates,
|
||||
// adding this before connecting it to the state ensure that
|
||||
// it does not miss any updates that might be sent in the split
|
||||
// time between the node connecting and the batcher being ready.
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
||||
select {
|
||||
case m.ch <- &tailcfg.MapResponse{}:
|
||||
default:
|
||||
// Channel might be closed
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
||||
// but we still need to update routes and other state-level changes
|
||||
connectChange := m.h.state.Connect(m.node)
|
||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
||||
m.h.Change(connectChange)
|
||||
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
|
||||
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
|
||||
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
|
||||
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
|
||||
// from AvailableRoutes.
|
||||
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||
if err != nil {
|
||||
m.errf(err, "failed to update node from initial MapRequest")
|
||||
return
|
||||
}
|
||||
m.h.Change(mapReqChange)
|
||||
|
||||
// Connect the node after its state has been updated.
|
||||
// We send two separate change notifications because these are distinct operations:
|
||||
// 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo)
|
||||
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
|
||||
// While this results in two notifications, it ensures route data is synchronized before
|
||||
// primary route selection occurs, which is critical for proper HA subnet router failover.
|
||||
connectChange := m.h.state.Connect(m.node.ID)
|
||||
m.h.Change(connectChange)
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
|
@ -1,11 +1,15 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
@ -21,6 +25,56 @@ const (
|
||||
update = 3
|
||||
)
|
||||
|
||||
const prometheusNamespace = "headscale"
|
||||
|
||||
var (
|
||||
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_operations_total",
|
||||
Help: "Total number of NodeStore operations",
|
||||
}, []string{"operation"})
|
||||
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_operation_duration_seconds",
|
||||
Help: "Duration of NodeStore operations",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, []string{"operation"})
|
||||
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_batch_size",
|
||||
Help: "Size of NodeStore write batches",
|
||||
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
|
||||
})
|
||||
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_batch_duration_seconds",
|
||||
Help: "Duration of NodeStore batch processing",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_snapshot_build_duration_seconds",
|
||||
Help: "Duration of NodeStore snapshot building from nodes",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_nodes_total",
|
||||
Help: "Total number of nodes in the NodeStore",
|
||||
})
|
||||
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_peers_calculation_duration_seconds",
|
||||
Help: "Duration of peers calculation in NodeStore",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_queue_depth",
|
||||
Help: "Current depth of NodeStore write queue",
|
||||
})
|
||||
)
|
||||
|
||||
// NodeStore is a thread-safe store for nodes.
|
||||
// It is a copy-on-write structure, replacing the "snapshot"
|
||||
// when a change to the structure occurs. It is optimised for reads,
|
||||
@ -29,13 +83,14 @@ const (
|
||||
// changes rapidly.
|
||||
//
|
||||
// Writes will block until committed, while reads are never
|
||||
// blocked.
|
||||
// blocked. This means that the caller of a write operation
|
||||
// is responsible for ensuring an update depending on a write
|
||||
// is not issued before the write is complete.
|
||||
type NodeStore struct {
|
||||
data atomic.Pointer[Snapshot]
|
||||
|
||||
peersFunc PeersFunc
|
||||
writeQueue chan work
|
||||
// TODO: metrics
|
||||
}
|
||||
|
||||
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||
@ -50,9 +105,17 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||
}
|
||||
store.data.Store(&snap)
|
||||
|
||||
// Initialize node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Snapshot is the representation of the current state of the NodeStore.
|
||||
// It contains all nodes and their relationships.
|
||||
// It is a copy-on-write structure, meaning that when a write occurs,
|
||||
// a new Snapshot is created with the updated state,
|
||||
// and replaces the old one atomically.
|
||||
type Snapshot struct {
|
||||
// nodesByID is the main source of truth for nodes.
|
||||
nodesByID map[types.NodeID]types.Node
|
||||
@ -64,15 +127,19 @@ type Snapshot struct {
|
||||
allNodes []types.NodeView
|
||||
}
|
||||
|
||||
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||
// with the relationships between nodes and their peers.
|
||||
// This will typically be used to calculate which nodes can see each other
|
||||
// based on the current policy.
|
||||
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
|
||||
// work represents a single operation to be performed on the NodeStore.
|
||||
type work struct {
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
immediate bool // For operations that need immediate processing
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
@ -80,6 +147,9 @@ type work struct {
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
@ -87,8 +157,12 @@ func (s *NodeStore) PutNode(n types.Node) {
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
@ -96,7 +170,21 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
|
||||
// UpdateNode applies a function to modify a specific node in the store.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
// This is analogous to a database "transaction", or, the caller should
|
||||
// rather collect all data they want to change, and then call this function.
|
||||
// Fewer calls are better.
|
||||
//
|
||||
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||
// This is because the main nodesByID map contains the struct, and every other map is using a
|
||||
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
|
||||
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||
// on all read operations.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
@ -104,48 +192,47 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
}
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
// UpdateNodeImmediate applies a function to modify a specific node in the store
|
||||
// with immediate processing (bypassing normal batching delays).
|
||||
// Use this for time-sensitive updates like online status changes.
|
||||
func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
immediate: true,
|
||||
}
|
||||
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: del,
|
||||
nodeID: id,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("delete").Inc()
|
||||
}
|
||||
|
||||
// Start initializes the NodeStore and starts processing the write queue.
|
||||
func (s *NodeStore) Start() {
|
||||
s.writeQueue = make(chan work)
|
||||
go s.processWrite()
|
||||
}
|
||||
|
||||
// Stop stops the NodeStore and closes the write queue.
|
||||
func (s *NodeStore) Stop() {
|
||||
close(s.writeQueue)
|
||||
}
|
||||
|
||||
// processWrite processes the write queue in batches.
|
||||
// It collects writes into batches and applies them periodically.
|
||||
func (s *NodeStore) processWrite() {
|
||||
c := time.NewTicker(batchTimeout)
|
||||
batch := make([]work, 0, batchSize)
|
||||
@ -157,13 +244,7 @@ func (s *NodeStore) processWrite() {
|
||||
c.Stop()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle immediate operations right away
|
||||
if w.immediate {
|
||||
s.applyBatch([]work{w})
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= batchSize {
|
||||
s.applyBatch(batch)
|
||||
@ -181,7 +262,22 @@ func (s *NodeStore) processWrite() {
|
||||
}
|
||||
}
|
||||
|
||||
// applyBatch applies a batch of work to the node store.
|
||||
// This means that it takes a copy of the current nodes,
|
||||
// then applies the batch of operations to that copy,
|
||||
// runs any precomputation needed (like calculating peers),
|
||||
// and finally replaces the snapshot in the store with the new one.
|
||||
// The replacement of the snapshot is atomic, ensuring that reads
|
||||
// are never blocked by writes.
|
||||
// Each write item is blocked until the batch is applied to ensure
|
||||
// the caller knows the operation is complete and do not send any
|
||||
// updates that are dependent on a read that is yet to be written.
|
||||
func (s *NodeStore) applyBatch(batch []work) {
|
||||
timer := prometheus.NewTimer(nodeStoreBatchDuration)
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreBatchSize.Observe(float64(len(batch)))
|
||||
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
@ -201,15 +297,25 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
}
|
||||
|
||||
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
||||
|
||||
s.data.Store(&newSnap)
|
||||
|
||||
// Update node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotFromNodes creates a new Snapshot from the provided nodes.
|
||||
// It builds a lot of "indexes" to make lookups fast for datasets we
|
||||
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
|
||||
// This is not a fast operation, it is the "slow" part of our copy-on-write
|
||||
// structure, but it allows us to have fast reads and efficient lookups.
|
||||
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
||||
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
allNodes := make([]types.NodeView, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
allNodes = append(allNodes, n.View())
|
||||
@ -219,8 +325,17 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
peersByNode: peersFunc(allNodes),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
|
||||
// peersByNode is most likely the most expensive operation,
|
||||
// it will use the list of all nodes, combined with the
|
||||
// current policy to precalculate which nodes are peers and
|
||||
// can see each other.
|
||||
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||
defer peersTimer.ObserveDuration()
|
||||
return peersFunc(allNodes)
|
||||
}(),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
}
|
||||
|
||||
// Build nodesByUser and nodesByNodeKey maps
|
||||
@ -234,9 +349,21 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
}
|
||||
|
||||
// GetNode retrieves a node by its ID.
|
||||
func (s *NodeStore) GetNode(id types.NodeID) types.NodeView {
|
||||
n := s.data.Load().nodesByID[id]
|
||||
return n.View()
|
||||
// The bool indicates if the node exists or is available (like "err not found").
|
||||
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
|
||||
// it isn't an invalid node (this is more of a node error or node is broken).
|
||||
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get").Inc()
|
||||
|
||||
n, exists := s.data.Load().nodesByID[id]
|
||||
if !exists {
|
||||
return types.NodeView{}, false
|
||||
}
|
||||
|
||||
return n.View(), true
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
||||
@ -306,15 +433,30 @@ func (s *NodeStore) DebugString() string {
|
||||
|
||||
// ListNodes returns a slice of all nodes in the store.
|
||||
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().allNodes)
|
||||
}
|
||||
|
||||
// ListPeers returns a slice of all peers for a given node ID.
|
||||
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list_peers").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().peersByNode[id])
|
||||
}
|
||||
|
||||
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
||||
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
|
||||
|
||||
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
return make(map[types.NodeID][]types.NodeView)
|
||||
}
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
@ -61,6 +62,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 1, "user1", "node2"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
@ -85,6 +87,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 1, "user1", "node3"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
@ -113,6 +116,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
4: createTestNode(4, 4, "user4", "node4"),
|
||||
}
|
||||
peersFunc := oddEvenPeersFunc
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
@ -191,6 +195,7 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
@ -214,6 +219,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
@ -329,6 +335,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
initialNodes := types.Nodes{&node1, &node2, &node3}
|
||||
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool {
|
||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -104,6 +104,7 @@ type Node struct {
|
||||
// headscale. It is best effort and not persisted.
|
||||
LastSeen *time.Time `gorm:"column:last_seen"`
|
||||
|
||||
|
||||
// ApprovedRoutes is a list of routes that the node is allowed to announce
|
||||
// as a subnet router. They are not necessarily the routes that the node
|
||||
// announces at the moment.
|
||||
@ -420,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix {
|
||||
}
|
||||
|
||||
// SubnetRoutes returns the list of routes that the node announces and are approved.
|
||||
//
|
||||
// IMPORTANT: This method is used for internal data structures and should NOT be used
|
||||
// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually
|
||||
// with PrimaryRoutes to ensure it includes only routes actively served by the node.
|
||||
// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto.
|
||||
func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
var routes []netip.Prefix
|
||||
|
||||
@ -525,7 +531,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
||||
}
|
||||
|
||||
node.Hostname = hostInfo.Hostname
|
||||
|
||||
|
||||
log.Trace().
|
||||
Str("node_id", node.ID.String()).
|
||||
Str("new_hostname", node.Hostname).
|
||||
|
Loading…
Reference in New Issue
Block a user