mirror of
https://github.com/juanfont/headscale.git
synced 2025-04-25 01:19:15 +02:00
Merge 7de444f5e5
into 109989005d
This commit is contained in:
commit
5c919d08a9
@ -40,7 +40,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
w.Write(pol)
|
w.Write(pol)
|
||||||
}))
|
}))
|
||||||
debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
filter := h.polMan.Filter()
|
filter, _ := h.polMan.Filter()
|
||||||
|
|
||||||
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -536,7 +536,7 @@ func appendPeerChanges(
|
|||||||
changed types.Nodes,
|
changed types.Nodes,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) error {
|
) error {
|
||||||
filter := polMan.Filter()
|
filter, matchers := polMan.Filter()
|
||||||
|
|
||||||
sshPolicy, err := polMan.SSHPolicy(node)
|
sshPolicy, err := polMan.SSHPolicy(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -546,7 +546,7 @@ func appendPeerChanges(
|
|||||||
// If there are filter rules present, see if there are any nodes that cannot
|
// If there are filter rules present, see if there are any nodes that cannot
|
||||||
// access each-other at all and remove them from the peers.
|
// access each-other at all and remove them from the peers.
|
||||||
if len(filter) > 0 {
|
if len(filter) > 0 {
|
||||||
changed = policy.FilterNodesByACL(node, changed, filter)
|
changed = policy.FilterNodesByACL(node, changed, matchers)
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := generateUserProfiles(node, changed)
|
profiles := generateUserProfiles(node, changed)
|
||||||
|
@ -13,6 +13,14 @@ type Match struct {
|
|||||||
dests *netipx.IPSet
|
dests *netipx.IPSet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
||||||
|
matches := make([]Match, 0, len(rules))
|
||||||
|
for _, rule := range rules {
|
||||||
|
matches = append(matches, MatchFromFilterRule(rule))
|
||||||
|
}
|
||||||
|
return matches
|
||||||
|
}
|
||||||
|
|
||||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||||
dests := []string{}
|
dests := []string{}
|
||||||
for _, dest := range rule.DstPorts {
|
for _, dest := range rule.DstPorts {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package policy
|
package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||||
@ -15,7 +16,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type PolicyManager interface {
|
type PolicyManager interface {
|
||||||
Filter() []tailcfg.FilterRule
|
// Filter returns the current filter rules for the entire tailnet and the associated matchers.
|
||||||
|
Filter() ([]tailcfg.FilterRule, []matcher.Match)
|
||||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||||
SetPolicy([]byte) (bool, error)
|
SetPolicy([]byte) (bool, error)
|
||||||
SetUsers(users []types.User) (bool, error)
|
SetUsers(users []types.User) (bool, error)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package policy
|
package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@ -15,7 +16,7 @@ import (
|
|||||||
func FilterNodesByACL(
|
func FilterNodesByACL(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
filter []tailcfg.FilterRule,
|
matchers []matcher.Match,
|
||||||
) types.Nodes {
|
) types.Nodes {
|
||||||
var result types.Nodes
|
var result types.Nodes
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ func FilterNodesByACL(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
|
||||||
result = append(result, peer)
|
result = append(result, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package policy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -769,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||||||
var err error
|
var err error
|
||||||
pm, err = pmf(users, append(tt.peers, tt.node))
|
pm, err = pmf(users, append(tt.peers, tt.node))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
got := pm.Filter()
|
got, _ := pm.Filter()
|
||||||
got = ReduceFilterRules(tt.node, got)
|
got = ReduceFilterRules(tt.node, got)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
@ -1425,10 +1426,11 @@ func TestFilterNodesByACL(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||||
got := FilterNodesByACL(
|
got := FilterNodesByACL(
|
||||||
tt.args.node,
|
tt.args.node,
|
||||||
tt.args.nodes,
|
tt.args.nodes,
|
||||||
tt.args.rules,
|
matchers,
|
||||||
)
|
)
|
||||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||||
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
||||||
|
@ -2,6 +2,7 @@ package v1
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@ -86,10 +87,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
return pm.filter
|
return pm.filter, matcher.MatchesFromFilterRules(pm.filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@ -27,6 +28,7 @@ func TestPolicySetChange(t *testing.T) {
|
|||||||
wantNodesChange bool
|
wantNodesChange bool
|
||||||
wantPolicyChange bool
|
wantPolicyChange bool
|
||||||
wantFilter []tailcfg.FilterRule
|
wantFilter []tailcfg.FilterRule
|
||||||
|
wantMatchers []matcher.Match
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "set-nodes",
|
name: "set-nodes",
|
||||||
@ -42,6 +44,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
wantMatchers: []matcher.Match{
|
||||||
|
matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set-users",
|
name: "set-users",
|
||||||
@ -52,6 +57,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
wantMatchers: []matcher.Match{
|
||||||
|
matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set-users-and-node",
|
name: "set-users-and-node",
|
||||||
@ -70,6 +78,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
wantMatchers: []matcher.Match{
|
||||||
|
matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set-policy",
|
name: "set-policy",
|
||||||
@ -95,6 +106,9 @@ func TestPolicySetChange(t *testing.T) {
|
|||||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
wantMatchers: []matcher.Match{
|
||||||
|
matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,8 +164,16 @@ func TestPolicySetChange(t *testing.T) {
|
|||||||
assert.Equal(t, tt.wantNodesChange, change)
|
assert.Equal(t, tt.wantNodesChange, change)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
filter, matchers := pm.Filter()
|
||||||
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||||
|
t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(
|
||||||
|
tt.wantMatchers,
|
||||||
|
matchers,
|
||||||
|
cmp.AllowUnexported(matcher.Match{}),
|
||||||
|
); diff != "" {
|
||||||
|
t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package v2
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -22,6 +23,7 @@ type PolicyManager struct {
|
|||||||
|
|
||||||
filterHash deephash.Sum
|
filterHash deephash.Sum
|
||||||
filter []tailcfg.FilterRule
|
filter []tailcfg.FilterRule
|
||||||
|
matchers []matcher.Match
|
||||||
|
|
||||||
tagOwnerMapHash deephash.Sum
|
tagOwnerMapHash deephash.Sum
|
||||||
tagOwnerMap map[Tag]*netipx.IPSet
|
tagOwnerMap map[Tag]*netipx.IPSet
|
||||||
@ -66,9 +68,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
filterHash := deephash.Hash(&filter)
|
filterHash := deephash.Hash(&filter)
|
||||||
filterChanged := filterHash == pm.filterHash
|
filterChanged := filterHash != pm.filterHash
|
||||||
pm.filter = filter
|
pm.filter = filter
|
||||||
pm.filterHash = filterHash
|
pm.filterHash = filterHash
|
||||||
|
if filterChanged {
|
||||||
|
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
|
||||||
|
}
|
||||||
|
|
||||||
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
||||||
// that the map for tag owners is resolved before resolving autoapprovers.
|
// that the map for tag owners is resolved before resolving autoapprovers.
|
||||||
@ -142,11 +147,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
|||||||
return pm.updateLocked()
|
return pm.updateLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter returns the current filter rules for the entire tailnet.
|
// Filter returns the current filter rules for the entire tailnet and the associated matchers.
|
||||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
return pm.filter
|
return pm.filter, pm.matchers
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package v2
|
package v2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@ -29,16 +30,18 @@ func TestPolicyManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pol string
|
pol string
|
||||||
nodes types.Nodes
|
nodes types.Nodes
|
||||||
wantFilter []tailcfg.FilterRule
|
wantFilter []tailcfg.FilterRule
|
||||||
|
wantMatchers []matcher.Match
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty-policy",
|
name: "empty-policy",
|
||||||
pol: "{}",
|
pol: "{}",
|
||||||
nodes: types.Nodes{},
|
nodes: types.Nodes{},
|
||||||
wantFilter: nil,
|
wantFilter: nil,
|
||||||
|
wantMatchers: []matcher.Match{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,9 +50,16 @@ func TestPolicyManager(t *testing.T) {
|
|||||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
filter := pm.Filter()
|
filter, matchers := pm.Filter()
|
||||||
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||||
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(
|
||||||
|
tt.wantMatchers,
|
||||||
|
matchers,
|
||||||
|
cmp.AllowUnexported(matcher.Match{}),
|
||||||
|
); diff != "" {
|
||||||
|
t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Test SSH Policy
|
// TODO(kradalby): Test SSH Policy
|
||||||
|
@ -258,18 +258,10 @@ func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
|
func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool {
|
||||||
src := node.IPs()
|
src := node.IPs()
|
||||||
allowedIPs := node2.IPs()
|
allowedIPs := node2.IPs()
|
||||||
|
|
||||||
// TODO(kradalby): Regenerate this every time the filter change, instead of
|
|
||||||
// every time we use it.
|
|
||||||
// Part of #2416
|
|
||||||
matchers := make([]matcher.Match, len(filter))
|
|
||||||
for i, rule := range filter {
|
|
||||||
matchers[i] = matcher.MatchFromFilterRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, matcher := range matchers {
|
for _, matcher := range matchers {
|
||||||
if !matcher.SrcsContainsIPs(src...) {
|
if !matcher.SrcsContainsIPs(src...) {
|
||||||
continue
|
continue
|
||||||
|
@ -2,6 +2,7 @@ package types
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -116,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := tt.node1.CanAccess(tt.rules, &tt.node2)
|
matchers := matcher.MatchesFromFilterRules(tt.rules)
|
||||||
|
got := tt.node1.CanAccess(matchers, &tt.node2)
|
||||||
|
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)
|
t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got)
|
||||||
|
Loading…
Reference in New Issue
Block a user