diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 662e491c..54c80bac 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -546,7 +546,7 @@ func appendPeerChanges( // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. if len(filter) > 0 { - changed = policy.FilterNodesByACL(node, changed, matchers) + changed = policy.ReduceNodes(node, changed, matchers) } profiles := generateUserProfiles(node, changed) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 5d718b54..dfce60bb 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -348,6 +348,11 @@ func Test_fullMapResponse(t *testing.T) { "src": ["100.64.0.2"], "dst": ["user1@:*"], }, + { + "action": "accept", + "src": ["100.64.0.1"], + "dst": ["192.168.0.0/24:*"], + }, ], } `), @@ -380,6 +385,10 @@ func Test_fullMapResponse(t *testing.T) { {IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}, }, }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "192.168.0.0/24", Ports: tailcfg.PortRangeAny}}, + }, }, }, SSHPolicy: nil, diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 32905345..c9365f1a 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -81,7 +81,9 @@ func tailNode( } tags = lo.Uniq(append(tags, node.ForcedTags...)) - allowed := append(node.Prefixes(), primary.PrimaryRoutes(node.ID)...) + _, matchers := polMan.Filter() + routes := policy.ReduceRoutes(node, primary.PrimaryRoutes(node.ID), matchers) + allowed := append(node.Prefixes(), routes...) allowed = append(allowed, node.ExitRoutes()...) tsaddr.SortPrefixes(allowed) diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 1c3c018f..30c6d4a8 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -269,10 +269,13 @@ func TestNodeExpiry(t *testing.T) { GivenName: "test", Expiry: tt.exp, } + polMan, err := policy.NewPolicyManager(nil, nil, nil) + require.NoError(t, err) + tn, err := tailNode( node, 0, - nil, // TODO(kradalby): removed in merge but error? + polMan, nil, &types.Config{}, ) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 0df1bcc4..b90d2efc 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -1,9 +1,10 @@ package policy import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" @@ -33,7 +34,7 @@ type PolicyManager interface { } // NewPolicyManager returns a new policy manager, the version is determined by -// the environment flag "HEADSCALE_EXPERIMENTAL_POLICY_V2". +// the environment flag "HEADSCALE_POLICY_V1". func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { var polMan PolicyManager var err error diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index d86de29b..5859a198 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -1,10 +1,11 @@ package policy import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "slices" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/samber/lo" @@ -12,8 +13,8 @@ import ( "tailscale.com/tailcfg" ) -// FilterNodesByACL returns the list of peers authorized to be accessed from a given node. -func FilterNodesByACL( +// ReduceNodes returns the list of peers authorized to be accessed from a given node. +func ReduceNodes( node *types.Node, nodes types.Nodes, matchers []matcher.Match, @@ -33,6 +34,23 @@ func FilterNodesByACL( return result } +// ReduceRoutes returns a reduced list of routes for a given node that it can access. +func ReduceRoutes( + node *types.Node, + routes []netip.Prefix, + matchers []matcher.Match, +) []netip.Prefix { + var result []netip.Prefix + + for _, route := range routes { + if node.CanAccessRoute(matchers, route) { + result = append(result, route) + } + } + + return result +} + // ReduceFilterRules takes a node and a set of rules and removes all rules and destinations // that are not relevant to that particular node. func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule { diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 5b3814a2..b1d44cc7 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -784,7 +784,7 @@ func TestReduceFilterRules(t *testing.T) { } } -func TestFilterNodesByACL(t *testing.T) { +func TestReduceNodes(t *testing.T) { type args struct { nodes types.Nodes rules []tailcfg.FilterRule @@ -1530,7 +1530,7 @@ func TestFilterNodesByACL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) - got := FilterNodesByACL( + got := ReduceNodes( tt.args.node, tt.args.nodes, matchers, @@ -1946,3 +1946,197 @@ func TestSSHPolicyRules(t *testing.T) { } } } +func TestReduceRoutes(t *testing.T) { + type args struct { + node *types.Node + routes []netip.Prefix + rules []tailcfg.FilterRule + } + tests := []struct { + name string + args args + want []netip.Prefix + }{ + { + name: "node can access all routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + }, + { + name: "node can access specific route", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + }, + { + name: "node can access multiple specific routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, + {IP: "192.168.1.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "node can access overlapping routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/16"), // Overlaps with the first one + netip.MustParsePrefix("192.168.1.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/16"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/16"), + }, + }, + { + name: "node with no matching rules", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, // Different source IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + }, + want: nil, + }, + { + name: "node with both IPv4 and IPv6", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/64"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"fd7a:115c:a1e0::1"}, // IPv6 source + DstPorts: []tailcfg.NetPortRange{ + {IP: "2001:db8::/64"}, // IPv6 destination + }, + }, + { + SrcIPs: []string{"100.64.0.1"}, // IPv4 source + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, // IPv4 destination + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/64"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( + tt.args.node, + tt.args.routes, + matchers, + ) + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("ReduceRoutes() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index ec4b7737..5d7f7db1 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -152,6 +152,10 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { // Filter returns the current filter rules for the entire tailnet and the associated matchers. func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { + if pm == nil { + return nil, nil + } + pm.mu.Lock() defer pm.mu.Unlock() return pm.filter, pm.matchers @@ -159,6 +163,10 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { // SetUsers updates the users in the policy manager and updates the filter rules. func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { + if pm == nil { + return false, nil + } + pm.mu.Lock() defer pm.mu.Unlock() pm.users = users @@ -167,6 +175,10 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { // SetNodes updates the nodes in the policy manager and updates the filter rules. func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { + if pm == nil { + return false, nil + } + pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes @@ -238,6 +250,10 @@ func (pm *PolicyManager) Version() int { } func (pm *PolicyManager) DebugString() string { + if pm == nil { + return "PolicyManager is not setup" + } + var sb strings.Builder fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version()) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 826867eb..76770160 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -291,6 +291,22 @@ func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { return false } +func (node *Node) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool { + src := node.IPs() + + for _, matcher := range matchers { + if !matcher.SrcsContainsIPs(src...) { + continue + } + + if matcher.DestsOverlapsPrefixes(route) { + return true + } + } + + return false +} + func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { var found Nodes