1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-28 13:49:04 +02:00

introduce policy v2 package

policy v2 is built from the ground up to be stricter
and follow the same pattern for all types of resolvers.

TODO introduce
aliass
resolver

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-26 19:31:08 +01:00
parent c0f856256a
commit 72049e94b1
No known key found for this signature in database
8 changed files with 3321 additions and 0 deletions

View File

@ -0,0 +1,169 @@
package v2
import (
"errors"
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"tailscale.com/tailcfg"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) compileFilterRules(
users types.Users,
nodes types.Nodes,
) ([]tailcfg.FilterRule, error) {
if pol == nil {
return tailcfg.FilterAllowAll, nil
}
var rules []tailcfg.FilterRule
for _, acl := range pol.ACLs {
if acl.Action != "accept" {
return nil, ErrInvalidAction
}
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving source ips")
}
if len(srcIPs.Prefixes()) == 0 {
continue
}
// TODO(kradalby): integrate type into schema
// TODO(kradalby): figure out the _ is wildcard stuff
protocols, _, err := parseProtocol(acl.Protocol)
if err != nil {
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
}
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
}
for _, pref := range ips.Prefixes() {
for _, port := range dest.Ports {
pr := tailcfg.NetPortRange{
IP: pref.String(),
Ports: port,
}
destPorts = append(destPorts, pr)
}
}
}
if len(destPorts) == 0 {
continue
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: ipSetToPrefixStringList(srcIPs),
DstPorts: destPorts,
IPProto: protocols,
})
}
return rules, nil
}
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: !accept,
Accept: accept,
SessionDuration: duration,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
}
}
func (pol *Policy) compileSSHPolicy(
users types.Users,
node *types.Node,
nodes types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
return nil, nil
}
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
}
dest.AddSet(ips)
}
destSet, err := dest.IPSet()
if err != nil {
return nil, err
}
if !node.InIPSet(destSet) {
continue
}
var action tailcfg.SSHAction
switch rule.Action {
case "accept":
action = sshAction(true, 0)
case "check":
action = sshAction(true, rule.CheckPeriod)
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}
var principals []*tailcfg.SSHPrincipal
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving source ips")
}
for addr := range util.IPSetAddrIter(srcIPs) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
userMap := make(map[string]string, len(rule.Users))
for _, user := range rule.Users {
userMap[user.String()] = "="
}
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
return &tailcfg.SSHPolicy{
Rules: rules,
}, nil
}
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

View File

