mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Make matchers part of the Policy interface (#2514)
* Make matchers part of the Policy interface * Prevent race condition between rules and matchers * Test also matchers in tests for Policy.Filter * Compute `filterChanged` in v2 policy correctly * Fix nil vs. empty list issue in v2 policy test * policy/v2: always clear ssh map Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> Co-authored-by: Aras Ergus <aras.ergus@tngtech.com> Co-authored-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									eb1ecefd9e
								
							
						
					
					
						commit
						4651d06fa8
					
				@ -40,7 +40,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
 | 
				
			|||||||
		w.Write(pol)
 | 
							w.Write(pol)
 | 
				
			||||||
	}))
 | 
						}))
 | 
				
			||||||
	debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
		filter := h.polMan.Filter()
 | 
							filter, _ := h.polMan.Filter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		filterJSON, err := json.MarshalIndent(filter, "", "  ")
 | 
							filterJSON, err := json.MarshalIndent(filter, "", "  ")
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -536,7 +536,7 @@ func appendPeerChanges(
 | 
				
			|||||||
	changed types.Nodes,
 | 
						changed types.Nodes,
 | 
				
			||||||
	cfg *types.Config,
 | 
						cfg *types.Config,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
	filter := polMan.Filter()
 | 
						filter, matchers := polMan.Filter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sshPolicy, err := polMan.SSHPolicy(node)
 | 
						sshPolicy, err := polMan.SSHPolicy(node)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -546,7 +546,7 @@ func appendPeerChanges(
 | 
				
			|||||||
	// If there are filter rules present, see if there are any nodes that cannot
 | 
						// If there are filter rules present, see if there are any nodes that cannot
 | 
				
			||||||
	// access each-other at all and remove them from the peers.
 | 
						// access each-other at all and remove them from the peers.
 | 
				
			||||||
	if len(filter) > 0 {
 | 
						if len(filter) > 0 {
 | 
				
			||||||
		changed = policy.FilterNodesByACL(node, changed, filter)
 | 
							changed = policy.FilterNodesByACL(node, changed, matchers)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	profiles := generateUserProfiles(node, changed)
 | 
						profiles := generateUserProfiles(node, changed)
 | 
				
			||||||
 | 
				
			|||||||
@ -13,6 +13,14 @@ type Match struct {
 | 
				
			|||||||
	dests *netipx.IPSet
 | 
						dests *netipx.IPSet
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
 | 
				
			||||||
 | 
						matches := make([]Match, 0, len(rules))
 | 
				
			||||||
 | 
						for _, rule := range rules {
 | 
				
			||||||
 | 
							matches = append(matches, MatchFromFilterRule(rule))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return matches
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
 | 
					func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
 | 
				
			||||||
	dests := []string{}
 | 
						dests := []string{}
 | 
				
			||||||
	for _, dest := range rule.DstPorts {
 | 
						for _, dest := range rule.DstPorts {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
package policy
 | 
					package policy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
 | 
						policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
 | 
				
			||||||
@ -15,7 +16,8 @@ var (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type PolicyManager interface {
 | 
					type PolicyManager interface {
 | 
				
			||||||
	Filter() []tailcfg.FilterRule
 | 
						// Filter returns the current filter rules for the entire tailnet and the associated matchers.
 | 
				
			||||||
 | 
						Filter() ([]tailcfg.FilterRule, []matcher.Match)
 | 
				
			||||||
	SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
 | 
						SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
 | 
				
			||||||
	SetPolicy([]byte) (bool, error)
 | 
						SetPolicy([]byte) (bool, error)
 | 
				
			||||||
	SetUsers(users []types.User) (bool, error)
 | 
						SetUsers(users []types.User) (bool, error)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
package policy
 | 
					package policy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"slices"
 | 
						"slices"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -15,7 +16,7 @@ import (
 | 
				
			|||||||
func FilterNodesByACL(
 | 
					func FilterNodesByACL(
 | 
				
			||||||
	node *types.Node,
 | 
						node *types.Node,
 | 
				
			||||||
	nodes types.Nodes,
 | 
						nodes types.Nodes,
 | 
				
			||||||
	filter []tailcfg.FilterRule,
 | 
						matchers []matcher.Match,
 | 
				
			||||||
) types.Nodes {
 | 
					) types.Nodes {
 | 
				
			||||||
	var result types.Nodes
 | 
						var result types.Nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -24,7 +25,7 @@ func FilterNodesByACL(
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
 | 
							if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
 | 
				
			||||||
			result = append(result, peer)
 | 
								result = append(result, peer)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package policy
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -769,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) {
 | 
				
			|||||||
				var err error
 | 
									var err error
 | 
				
			||||||
				pm, err = pmf(users, append(tt.peers, tt.node))
 | 
									pm, err = pmf(users, append(tt.peers, tt.node))
 | 
				
			||||||
				require.NoError(t, err)
 | 
									require.NoError(t, err)
 | 
				
			||||||
				got := pm.Filter()
 | 
									got, _ := pm.Filter()
 | 
				
			||||||
				got = ReduceFilterRules(tt.node, got)
 | 
									got = ReduceFilterRules(tt.node, got)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if diff := cmp.Diff(tt.want, got); diff != "" {
 | 
									if diff := cmp.Diff(tt.want, got); diff != "" {
 | 
				
			||||||
@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								matchers := matcher.MatchesFromFilterRules(tt.args.rules)
 | 
				
			||||||
			got := FilterNodesByACL(
 | 
								got := FilterNodesByACL(
 | 
				
			||||||
				tt.args.node,
 | 
									tt.args.node,
 | 
				
			||||||
				tt.args.nodes,
 | 
									tt.args.nodes,
 | 
				
			||||||
				tt.args.rules,
 | 
									matchers,
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
			if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
 | 
								if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
 | 
				
			||||||
				t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
 | 
									t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package v1
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
@ -88,10 +89,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
 | 
				
			|||||||
	return true, nil
 | 
						return true, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
 | 
					func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
 | 
				
			||||||
	pm.mu.Lock()
 | 
						pm.mu.Lock()
 | 
				
			||||||
	defer pm.mu.Unlock()
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
	return pm.filter
 | 
						return pm.filter, matcher.MatchesFromFilterRules(pm.filter)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
 | 
					func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
package v1
 | 
					package v1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/google/go-cmp/cmp"
 | 
						"github.com/google/go-cmp/cmp"
 | 
				
			||||||
@ -27,6 +28,7 @@ func TestPolicySetChange(t *testing.T) {
 | 
				
			|||||||
		wantNodesChange  bool
 | 
							wantNodesChange  bool
 | 
				
			||||||
		wantPolicyChange bool
 | 
							wantPolicyChange bool
 | 
				
			||||||
		wantFilter       []tailcfg.FilterRule
 | 
							wantFilter       []tailcfg.FilterRule
 | 
				
			||||||
 | 
							wantMatchers     []matcher.Match
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "set-nodes",
 | 
								name: "set-nodes",
 | 
				
			||||||
@ -42,6 +44,9 @@ func TestPolicySetChange(t *testing.T) {
 | 
				
			|||||||
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
								wantMatchers: []matcher.Match{
 | 
				
			||||||
 | 
									matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:            "set-users",
 | 
								name:            "set-users",
 | 
				
			||||||
@ -52,6 +57,9 @@ func TestPolicySetChange(t *testing.T) {
 | 
				
			|||||||
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
								wantMatchers: []matcher.Match{
 | 
				
			||||||
 | 
									matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:  "set-users-and-node",
 | 
								name:  "set-users-and-node",
 | 
				
			||||||
@ -70,6 +78,9 @@ func TestPolicySetChange(t *testing.T) {
 | 
				
			|||||||
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
								wantMatchers: []matcher.Match{
 | 
				
			||||||
 | 
									matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "set-policy",
 | 
								name: "set-policy",
 | 
				
			||||||
@ -95,6 +106,9 @@ func TestPolicySetChange(t *testing.T) {
 | 
				
			|||||||
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
 | 
										DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
								wantMatchers: []matcher.Match{
 | 
				
			||||||
 | 
									matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -150,8 +164,16 @@ func TestPolicySetChange(t *testing.T) {
 | 
				
			|||||||
				assert.Equal(t, tt.wantNodesChange, change)
 | 
									assert.Equal(t, tt.wantNodesChange, change)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
 | 
								filter, matchers := pm.Filter()
 | 
				
			||||||
				t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
 | 
								if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
 | 
				
			||||||
 | 
									t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if diff := cmp.Diff(
 | 
				
			||||||
 | 
									tt.wantMatchers,
 | 
				
			||||||
 | 
									matchers,
 | 
				
			||||||
 | 
									cmp.AllowUnexported(matcher.Match{}),
 | 
				
			||||||
 | 
								); diff != "" {
 | 
				
			||||||
 | 
									t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,8 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"slices"
 | 
						"slices"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						"github.com/juanfont/headscale/hscontrol/types"
 | 
				
			||||||
@ -24,6 +26,7 @@ type PolicyManager struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	filterHash deephash.Sum
 | 
						filterHash deephash.Sum
 | 
				
			||||||
	filter     []tailcfg.FilterRule
 | 
						filter     []tailcfg.FilterRule
 | 
				
			||||||
 | 
						matchers   []matcher.Match
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tagOwnerMapHash deephash.Sum
 | 
						tagOwnerMapHash deephash.Sum
 | 
				
			||||||
	tagOwnerMap     map[Tag]*netipx.IPSet
 | 
						tagOwnerMap     map[Tag]*netipx.IPSet
 | 
				
			||||||
@ -62,15 +65,24 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM
 | 
				
			|||||||
// updateLocked updates the filter rules based on the current policy and nodes.
 | 
					// updateLocked updates the filter rules based on the current policy and nodes.
 | 
				
			||||||
// It must be called with the lock held.
 | 
					// It must be called with the lock held.
 | 
				
			||||||
func (pm *PolicyManager) updateLocked() (bool, error) {
 | 
					func (pm *PolicyManager) updateLocked() (bool, error) {
 | 
				
			||||||
 | 
						// Clear the SSH policy map to ensure it's recalculated with the new policy.
 | 
				
			||||||
 | 
						// TODO(kradalby): This could potentially be optimized by only clearing the
 | 
				
			||||||
 | 
						// policies for nodes that have changed. Particularly if the only difference is
 | 
				
			||||||
 | 
						// that nodes has been added or removed.
 | 
				
			||||||
 | 
						defer clear(pm.sshPolicyMap)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
 | 
						filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return false, fmt.Errorf("compiling filter rules: %w", err)
 | 
							return false, fmt.Errorf("compiling filter rules: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	filterHash := deephash.Hash(&filter)
 | 
						filterHash := deephash.Hash(&filter)
 | 
				
			||||||
	filterChanged := filterHash == pm.filterHash
 | 
						filterChanged := filterHash != pm.filterHash
 | 
				
			||||||
	pm.filter = filter
 | 
						pm.filter = filter
 | 
				
			||||||
	pm.filterHash = filterHash
 | 
						pm.filterHash = filterHash
 | 
				
			||||||
 | 
						if filterChanged {
 | 
				
			||||||
 | 
							pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Order matters, tags might be used in autoapprovers, so we need to ensure
 | 
						// Order matters, tags might be used in autoapprovers, so we need to ensure
 | 
				
			||||||
	// that the map for tag owners is resolved before resolving autoapprovers.
 | 
						// that the map for tag owners is resolved before resolving autoapprovers.
 | 
				
			||||||
@ -100,12 +112,6 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
 | 
				
			|||||||
		return false, nil
 | 
							return false, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Clear the SSH policy map to ensure it's recalculated with the new policy.
 | 
					 | 
				
			||||||
	// TODO(kradalby): This could potentially be optimized by only clearing the
 | 
					 | 
				
			||||||
	// policies for nodes that have changed. Particularly if the only difference is
 | 
					 | 
				
			||||||
	// that nodes has been added or removed.
 | 
					 | 
				
			||||||
	clear(pm.sshPolicyMap)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return true, nil
 | 
						return true, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -144,11 +150,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
 | 
				
			|||||||
	return pm.updateLocked()
 | 
						return pm.updateLocked()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Filter returns the current filter rules for the entire tailnet.
 | 
					// Filter returns the current filter rules for the entire tailnet and the associated matchers.
 | 
				
			||||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
 | 
					func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
 | 
				
			||||||
	pm.mu.Lock()
 | 
						pm.mu.Lock()
 | 
				
			||||||
	defer pm.mu.Unlock()
 | 
						defer pm.mu.Unlock()
 | 
				
			||||||
	return pm.filter
 | 
						return pm.filter, pm.matchers
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetUsers updates the users in the policy manager and updates the filter rules.
 | 
					// SetUsers updates the users in the policy manager and updates the filter rules.
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
package v2
 | 
					package v2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/google/go-cmp/cmp"
 | 
						"github.com/google/go-cmp/cmp"
 | 
				
			||||||
@ -29,16 +30,18 @@ func TestPolicyManager(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tests := []struct {
 | 
						tests := []struct {
 | 
				
			||||||
		name       string
 | 
							name         string
 | 
				
			||||||
		pol        string
 | 
							pol          string
 | 
				
			||||||
		nodes      types.Nodes
 | 
							nodes        types.Nodes
 | 
				
			||||||
		wantFilter []tailcfg.FilterRule
 | 
							wantFilter   []tailcfg.FilterRule
 | 
				
			||||||
 | 
							wantMatchers []matcher.Match
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:       "empty-policy",
 | 
								name:         "empty-policy",
 | 
				
			||||||
			pol:        "{}",
 | 
								pol:          "{}",
 | 
				
			||||||
			nodes:      types.Nodes{},
 | 
								nodes:        types.Nodes{},
 | 
				
			||||||
			wantFilter: nil,
 | 
								wantFilter:   nil,
 | 
				
			||||||
 | 
								wantMatchers: []matcher.Match{},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -47,9 +50,16 @@ func TestPolicyManager(t *testing.T) {
 | 
				
			|||||||
			pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
 | 
								pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
 | 
				
			||||||
			require.NoError(t, err)
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			filter := pm.Filter()
 | 
								filter, matchers := pm.Filter()
 | 
				
			||||||
			if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
 | 
								if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
 | 
				
			||||||
				t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
 | 
									t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if diff := cmp.Diff(
 | 
				
			||||||
 | 
									tt.wantMatchers,
 | 
				
			||||||
 | 
									matchers,
 | 
				
			||||||
 | 
									cmp.AllowUnexported(matcher.Match{}),
 | 
				
			||||||
 | 
								); diff != "" {
 | 
				
			||||||
 | 
									t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// TODO(kradalby): Test SSH Policy
 | 
								// TODO(kradalby): Test SSH Policy
 | 
				
			||||||
 | 
				
			|||||||
@ -270,18 +270,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
 | 
					func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool {
 | 
				
			||||||
	src := node.IPs()
 | 
						src := node.IPs()
 | 
				
			||||||
	allowedIPs := node2.IPs()
 | 
						allowedIPs := node2.IPs()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): Regenerate this every time the filter change, instead of
 | 
					 | 
				
			||||||
	// every time we use it.
 | 
					 | 
				
			||||||
	// Part of #2416
 | 
					 | 
				
			||||||
	matchers := make([]matcher.Match, len(filter))
 | 
					 | 
				
			||||||
	for i, rule := range filter {
 | 
					 | 
				
			||||||
		matchers[i] = matcher.MatchFromFilterRule(rule)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, matcher := range matchers {
 | 
						for _, matcher := range matchers {
 | 
				
			||||||
		if !matcher.SrcsContainsIPs(src...) {
 | 
							if !matcher.SrcsContainsIPs(src...) {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package types
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
				
			||||||
	"net/netip"
 | 
						"net/netip"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
@ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			got := tt.node1.CanAccess(tt.rules, &tt.node2)
 | 
								matchers := matcher.MatchesFromFilterRules(tt.rules)
 | 
				
			||||||
 | 
								got := tt.node1.CanAccess(matchers, &tt.node2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if got != tt.want {
 | 
								if got != tt.want {
 | 
				
			||||||
				t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)
 | 
									t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user