From 0b1727c3378f69340149d3dc7470d7d969384f2b Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 24 Jan 2026 07:49:21 +0000 Subject: [PATCH] policy: merge filter rules with identical SrcIPs and IPProto Tailscale merges multiple ACL rules into fewer FilterRule entries when they have identical SrcIPs and IPProto, combining their DstPorts arrays. This change implements the same behavior in Headscale. Add mergeFilterRules() which uses O(n) hash map lookup to merge rules with identical keys. DstPorts are NOT deduplicated to match Tailscale behavior. Also fix DestsIsTheInternet() to handle merged filter rules where TheInternet is combined with other destinations - now uses superset check instead of equality check. Updates #3036 --- hscontrol/policy/matcher/matcher.go | 21 ++- hscontrol/policy/v2/filter.go | 48 ++++++- hscontrol/policy/v2/filter_test.go | 209 ++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 6 deletions(-) diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index afc3cf68..1e6312b8 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -93,11 +93,24 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix) } -// DestsIsTheInternet reports if the destination is equal to "the internet" +// DestsIsTheInternet reports if the destination contains "the internet" // which is a IPSet that represents "autogroup:internet" and is special // cased for exit nodes. +// This checks if dests is a superset of TheInternet(), which handles +// merged filter rules where TheInternet is combined with other destinations. func (m Match) DestsIsTheInternet() bool { - return m.dests.Equal(util.TheInternet()) || - m.dests.ContainsPrefix(tsaddr.AllIPv4()) || - m.dests.ContainsPrefix(tsaddr.AllIPv6()) + if m.dests.ContainsPrefix(tsaddr.AllIPv4()) || + m.dests.ContainsPrefix(tsaddr.AllIPv6()) { + return true + } + + // Check if dests contains all prefixes of TheInternet (superset check) + theInternet := util.TheInternet() + for _, prefix := range theInternet.Prefixes() { + if !m.dests.ContainsPrefix(prefix) { + return false + } + } + + return true } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index bdcbb1f2..63950de6 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -5,6 +5,8 @@ import ( "fmt" "net/netip" "slices" + "strconv" + "strings" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -81,7 +83,7 @@ func (pol *Policy) compileFilterRules( }) } - return rules, nil + return mergeFilterRules(rules), nil } // compileFilterRulesForNode compiles filter rules for a specific node. @@ -114,7 +116,7 @@ func (pol *Policy) compileFilterRulesForNode( } } - return rules, nil + return mergeFilterRules(rules), nil } // compileACLWithAutogroupSelf compiles a single ACL rule, handling @@ -460,3 +462,45 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string { return out } + +// filterRuleKey generates a unique key for merging based on SrcIPs and IPProto. +func filterRuleKey(rule tailcfg.FilterRule) string { + srcKey := strings.Join(rule.SrcIPs, ",") + + protoStrs := make([]string, len(rule.IPProto)) + for i, p := range rule.IPProto { + protoStrs[i] = strconv.Itoa(p) + } + + return srcKey + "|" + strings.Join(protoStrs, ",") +} + +// mergeFilterRules merges rules with identical SrcIPs and IPProto by combining +// their DstPorts. DstPorts are NOT deduplicated to match Tailscale behavior. +func mergeFilterRules(rules []tailcfg.FilterRule) []tailcfg.FilterRule { + if len(rules) <= 1 { + return rules + } + + keyToIdx := make(map[string]int) + result := make([]tailcfg.FilterRule, 0, len(rules)) + + for _, rule := range rules { + key := filterRuleKey(rule) + + if idx, exists := keyToIdx[key]; exists { + // Merge: append DstPorts to existing rule + result[idx].DstPorts = append(result[idx].DstPorts, rule.DstPorts...) + } else { + // New unique combination + keyToIdx[key] = len(result) + result = append(result, tailcfg.FilterRule{ + SrcIPs: rule.SrcIPs, + DstPorts: slices.Clone(rule.DstPorts), + IPProto: rule.IPProto, + }) + } + } + + return result +} diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index eb3afb39..c5a4bbd1 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -1861,3 +1861,212 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) { assert.True(t, containsSrcIP(directionRules, "100.64.0.1"), "superadmin's IP should be in sources for rule 1 (partial resolution preserved)") } + +func TestMergeFilterRules(t *testing.T) { + tests := []struct { + name string + input []tailcfg.FilterRule + want []tailcfg.FilterRule + }{ + { + name: "empty input", + input: []tailcfg.FilterRule{}, + want: []tailcfg.FilterRule{}, + }, + { + name: "single rule unchanged", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + }, + { + name: "merge two rules with same key", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP}, + }, + }, + }, + { + name: "different SrcIPs not merged", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + }, + { + name: "different IPProto not merged", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{ProtocolUDP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{ProtocolUDP}, + }, + }, + }, + { + name: "DstPorts combined without deduplication", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + }, + { + name: "merge three rules with same key", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + {IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeFilterRules(tt.input) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("mergeFilterRules() mismatch (-want +got):\n%s", diff) + } + }) + } +}