1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-04-25 01:19:15 +02:00
This commit is contained in:
aergus-tng 2025-04-11 12:39:22 +02:00 committed by GitHub
commit 5c919d08a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 82 additions and 37 deletions

View File

@ -40,7 +40,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(pol) w.Write(pol)
})) }))
debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
filter := h.polMan.Filter() filter, _ := h.polMan.Filter()
filterJSON, err := json.MarshalIndent(filter, "", " ") filterJSON, err := json.MarshalIndent(filter, "", " ")
if err != nil { if err != nil {

View File

@ -536,7 +536,7 @@ func appendPeerChanges(
changed types.Nodes, changed types.Nodes,
cfg *types.Config, cfg *types.Config,
) error { ) error {
filter := polMan.Filter() filter, matchers := polMan.Filter()
sshPolicy, err := polMan.SSHPolicy(node) sshPolicy, err := polMan.SSHPolicy(node)
if err != nil { if err != nil {
@ -546,7 +546,7 @@ func appendPeerChanges(
// If there are filter rules present, see if there are any nodes that cannot // If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers. // access each-other at all and remove them from the peers.
if len(filter) > 0 { if len(filter) > 0 {
changed = policy.FilterNodesByACL(node, changed, filter) changed = policy.FilterNodesByACL(node, changed, matchers)
} }
profiles := generateUserProfiles(node, changed) profiles := generateUserProfiles(node, changed)

View File

@ -13,6 +13,14 @@ type Match struct {
dests *netipx.IPSet dests *netipx.IPSet
} }
func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
matches := make([]Match, 0, len(rules))
for _, rule := range rules {
matches = append(matches, MatchFromFilterRule(rule))
}
return matches
}
func MatchFromFilterRule(rule tailcfg.FilterRule) Match { func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
dests := []string{} dests := []string{}
for _, dest := range rule.DstPorts { for _, dest := range rule.DstPorts {

View File

@ -1,6 +1,7 @@
package policy package policy
import ( import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip" "net/netip"
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
@ -15,7 +16,8 @@ var (
) )
type PolicyManager interface { type PolicyManager interface {
Filter() []tailcfg.FilterRule // Filter returns the current filter rules for the entire tailnet and the associated matchers.
Filter() ([]tailcfg.FilterRule, []matcher.Match)
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error) SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error) SetUsers(users []types.User) (bool, error)

View File

@ -1,6 +1,7 @@
package policy package policy
import ( import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip" "net/netip"
"slices" "slices"
@ -15,7 +16,7 @@ import (
func FilterNodesByACL( func FilterNodesByACL(
node *types.Node, node *types.Node,
nodes types.Nodes, nodes types.Nodes,
filter []tailcfg.FilterRule, matchers []matcher.Match,
) types.Nodes { ) types.Nodes {
var result types.Nodes var result types.Nodes
@ -24,7 +25,7 @@ func FilterNodesByACL(
continue continue
} }
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
result = append(result, peer) result = append(result, peer)
} }
} }

View File

@ -2,6 +2,7 @@ package policy
import ( import (
"fmt" "fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip" "net/netip"
"testing" "testing"
@ -769,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) {
var err error var err error
pm, err = pmf(users, append(tt.peers, tt.node)) pm, err = pmf(users, append(tt.peers, tt.node))
require.NoError(t, err) require.NoError(t, err)
got := pm.Filter() got, _ := pm.Filter()
got = ReduceFilterRules(tt.node, got) got = ReduceFilterRules(tt.node, got)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
got := FilterNodesByACL( got := FilterNodesByACL(
tt.args.node, tt.args.node,
tt.args.nodes, tt.args.nodes,
tt.args.rules, matchers,
) )
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)

View File

@ -2,6 +2,7 @@ package v1
import ( import (
"fmt" "fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"io" "io"
"net/netip" "net/netip"
"os" "os"
@ -86,10 +87,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil return true, nil
} }
func (pm *PolicyManager) Filter() []tailcfg.FilterRule { func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
return pm.filter return pm.filter, matcher.MatchesFromFilterRules(pm.filter)
} }
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {

View File

@ -1,6 +1,7 @@
package v1 package v1
import ( import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -27,6 +28,7 @@ func TestPolicySetChange(t *testing.T) {
wantNodesChange bool wantNodesChange bool
wantPolicyChange bool wantPolicyChange bool
wantFilter []tailcfg.FilterRule wantFilter []tailcfg.FilterRule
wantMatchers []matcher.Match
}{ }{
{ {
name: "set-nodes", name: "set-nodes",
@ -42,6 +44,9 @@ func TestPolicySetChange(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
}, },
}, },
wantMatchers: []matcher.Match{
matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
},
}, },
{ {
name: "set-users", name: "set-users",
@ -52,6 +57,9 @@ func TestPolicySetChange(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
}, },
}, },
wantMatchers: []matcher.Match{
matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
},
}, },
{ {
name: "set-users-and-node", name: "set-users-and-node",
@ -70,6 +78,9 @@ func TestPolicySetChange(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
}, },
}, },
wantMatchers: []matcher.Match{
matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}),
},
}, },
{ {
name: "set-policy", name: "set-policy",
@ -95,6 +106,9 @@ func TestPolicySetChange(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
}, },
}, },
wantMatchers: []matcher.Match{
matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}),
},
}, },
} }
@ -150,8 +164,16 @@ func TestPolicySetChange(t *testing.T) {
assert.Equal(t, tt.wantNodesChange, change) assert.Equal(t, tt.wantNodesChange, change)
} }
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { filter, matchers := pm.Filter()
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff)
}
if diff := cmp.Diff(
tt.wantMatchers,
matchers,
cmp.AllowUnexported(matcher.Match{}),
); diff != "" {
t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff)
} }
}) })
} }

