diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index afc3cf68..1e6312b8 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -93,11 +93,24 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix) } -// DestsIsTheInternet reports if the destination is equal to "the internet" +// DestsIsTheInternet reports if the destination contains "the internet" // which is a IPSet that represents "autogroup:internet" and is special // cased for exit nodes. +// This checks if dests is a superset of TheInternet(), which handles +// merged filter rules where TheInternet is combined with other destinations. func (m Match) DestsIsTheInternet() bool { - return m.dests.Equal(util.TheInternet()) || - m.dests.ContainsPrefix(tsaddr.AllIPv4()) || - m.dests.ContainsPrefix(tsaddr.AllIPv6()) + if m.dests.ContainsPrefix(tsaddr.AllIPv4()) || + m.dests.ContainsPrefix(tsaddr.AllIPv6()) { + return true + } + + // Check if dests contains all prefixes of TheInternet (superset check) + theInternet := util.TheInternet() + for _, prefix := range theInternet.Prefixes() { + if !m.dests.ContainsPrefix(prefix) { + return false + } + } + + return true } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index bdcbb1f2..63950de6 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -5,6 +5,8 @@ import ( "fmt" "net/netip" "slices" + "strconv" + "strings" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -81,7 +83,7 @@ func (pol *Policy) compileFilterRules( }) } - return rules, nil + return mergeFilterRules(rules), nil } // compileFilterRulesForNode compiles filter rules for a specific node. @@ -114,7 +116,7 @@ func (pol *Policy) compileFilterRulesForNode( } } - return rules, nil + return mergeFilterRules(rules), nil } // compileACLWithAutogroupSelf compiles a single ACL rule, handling @@ -460,3 +462,45 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string { return out } + +// filterRuleKey generates a unique key for merging based on SrcIPs and IPProto. +func filterRuleKey(rule tailcfg.FilterRule) string { + srcKey := strings.Join(rule.SrcIPs, ",") + + protoStrs := make([]string, len(rule.IPProto)) + for i, p := range rule.IPProto { + protoStrs[i] = strconv.Itoa(p) + } + + return srcKey + "|" + strings.Join(protoStrs, ",") +} + +// mergeFilterRules merges rules with identical SrcIPs and IPProto by combining +// their DstPorts. DstPorts are NOT deduplicated to match Tailscale behavior. +func mergeFilterRules(rules []tailcfg.FilterRule) []tailcfg.FilterRule { + if len(rules) <= 1 { + return rules + } + + keyToIdx := make(map[string]int) + result := make([]tailcfg.FilterRule, 0, len(rules)) + + for _, rule := range rules { + key := filterRuleKey(rule) + + if idx, exists := keyToIdx[key]; exists { + // Merge: append DstPorts to existing rule + result[idx].DstPorts = append(result[idx].DstPorts, rule.DstPorts...) + } else { + // New unique combination + keyToIdx[key] = len(result) + result = append(result, tailcfg.FilterRule{ + SrcIPs: rule.SrcIPs, + DstPorts: slices.Clone(rule.DstPorts), + IPProto: rule.IPProto, + }) + } + } + + return result +} diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index eb3afb39..c5a4bbd1 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -1861,3 +1861,212 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) { assert.True(t, containsSrcIP(directionRules, "100.64.0.1"), "superadmin's IP should be in sources for rule 1 (partial resolution preserved)") } + +func TestMergeFilterRules(t *testing.T) { + tests := []struct { + name string + input []tailcfg.FilterRule + want []tailcfg.FilterRule + }{ + { + name: "empty input", + input: []tailcfg.FilterRule{}, + want: []tailcfg.FilterRule{}, + }, + { + name: "single rule unchanged", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + }, + { + name: "merge two rules with same key", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP}, + }, + }, + }, + { + name: "different SrcIPs not merged", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + }, + { + name: "different IPProto not merged", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{ProtocolUDP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{ProtocolUDP}, + }, + }, + }, + { + name: "DstPorts combined without deduplication", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP}, + }, + }, + }, + { + name: "merge three rules with same key", + input: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + {IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, + }, + IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeFilterRules(tt.input) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("mergeFilterRules() mismatch (-want +got):\n%s", diff) + } + }) + } +}