diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 2b245b58..ef28a955 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -40,7 +40,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Write(pol) })) debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - filter := h.polMan.Filter() + filter, _ := h.polMan.Filter() filterJSON, err := json.MarshalIndent(filter, "", " ") if err != nil { diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index e49057e7..662e491c 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -536,7 +536,7 @@ func appendPeerChanges( changed types.Nodes, cfg *types.Config, ) error { - filter := polMan.Filter() + filter, matchers := polMan.Filter() sshPolicy, err := polMan.SSHPolicy(node) if err != nil { @@ -546,7 +546,7 @@ func appendPeerChanges( // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. if len(filter) > 0 { - changed = policy.FilterNodesByACL(node, changed, filter) + changed = policy.FilterNodesByACL(node, changed, matchers) } profiles := generateUserProfiles(node, changed) diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index 2b86416e..1d4f09d2 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -13,6 +13,14 @@ type Match struct { dests *netipx.IPSet } +func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { + matches := make([]Match, 0, len(rules)) + for _, rule := range rules { + matches = append(matches, MatchFromFilterRule(rule)) + } + return matches +} + func MatchFromFilterRule(rule tailcfg.FilterRule) Match { dests := []string{} for _, dest := range rule.DstPorts { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 29b55fc1..0df1bcc4 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -1,6 +1,7 @@ package policy import ( + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" @@ -15,7 +16,8 @@ var ( ) type PolicyManager interface { - Filter() []tailcfg.FilterRule + // Filter returns the current filter rules for the entire tailnet and the associated matchers. + Filter() ([]tailcfg.FilterRule, []matcher.Match) SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) SetPolicy([]byte) (bool, error) SetUsers(users []types.User) (bool, error) diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index ba375beb..d86de29b 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -1,6 +1,7 @@ package policy import ( + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "slices" @@ -15,7 +16,7 @@ import ( func FilterNodesByACL( node *types.Node, nodes types.Nodes, - filter []tailcfg.FilterRule, + matchers []matcher.Match, ) types.Nodes { var result types.Nodes @@ -24,7 +25,7 @@ func FilterNodesByACL( continue } - if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { + if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) { result = append(result, peer) } } diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index cfd38765..cebda65f 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -2,6 +2,7 @@ package policy import ( "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "testing" @@ -769,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) { var err error pm, err = pmf(users, append(tt.peers, tt.node)) require.NoError(t, err) - got := pm.Filter() + got, _ := pm.Filter() got = ReduceFilterRules(tt.node, got) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + matchers := matcher.MatchesFromFilterRules(tt.args.rules) got := FilterNodesByACL( tt.args.node, tt.args.nodes, - tt.args.rules, + matchers, ) if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) diff --git a/hscontrol/policy/v1/policy.go b/hscontrol/policy/v1/policy.go index 89625ce3..c2e9520a 100644 --- a/hscontrol/policy/v1/policy.go +++ b/hscontrol/policy/v1/policy.go @@ -2,6 +2,7 @@ package v1 import ( "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "io" "net/netip" "os" @@ -88,10 +89,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return true, nil } -func (pm *PolicyManager) Filter() []tailcfg.FilterRule { +func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { pm.mu.Lock() defer pm.mu.Unlock() - return pm.filter + return pm.filter, matcher.MatchesFromFilterRules(pm.filter) } func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { diff --git a/hscontrol/policy/v1/policy_test.go b/hscontrol/policy/v1/policy_test.go index e250db2a..c9f98079 100644 --- a/hscontrol/policy/v1/policy_test.go +++ b/hscontrol/policy/v1/policy_test.go @@ -1,6 +1,7 @@ package v1 import ( + "github.com/juanfont/headscale/hscontrol/policy/matcher" "testing" "github.com/google/go-cmp/cmp" @@ -27,6 +28,7 @@ func TestPolicySetChange(t *testing.T) { wantNodesChange bool wantPolicyChange bool wantFilter []tailcfg.FilterRule + wantMatchers []matcher.Match }{ { name: "set-nodes", @@ -42,6 +44,9 @@ func TestPolicySetChange(t *testing.T) { DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, }, }, + wantMatchers: []matcher.Match{ + matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}), + }, }, { name: "set-users", @@ -52,6 +57,9 @@ func TestPolicySetChange(t *testing.T) { DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, }, }, + wantMatchers: []matcher.Match{ + matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}), + }, }, { name: "set-users-and-node", @@ -70,6 +78,9 @@ func TestPolicySetChange(t *testing.T) { DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, }, }, + wantMatchers: []matcher.Match{ + matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}), + }, }, { name: "set-policy", @@ -95,6 +106,9 @@ func TestPolicySetChange(t *testing.T) { DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, }, }, + wantMatchers: []matcher.Match{ + matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}), + }, }, } @@ -150,8 +164,16 @@ func TestPolicySetChange(t *testing.T) { assert.Equal(t, tt.wantNodesChange, change) } - if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { - t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + filter, matchers := pm.Filter() + if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { + t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff) + } + if diff := cmp.Diff( + tt.wantMatchers, + matchers, + cmp.AllowUnexported(matcher.Match{}), + ); diff != "" { + t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff) } }) } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 4060b6a6..2bc04dbc 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -7,6 +7,8 @@ import ( "strings" "sync" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "slices" "github.com/juanfont/headscale/hscontrol/types" @@ -24,6 +26,7 @@ type PolicyManager struct { filterHash deephash.Sum filter []tailcfg.FilterRule + matchers []matcher.Match tagOwnerMapHash deephash.Sum tagOwnerMap map[Tag]*netipx.IPSet @@ -62,15 +65,24 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM // updateLocked updates the filter rules based on the current policy and nodes. // It must be called with the lock held. func (pm *PolicyManager) updateLocked() (bool, error) { + // Clear the SSH policy map to ensure it's recalculated with the new policy. + // TODO(kradalby): This could potentially be optimized by only clearing the + // policies for nodes that have changed. Particularly if the only difference is + // that nodes has been added or removed. + defer clear(pm.sshPolicyMap) + filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes) if err != nil { return false, fmt.Errorf("compiling filter rules: %w", err) } filterHash := deephash.Hash(&filter) - filterChanged := filterHash == pm.filterHash + filterChanged := filterHash != pm.filterHash pm.filter = filter pm.filterHash = filterHash + if filterChanged { + pm.matchers = matcher.MatchesFromFilterRules(pm.filter) + } // Order matters, tags might be used in autoapprovers, so we need to ensure // that the map for tag owners is resolved before resolving autoapprovers. @@ -100,12 +112,6 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return false, nil } - // Clear the SSH policy map to ensure it's recalculated with the new policy. - // TODO(kradalby): This could potentially be optimized by only clearing the - // policies for nodes that have changed. Particularly if the only difference is - // that nodes has been added or removed. - clear(pm.sshPolicyMap) - return true, nil } @@ -144,11 +150,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { return pm.updateLocked() } -// Filter returns the current filter rules for the entire tailnet. -func (pm *PolicyManager) Filter() []tailcfg.FilterRule { +// Filter returns the current filter rules for the entire tailnet and the associated matchers. +func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { pm.mu.Lock() defer pm.mu.Unlock() - return pm.filter + return pm.filter, pm.matchers } // SetUsers updates the users in the policy manager and updates the filter rules. diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index ee26c596..b61c5758 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -1,6 +1,7 @@ package v2 import ( + "github.com/juanfont/headscale/hscontrol/policy/matcher" "testing" "github.com/google/go-cmp/cmp" @@ -29,16 +30,18 @@ func TestPolicyManager(t *testing.T) { } tests := []struct { - name string - pol string - nodes types.Nodes - wantFilter []tailcfg.FilterRule + name string + pol string + nodes types.Nodes + wantFilter []tailcfg.FilterRule + wantMatchers []matcher.Match }{ { - name: "empty-policy", - pol: "{}", - nodes: types.Nodes{}, - wantFilter: nil, + name: "empty-policy", + pol: "{}", + nodes: types.Nodes{}, + wantFilter: nil, + wantMatchers: []matcher.Match{}, }, } @@ -47,9 +50,16 @@ func TestPolicyManager(t *testing.T) { pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) require.NoError(t, err) - filter := pm.Filter() - if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { - t.Errorf("Filter() mismatch (-want +got):\n%s", diff) + filter, matchers := pm.Filter() + if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { + t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff( + tt.wantMatchers, + matchers, + cmp.AllowUnexported(matcher.Match{}), + ); diff != "" { + t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff) } // TODO(kradalby): Test SSH Policy diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 3567c4f1..826867eb 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -270,18 +270,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) { } } -func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { +func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { src := node.IPs() allowedIPs := node2.IPs() - // TODO(kradalby): Regenerate this every time the filter change, instead of - // every time we use it. - // Part of #2416 - matchers := make([]matcher.Match, len(filter)) - for i, rule := range filter { - matchers[i] = matcher.MatchFromFilterRule(rule) - } - for _, matcher := range matchers { if !matcher.SrcsContainsIPs(src...) { continue diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 702fa251..c7261587 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "strings" "testing" @@ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.node1.CanAccess(tt.rules, &tt.node2) + matchers := matcher.MatchesFromFilterRules(tt.rules) + got := tt.node1.CanAccess(matchers, &tt.node2) if got != tt.want { t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)