@ -0,0 +1,378 @@
package v2
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func TestParsing(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser"},
}
tests := []struct {
name string
format string
acl string
want []tailcfg.FilterRule
wantErr bool
}{
{
name: "invalid-hujson",
format: "hujson",
acl: `
{
`,
want: []tailcfg.FilterRule{},
wantErr: true,
},
// The new parser will ignore all that is irrelevant
// {
// name: "valid-hujson-invalid-content",
// format: "hujson",
// acl: `
// {
// "valid_json": true,
// "but_a_policy_though": false
// }
// `,
// want: []tailcfg.FilterRule{},
// wantErr: true,
// },
// {
// name: "invalid-cidr",
// format: "hujson",
// acl: `
// {"example-host-1": "100.100.100.100/42"}
// `,
// want: []tailcfg.FilterRule{},
// wantErr: true,
// },
{
name: "basic-rule",
format: "hujson",
acl: `
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"subnet-1",
"192.168.1.0/24"
],
"dst": [
"*:22,3389",
"host-1:*",
],
},
],
}
`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
{IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
},
},
wantErr: false,
},
{
name: "parse-protocol",
format: "hujson",
acl: `
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"Action": "accept",
"src": [
"*",
],
"proto": "tcp",
"dst": [
"host-1:*",
],
},
{
"Action": "accept",
"src": [
"*",
],
"proto": "udp",
"dst": [
"host-1:53",
],
},
{
"Action": "accept",
"src": [
"*",
],
"proto": "icmp",
"dst": [
"host-1:*",
],
},
],
}`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolTCP},
},
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
},
IPProto: []int{protocolUDP},
},
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolICMP, protocolIPv6ICMP},
},
},
wantErr: false,
},
{
name: "port-wildcard",
format: "hujson",
acl: `
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"Action": "accept",
"src": [
"*",
],
"dst": [
"host-1:*",
],
},
],
}
`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
},
},
wantErr: false,
},
{
name: "port-range",
format: "hujson",
acl: `
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"subnet-1",
],
"dst": [
"host-1:5400-5500",
],
},
],
}
`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.100.101.0/24"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.100.100.100/32",
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
},
},
},
},
wantErr: false,
},
{
name: "port-group",
format: "hujson",
acl: `
{
"groups": {
"group:example": [
"testuser@",
],
},
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"group:example",
],
"dst": [
"host-1:*",
],
},
],
}
`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"200.200.200.200/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
},
},
wantErr: false,
},
{
name: "port-user",
format: "hujson",
acl: `
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"testuser@",
],
"dst": [
"host-1:*",
],
},
],
}
`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"200.200.200.200/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
},
},
wantErr: false,
},
{
name: "ipv6",
format: "hujson",
acl: `
{
"hosts": {
"host-1": "100.100.100.100/32",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"*",
],
"dst": [
"host-1:*",
],
},
],
}
`,
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol, err := policyFromBytes([]byte(tt.acl))
if tt.wantErr && err == nil {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
return
} else if !tt.wantErr && err != nil {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
return
}
rules, err := pol.compileFilterRules(
users,
types.Nodes{
&types.Node{
IPv4: ap("100.100.100.100"),
},
&types.Node{
IPv4: ap("200.200.200.200"),
User: users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
})
if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, rules); diff != "" {
t.Errorf("parsing() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View File

@ -0,0 +1,283 @@
package v2
import (
"encoding/json"
"fmt"
"net/netip"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
)
type PolicyManager struct {
mu sync.Mutex
pol *Policy
users []types.User
nodes types.Nodes
filterHash deephash.Sum
filter []tailcfg.FilterRule
tagOwnerMapHash deephash.Sum
tagOwnerMap map[Tag]*netipx.IPSet
autoApproveMapHash deephash.Sum
autoApproveMap map[netip.Prefix]*netipx.IPSet
// Lazy map of SSH policies
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
}
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policy, err := policyFromBytes(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
}
_, err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManager) updateLocked() (bool, error) {
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("compiling filter rules: %w", err)
}
filterHash := deephash.Hash(&filter)
filterChanged := filterHash == pm.filterHash
pm.filter = filter
pm.filterHash = filterHash
// Order matters, tags might be used in autoapprovers, so we need to ensure
// that the map for tag owners is resolved before resolving autoapprovers.
// TODO(kradalby): Order might not matter after #2417
tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("resolving tag owners map: %w", err)
}
tagOwnerMapHash := deephash.Hash(&tagMap)
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
pm.tagOwnerMap = tagMap
pm.tagOwnerMapHash = tagOwnerMapHash
autoMap, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("resolving auto approvers map: %w", err)
}
autoApproveMapHash := deephash.Hash(&autoMap)
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
pm.autoApproveMap = autoMap
pm.autoApproveMapHash = autoApproveMapHash
// If neither of the calculated values changed, no need to update nodes
if !filterChanged && !tagOwnerChanged && !autoApproveChanged {
return false, nil
}
// Clear the SSH policy map to ensure it's recalculated with the new policy.
// TODO(kradalby): This could potentially be optimized by only clearing the
// policies for nodes that have changed. Particularly if the only difference is
// that nodes has been added or removed.
clear(pm.sshPolicyMap)
return true, nil
}
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
return sshPol, nil
}
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}
pm.sshPolicyMap[node.ID] = sshPol
return sshPol, nil
}
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
if len(polB) == 0 {
return false, nil
}
pol, err := policyFromBytes(polB)
if err != nil {
return false, fmt.Errorf("parsing policy: %w", err)
}
pm.mu.Lock()
defer pm.mu.Unlock()
pm.pol = pol
return pm.updateLocked()
}
// Filter returns the current filter rules for the entire tailnet.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
if pm == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
for _, nodeAddr := range node.IPs() {
if ips.Contains(nodeAddr) {
return true
}
}
}
return false
}
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
if pm == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
// The fast path is that a node requests to approve a prefix
// where there is an exact entry, e.g. 10.0.0.0/8, then
// check and return quickly
if _, ok := pm.autoApproveMap[route]; ok {
for _, nodeAddr := range node.IPs() {
if pm.autoApproveMap[route].Contains(nodeAddr) {
return true
}
}
}
// The slow path is that the node tries to approve
// 10.0.10.0/24, which is a part of 10.0.0.0/8, then we
// cannot just lookup in the prefix map and have to check
// if there is a "parent" prefix available.
for prefix, approveAddrs := range pm.autoApproveMap {
// We do not want the exit node entry to approve all
// sorts of routes. The logic here is that it would be
// unexpected behaviour to have specific routes approved
// just because the node is allowed to designate itself as
// an exit.
if tsaddr.IsExitRoute(prefix) {
continue
}
// Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
for _, nodeAddr := range node.IPs() {
if approveAddrs.Contains(nodeAddr) {
return true
}
}
}
}
return false
}
func (pm *PolicyManager) Version() int {
return 2
}
func (pm *PolicyManager) DebugString() string {
var sb strings.Builder
fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version())
sb.WriteString("\n\n")
if pm.pol != nil {
pol, err := json.MarshalIndent(pm.pol, "", " ")
if err == nil {
sb.WriteString("Policy:\n")
sb.Write(pol)
sb.WriteString("\n\n")
}
}
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
for prefix, approveAddrs := range pm.autoApproveMap {
fmt.Fprintf(&sb, "\t%s:\n", prefix)
for _, iprange := range approveAddrs.Ranges() {
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
}
}
sb.WriteString("\n\n")
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
for prefix, tagOwners := range pm.tagOwnerMap {
fmt.Fprintf(&sb, "\t%s:\n", prefix)
for _, iprange := range tagOwners.Ranges() {
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
}
}
sb.WriteString("\n\n")
if pm.filter != nil {
filter, err := json.MarshalIndent(pm.filter, "", " ")
if err == nil {
sb.WriteString("Compiled filter:\n")
sb.Write(filter)
sb.WriteString("\n\n")
}
}
return sb.String()
}