View File

@ -3,6 +3,7 @@ package v2
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
@ -22,6 +23,7 @@ type PolicyManager struct {
filterHash deephash.Sum filterHash deephash.Sum
filter []tailcfg.FilterRule filter []tailcfg.FilterRule
matchers []matcher.Match
tagOwnerMapHash deephash.Sum tagOwnerMapHash deephash.Sum
tagOwnerMap map[Tag]*netipx.IPSet tagOwnerMap map[Tag]*netipx.IPSet
@ -66,9 +68,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
} }
filterHash := deephash.Hash(&filter) filterHash := deephash.Hash(&filter)
filterChanged := filterHash == pm.filterHash filterChanged := filterHash != pm.filterHash
pm.filter = filter pm.filter = filter
pm.filterHash = filterHash pm.filterHash = filterHash
if filterChanged {
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
}
// Order matters, tags might be used in autoapprovers, so we need to ensure // Order matters, tags might be used in autoapprovers, so we need to ensure
// that the map for tag owners is resolved before resolving autoapprovers. // that the map for tag owners is resolved before resolving autoapprovers.
@ -142,11 +147,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
return pm.updateLocked() return pm.updateLocked()
} }
// Filter returns the current filter rules for the entire tailnet. // Filter returns the current filter rules for the entire tailnet and the associated matchers.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule { func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
return pm.filter return pm.filter, pm.matchers
} }
// SetUsers updates the users in the policy manager and updates the filter rules. // SetUsers updates the users in the policy manager and updates the filter rules.

View File

@ -1,6 +1,7 @@
package v2 package v2
import ( import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -29,16 +30,18 @@ func TestPolicyManager(t *testing.T) {
} }
tests := []struct { tests := []struct {
name string name string
pol string pol string
nodes types.Nodes nodes types.Nodes
wantFilter []tailcfg.FilterRule wantFilter []tailcfg.FilterRule
wantMatchers []matcher.Match
}{ }{
{ {
name: "empty-policy", name: "empty-policy",
pol: "{}", pol: "{}",
nodes: types.Nodes{}, nodes: types.Nodes{},
wantFilter: nil, wantFilter: nil,
wantMatchers: []matcher.Match{},
}, },
} }
@ -47,9 +50,16 @@ func TestPolicyManager(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err) require.NoError(t, err)
filter := pm.Filter() filter, matchers := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff) t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(
tt.wantMatchers,
matchers,
cmp.AllowUnexported(matcher.Match{}),
); diff != "" {
t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff)
} }
// TODO(kradalby): Test SSH Policy // TODO(kradalby): Test SSH Policy

View File

@ -258,18 +258,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
} }
} }
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool {
src := node.IPs() src := node.IPs()
allowedIPs := node2.IPs() allowedIPs := node2.IPs()
// TODO(kradalby): Regenerate this every time the filter change, instead of
// every time we use it.
// Part of #2416
matchers := make([]matcher.Match, len(filter))
for i, rule := range filter {
matchers[i] = matcher.MatchFromFilterRule(rule)
}
for _, matcher := range matchers { for _, matcher := range matchers {
if !matcher.SrcsContainsIPs(src...) { if !matcher.SrcsContainsIPs(src...) {
continue continue

View File

@ -2,6 +2,7 @@ package types
import ( import (
"fmt" "fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
@ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := tt.node1.CanAccess(tt.rules, &tt.node2) matchers := matcher.MatchesFromFilterRules(tt.rules)
got := tt.node1.CanAccess(matchers, &tt.node2)
if got != tt.want { if got != tt.want {
t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got) t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)