mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-07 20:04:00 +01:00
policy: merge filter rules with identical SrcIPs and IPProto
Tailscale merges multiple ACL rules into fewer FilterRule entries when they have identical SrcIPs and IPProto, combining their DstPorts arrays. This change implements the same behavior in Headscale. Add mergeFilterRules() which uses O(n) hash map lookup to merge rules with identical keys. DstPorts are NOT deduplicated to match Tailscale behavior. Also fix DestsIsTheInternet() to handle merged filter rules where TheInternet is combined with other destinations - now uses superset check instead of equality check. Updates #3036
This commit is contained in:
parent
08fe2e4d6c
commit
0b1727c337
@ -93,11 +93,24 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
|
||||
return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix)
|
||||
}
|
||||
|
||||
// DestsIsTheInternet reports if the destination is equal to "the internet"
|
||||
// DestsIsTheInternet reports if the destination contains "the internet"
|
||||
// which is a IPSet that represents "autogroup:internet" and is special
|
||||
// cased for exit nodes.
|
||||
// This checks if dests is a superset of TheInternet(), which handles
|
||||
// merged filter rules where TheInternet is combined with other destinations.
|
||||
func (m Match) DestsIsTheInternet() bool {
|
||||
return m.dests.Equal(util.TheInternet()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv6())
|
||||
if m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv6()) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if dests contains all prefixes of TheInternet (superset check)
|
||||
theInternet := util.TheInternet()
|
||||
for _, prefix := range theInternet.Prefixes() {
|
||||
if !m.dests.ContainsPrefix(prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@ -81,7 +83,7 @@ func (pol *Policy) compileFilterRules(
|
||||
})
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
return mergeFilterRules(rules), nil
|
||||
}
|
||||
|
||||
// compileFilterRulesForNode compiles filter rules for a specific node.
|
||||
@ -114,7 +116,7 @@ func (pol *Policy) compileFilterRulesForNode(
|
||||
}
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
return mergeFilterRules(rules), nil
|
||||
}
|
||||
|
||||
// compileACLWithAutogroupSelf compiles a single ACL rule, handling
|
||||
@ -460,3 +462,45 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// filterRuleKey generates a unique key for merging based on SrcIPs and IPProto.
|
||||
func filterRuleKey(rule tailcfg.FilterRule) string {
|
||||
srcKey := strings.Join(rule.SrcIPs, ",")
|
||||
|
||||
protoStrs := make([]string, len(rule.IPProto))
|
||||
for i, p := range rule.IPProto {
|
||||
protoStrs[i] = strconv.Itoa(p)
|
||||
}
|
||||
|
||||
return srcKey + "|" + strings.Join(protoStrs, ",")
|
||||
}
|
||||
|
||||
// mergeFilterRules merges rules with identical SrcIPs and IPProto by combining
|
||||
// their DstPorts. DstPorts are NOT deduplicated to match Tailscale behavior.
|
||||
func mergeFilterRules(rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
if len(rules) <= 1 {
|
||||
return rules
|
||||
}
|
||||
|
||||
keyToIdx := make(map[string]int)
|
||||
result := make([]tailcfg.FilterRule, 0, len(rules))
|
||||
|
||||
for _, rule := range rules {
|
||||
key := filterRuleKey(rule)
|
||||
|
||||
if idx, exists := keyToIdx[key]; exists {
|
||||
// Merge: append DstPorts to existing rule
|
||||
result[idx].DstPorts = append(result[idx].DstPorts, rule.DstPorts...)
|
||||
} else {
|
||||
// New unique combination
|
||||
keyToIdx[key] = len(result)
|
||||
result = append(result, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: slices.Clone(rule.DstPorts),
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@ -1861,3 +1861,212 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) {
|
||||
assert.True(t, containsSrcIP(directionRules, "100.64.0.1"),
|
||||
"superadmin's IP should be in sources for rule 1 (partial resolution preserved)")
|
||||
}
|
||||
|
||||
func TestMergeFilterRules(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []tailcfg.FilterRule
|
||||
want []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []tailcfg.FilterRule{},
|
||||
want: []tailcfg.FilterRule{},
|
||||
},
|
||||
{
|
||||
name: "single rule unchanged",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge two rules with same key",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different SrcIPs not merged",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IPProto not merged",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{ProtocolUDP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{ProtocolUDP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DstPorts combined without deduplication",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge three rules with same key",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
{IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mergeFilterRules(tt.input)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("mergeFilterRules() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user