View File

@ -0,0 +1,58 @@
package v2
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
return &types.Node{
ID: 0,
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
Hostinfo: hostinfo,
}
}
func TestPolicyManager(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
}
tests := []struct {
name string
pol string
nodes types.Nodes
wantFilter []tailcfg.FilterRule
}{
{
name: "empty-policy",
pol: "{}",
nodes: types.Nodes{},
wantFilter: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err)
filter := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
}
// TODO(kradalby): Test SSH Policy
})
}
}

1005
hscontrol/policy/v2/types.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,164 @@
package v2
import (
"errors"
"fmt"
"slices"
"strconv"
"strings"
"tailscale.com/tailcfg"
)
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
func splitDestinationAndPort(input string) (string, string, error) {
// Find the last occurrence of the colon character
lastColonIndex := strings.LastIndex(input, ":")
// Check if the colon character is present and not at the beginning or end of the string
if lastColonIndex == -1 {
return "", "", errors.New("input must contain a colon character separating destination and port")
}
if lastColonIndex == 0 {
return "", "", errors.New("input cannot start with a colon character")
}
if lastColonIndex == len(input)-1 {
return "", "", errors.New("input cannot end with a colon character")
}
// Split the string into destination and port based on the last colon
destination := input[:lastColonIndex]
port := input[lastColonIndex+1:]
return destination, port, nil
}
// parsePortRange parses a port definition string and returns a slice of PortRange structs.
func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
if portDef == "*" {
return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil
}
var portRanges []tailcfg.PortRange
parts := strings.Split(portDef, ",")
for _, part := range parts {
if strings.Contains(part, "-") {
rangeParts := strings.Split(part, "-")
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
return e == ""
})
if len(rangeParts) != 2 {
return nil, errors.New("invalid port range format")
}
first, err := parsePort(rangeParts[0])
if err != nil {
return nil, err
}
last, err := parsePort(rangeParts[1])
if err != nil {
return nil, err
}
if first > last {
return nil, errors.New("invalid port range: first port is greater than last port")
}
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
} else {
port, err := parsePort(part)
if err != nil {
return nil, err
}
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
}
}
return portRanges, nil
}
// parsePort parses a single port number from a string.
func parsePort(portStr string) (uint16, error) {
port, err := strconv.Atoi(portStr)
if err != nil {
return 0, errors.New("invalid port number")
}
if port < 0 || port > 65535 {
return 0, errors.New("port number out of range")
}
return uint16(port), nil
}
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
ProtocolFC = 133 // Fibre Channel
)
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return nil, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
return []int{protocolIPv4}, true, nil
case "tcp":
return []int{protocolTCP}, false, nil
case "egp":
return []int{protocolEGP}, true, nil
case "igp":
return []int{protocolIGP}, true, nil
case "udp":
return []int{protocolUDP}, false, nil
case "gre":
return []int{protocolGRE}, true, nil
case "esp":
return []int{protocolESP}, true, nil
case "ah":
return []int{protocolAH}, true, nil
case "sctp":
return []int{protocolSCTP}, false, nil
case "icmp":
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
}
// TODO(kradalby): What is this?
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
}

