mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Make matchers part of the Policy interface
This commit is contained in:
		
							parent
							
								
									5a18e91317
								
							
						
					
					
						commit
						42f71b9c06
					
				@ -529,7 +529,8 @@ func appendPeerChanges(
 | 
			
		||||
	// If there are filter rules present, see if there are any nodes that cannot
 | 
			
		||||
	// access each-other at all and remove them from the peers.
 | 
			
		||||
	if len(filter) > 0 {
 | 
			
		||||
		changed = policy.FilterNodesByACL(node, changed, filter)
 | 
			
		||||
		matchers := polMan.Matchers()
 | 
			
		||||
		changed = policy.FilterNodesByACL(node, changed, matchers)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	profiles := generateUserProfiles(node, changed)
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,14 @@ type Match struct {
 | 
			
		||||
	dests *netipx.IPSet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
 | 
			
		||||
	matches := make([]Match, 0, len(rules))
 | 
			
		||||
	for _, rule := range rules {
 | 
			
		||||
		matches = append(matches, MatchFromFilterRule(rule))
 | 
			
		||||
	}
 | 
			
		||||
	return matches
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
 | 
			
		||||
	dests := []string{}
 | 
			
		||||
	for _, dest := range rule.DstPorts {
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
 | 
			
		||||
	policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
 | 
			
		||||
@ -16,6 +17,8 @@ var (
 | 
			
		||||
 | 
			
		||||
type PolicyManager interface {
 | 
			
		||||
	Filter() []tailcfg.FilterRule
 | 
			
		||||
	// Matchers returns the matchers for the current filter rules.
 | 
			
		||||
	Matchers() []matcher.Match
 | 
			
		||||
	SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
 | 
			
		||||
	SetPolicy([]byte) (bool, error)
 | 
			
		||||
	SetUsers(users []types.User) (bool, error)
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"slices"
 | 
			
		||||
 | 
			
		||||
@ -15,7 +16,7 @@ import (
 | 
			
		||||
func FilterNodesByACL(
 | 
			
		||||
	node *types.Node,
 | 
			
		||||
	nodes types.Nodes,
 | 
			
		||||
	filter []tailcfg.FilterRule,
 | 
			
		||||
	matchers []matcher.Match,
 | 
			
		||||
) types.Nodes {
 | 
			
		||||
	var result types.Nodes
 | 
			
		||||
 | 
			
		||||
@ -24,7 +25,7 @@ func FilterNodesByACL(
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
 | 
			
		||||
		if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
 | 
			
		||||
			result = append(result, peer)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			matchers := matcher.MatchesFromFilterRules(tt.args.rules)
 | 
			
		||||
			got := FilterNodesByACL(
 | 
			
		||||
				tt.args.node,
 | 
			
		||||
				tt.args.nodes,
 | 
			
		||||
				tt.args.rules,
 | 
			
		||||
				matchers,
 | 
			
		||||
			)
 | 
			
		||||
			if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
 | 
			
		||||
				t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package v1
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"os"
 | 
			
		||||
@ -92,6 +93,11 @@ func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
 | 
			
		||||
	return pm.filter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pm *PolicyManager) Matchers() []matcher.Match {
 | 
			
		||||
	filter := pm.Filter()
 | 
			
		||||
	return matcher.MatchesFromFilterRules(filter)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
 | 
			
		||||
	pm.mu.Lock()
 | 
			
		||||
	defer pm.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ package v2
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
@ -22,6 +23,7 @@ type PolicyManager struct {
 | 
			
		||||
 | 
			
		||||
	filterHash deephash.Sum
 | 
			
		||||
	filter     []tailcfg.FilterRule
 | 
			
		||||
	matchers   []matcher.Match
 | 
			
		||||
 | 
			
		||||
	tagOwnerMapHash deephash.Sum
 | 
			
		||||
	tagOwnerMap     map[Tag]*netipx.IPSet
 | 
			
		||||
@ -69,6 +71,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
 | 
			
		||||
	filterChanged := filterHash == pm.filterHash
 | 
			
		||||
	pm.filter = filter
 | 
			
		||||
	pm.filterHash = filterHash
 | 
			
		||||
	if filterChanged {
 | 
			
		||||
		pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Order matters, tags might be used in autoapprovers, so we need to ensure
 | 
			
		||||
	// that the map for tag owners is resolved before resolving autoapprovers.
 | 
			
		||||
@ -149,6 +154,12 @@ func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
 | 
			
		||||
	return pm.filter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pm *PolicyManager) Matchers() []matcher.Match {
 | 
			
		||||
	pm.mu.Lock()
 | 
			
		||||
	defer pm.mu.Unlock()
 | 
			
		||||
	return pm.matchers
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetUsers updates the users in the policy manager and updates the filter rules.
 | 
			
		||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
 | 
			
		||||
	pm.mu.Lock()
 | 
			
		||||
 | 
			
		||||
@ -258,18 +258,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
 | 
			
		||||
func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool {
 | 
			
		||||
	src := node.IPs()
 | 
			
		||||
	allowedIPs := node2.IPs()
 | 
			
		||||
 | 
			
		||||
	// TODO(kradalby): Regenerate this every time the filter change, instead of
 | 
			
		||||
	// every time we use it.
 | 
			
		||||
	// Part of #2416
 | 
			
		||||
	matchers := make([]matcher.Match, len(filter))
 | 
			
		||||
	for i, rule := range filter {
 | 
			
		||||
		matchers[i] = matcher.MatchFromFilterRule(rule)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, matcher := range matchers {
 | 
			
		||||
		if !matcher.SrcsContainsIPs(src...) {
 | 
			
		||||
			continue
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package types
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/policy/matcher"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
@ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			got := tt.node1.CanAccess(tt.rules, &tt.node2)
 | 
			
		||||
			matchers := matcher.MatchesFromFilterRules(tt.rules)
 | 
			
		||||
			got := tt.node1.CanAccess(matchers, &tt.node2)
 | 
			
		||||
 | 
			
		||||
			if got != tt.want {
 | 
			
		||||
				t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user