From 911f0e926b7fa016234da724afc7ec44bff021bc Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 16 Sep 2025 13:30:13 +0200 Subject: [PATCH] policy: only rebuild ssh policy when needed Signed-off-by: Kristoffer Dalby --- hscontrol/policy/v2/policy.go | 52 ++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 4215485a..e1f310ca 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -3,6 +3,7 @@ package v2 import ( "encoding/json" "fmt" + "maps" "net/netip" "slices" "strings" @@ -36,6 +37,7 @@ type PolicyManager struct { autoApproveMapHash deephash.Sum autoApproveMap map[netip.Prefix]*netipx.IPSet + sshPolicyHash deephash.Sum // Lazy map of SSH policies sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy } @@ -67,11 +69,9 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node // updateLocked updates the filter rules based on the current policy and nodes. // It must be called with the lock held. func (pm *PolicyManager) updateLocked() (bool, error) { - // Clear the SSH policy map to ensure it's recalculated with the new policy. - // TODO(kradalby): This could potentially be optimized by only clearing the - // policies for nodes that have changed. Particularly if the only difference is - // that nodes has been added or removed. - clear(pm.sshPolicyMap) + // Save current SSH policy map for comparison + oldSSHPolicyMap := make(map[types.NodeID]*tailcfg.SSHPolicy) + maps.Copy(oldSSHPolicyMap, pm.sshPolicyMap) filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes) if err != nil { @@ -144,8 +144,45 @@ func (pm *PolicyManager) updateLocked() (bool, error) { pm.exitSet = exitSet pm.exitSetHash = exitSetHash - // If neither of the calculated values changed, no need to update nodes - if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged { + // Check if SSH policy compilation results have changed by computing + // SSH policies for all nodes and comparing their hash + var sshPolicyChanged bool + if len(pm.pol.SSHs) > 0 { + newSSHPolicies := make(map[types.NodeID]*tailcfg.SSHPolicy) + for i := 0; i < pm.nodes.Len(); i++ { + node := pm.nodes.At(i) + sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + if err != nil { + // If compilation fails, store nil policy + newSSHPolicies[node.ID()] = nil + } else { + newSSHPolicies[node.ID()] = sshPol + } + } + + sshPolicyHash := deephash.Hash(&newSSHPolicies) + sshPolicyChanged = sshPolicyHash != pm.sshPolicyHash + if sshPolicyChanged { + log.Debug(). + Str("sshPolicy.hash.old", pm.sshPolicyHash.String()[:8]). + Str("sshPolicy.hash.new", sshPolicyHash.String()[:8]). + Msg("SSH policy hash changed") + // Clear and update the SSH policy map with the newly computed policies + clear(pm.sshPolicyMap) + pm.sshPolicyMap = newSSHPolicies + } + pm.sshPolicyHash = sshPolicyHash + } else { + // If no SSH policy is defined, clear the map if it's not already empty + if len(pm.sshPolicyMap) > 0 { + sshPolicyChanged = true + clear(pm.sshPolicyMap) + pm.sshPolicyHash = deephash.Sum{} + } + } + + // If none of the calculated values changed, no need to update nodes + if !filterChanged && !tagOwnerChanged && !autoApproveChanged && !exitSetChanged && !sshPolicyChanged { log.Trace(). Msg("Policy evaluation detected no changes - all hashes match") return false, nil @@ -156,6 +193,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Bool("tagOwners.changed", tagOwnerChanged). Bool("autoApprovers.changed", autoApproveChanged). Bool("exitNodes.changed", exitSetChanged). + Bool("sshPolicy.changed", sshPolicyChanged). Msg("Policy changes require node updates") return true, nil