View File

@ -0,0 +1,102 @@
package v2
import (
"errors"
"testing"
"github.com/google/go-cmp/cmp"
"tailscale.com/tailcfg"
)
// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests.
func TestParseDestinationAndPort(t *testing.T) {
testCases := []struct {
input string
expectedDst string
expectedPort string
expectedErr error
}{
{"git-server:*", "git-server", "*", nil},
{"192.168.1.0/24:22", "192.168.1.0/24", "22", nil},
{"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil},
{"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil},
{"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil},
{"tag:api-server:443", "tag:api-server", "443", nil},
{"example-host-1:*", "example-host-1", "*", nil},
{"hostname:80-90", "hostname", "80-90", nil},
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
{":invalid", "", "", errors.New("input cannot start with a colon character")},
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
}
for _, testCase := range testCases {
dst, port, err := splitDestinationAndPort(testCase.input)
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
}
}
}
func TestParsePort(t *testing.T) {
tests := []struct {
input string
expected uint16
err string
}{
{"80", 80, ""},
{"0", 0, ""},
{"65535", 65535, ""},
{"-1", 0, "port number out of range"},
{"65536", 0, "port number out of range"},
{"abc", 0, "invalid port number"},
{"", 0, "invalid port number"},
}
for _, test := range tests {
result, err := parsePort(test.input)
if err != nil && err.Error() != test.err {
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
}
if err == nil && test.err != "" {
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
}
if result != test.expected {
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
}
}
}
func TestParsePortRange(t *testing.T) {
tests := []struct {
input string
expected []tailcfg.PortRange
err string
}{
{"80", []tailcfg.PortRange{{80, 80}}, ""},
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"},
{"80-90,", nil, "invalid port number"},
{"80,90-", nil, "invalid port range format"},
{"80-90,abc", nil, "invalid port number"},
{"80-90,65536", nil, "port number out of range"},
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
}
for _, test := range tests {
result, err := parsePortRange(test.input)
if err != nil && err.Error() != test.err {
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
}
if err == nil && test.err != "" {
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
}
if diff := cmp.Diff(result, test.expected); diff != "" {
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
}
}
}