diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 7a297bd3..4da94c31 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -529,7 +529,8 @@ func appendPeerChanges( // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. if len(filter) > 0 { - changed = policy.FilterNodesByACL(node, changed, filter) + matchers := polMan.Matchers() + changed = policy.FilterNodesByACL(node, changed, matchers) } profiles := generateUserProfiles(node, changed) diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index 2b86416e..1d4f09d2 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -13,6 +13,14 @@ type Match struct { dests *netipx.IPSet } +func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { + matches := make([]Match, 0, len(rules)) + for _, rule := range rules { + matches = append(matches, MatchFromFilterRule(rule)) + } + return matches +} + func MatchFromFilterRule(rule tailcfg.FilterRule) Match { dests := []string{} for _, dest := range rule.DstPorts { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 24f68ca1..5df7da76 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -1,6 +1,7 @@ package policy import ( + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" @@ -16,6 +17,8 @@ var ( type PolicyManager interface { Filter() []tailcfg.FilterRule + // Matchers returns the matchers for the current filter rules. + Matchers() []matcher.Match SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) SetPolicy([]byte) (bool, error) SetUsers(users []types.User) (bool, error) diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index ba375beb..d86de29b 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -1,6 +1,7 @@ package policy import ( + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "slices" @@ -15,7 +16,7 @@ import ( func FilterNodesByACL( node *types.Node, nodes types.Nodes, - filter []tailcfg.FilterRule, + matchers []matcher.Match, ) types.Nodes { var result types.Nodes @@ -24,7 +25,7 @@ func FilterNodesByACL( continue } - if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { + if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) { result = append(result, peer) } } diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index cfd38765..597172fb 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -2,6 +2,7 @@ package policy import ( "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "testing" @@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + matchers := matcher.MatchesFromFilterRules(tt.args.rules) got := FilterNodesByACL( tt.args.node, tt.args.nodes, - tt.args.rules, + matchers, ) if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) diff --git a/hscontrol/policy/v1/policy.go b/hscontrol/policy/v1/policy.go index 0ac49d04..43efba5d 100644 --- a/hscontrol/policy/v1/policy.go +++ b/hscontrol/policy/v1/policy.go @@ -2,6 +2,7 @@ package v1 import ( "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "io" "net/netip" "os" @@ -92,6 +93,11 @@ func (pm *PolicyManager) Filter() []tailcfg.FilterRule { return pm.filter } +func (pm *PolicyManager) Matchers() []matcher.Match { + filter := pm.Filter() + return matcher.MatchesFromFilterRules(filter) +} + func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { pm.mu.Lock() defer pm.mu.Unlock() diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 41f51487..8fbedd06 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -3,6 +3,7 @@ package v2 import ( "encoding/json" "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "strings" "sync" @@ -22,6 +23,7 @@ type PolicyManager struct { filterHash deephash.Sum filter []tailcfg.FilterRule + matchers []matcher.Match tagOwnerMapHash deephash.Sum tagOwnerMap map[Tag]*netipx.IPSet @@ -69,6 +71,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) { filterChanged := filterHash == pm.filterHash pm.filter = filter pm.filterHash = filterHash + if filterChanged { + pm.matchers = matcher.MatchesFromFilterRules(pm.filter) + } // Order matters, tags might be used in autoapprovers, so we need to ensure // that the map for tag owners is resolved before resolving autoapprovers. @@ -149,6 +154,12 @@ func (pm *PolicyManager) Filter() []tailcfg.FilterRule { return pm.filter } +func (pm *PolicyManager) Matchers() []matcher.Match { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.matchers +} + // SetUsers updates the users in the policy manager and updates the filter rules. func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index c333a148..ebbfdb8b 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -258,18 +258,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) { } } -func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { +func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { src := node.IPs() allowedIPs := node2.IPs() - // TODO(kradalby): Regenerate this every time the filter change, instead of - // every time we use it. - // Part of #2416 - matchers := make([]matcher.Match, len(filter)) - for i, rule := range filter { - matchers[i] = matcher.MatchFromFilterRule(rule) - } - for _, matcher := range matchers { if !matcher.SrcsContainsIPs(src...) { continue diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index d439d483..9b7c23a9 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "strings" "testing" @@ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.node1.CanAccess(tt.rules, &tt.node2) + matchers := matcher.MatchesFromFilterRules(tt.rules) + got := tt.node1.CanAccess(matchers, &tt.node2) if got != tt.want { t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)