1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

policy: make ACL action and protocol strict types

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-09-10 11:17:26 +02:00
parent 462ef80f42
commit 64ad05f1c5
No known key found for this signature in database
4 changed files with 250 additions and 92 deletions

View File

@ -28,7 +28,7 @@ func (pol *Policy) compileFilterRules(
var rules []tailcfg.FilterRule
for _, acl := range pol.ACLs {
if acl.Action != "accept" {
if acl.Action != ActionAccept {
return nil, ErrInvalidAction
}
@ -41,12 +41,7 @@ func (pol *Policy) compileFilterRules(
continue
}
// TODO(kradalby): integrate type into schema
// TODO(kradalby): figure out the _ is wildcard stuff
protocols, _, err := parseProtocol(acl.Protocol)
if err != nil {
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
}
protocols, _ := acl.Protocol.parseProtocol()
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {

View File

@ -1242,9 +1242,214 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
return ret, exitNodeSet, nil
}
// Action represents the action to take for an ACL rule.
type Action string
const (
ActionAccept Action = "accept"
)
// String returns the string representation of the Action.
func (a Action) String() string {
return string(a)
}
// UnmarshalJSON implements JSON unmarshaling for Action.
func (a *Action) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
switch str {
case "accept":
*a = ActionAccept
default:
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
}
return nil
}
// MarshalJSON implements JSON marshaling for Action.
func (a Action) MarshalJSON() ([]byte, error) {
return json.Marshal(string(a))
}
// Protocol represents a network protocol with its IANA number and descriptions.
type Protocol string
const (
ProtocolICMP Protocol = "icmp"
ProtocolIGMP Protocol = "igmp"
ProtocolIPv4 Protocol = "ipv4"
ProtocolIPInIP Protocol = "ip-in-ip"
ProtocolTCP Protocol = "tcp"
ProtocolEGP Protocol = "egp"
ProtocolIGP Protocol = "igp"
ProtocolUDP Protocol = "udp"
ProtocolGRE Protocol = "gre"
ProtocolESP Protocol = "esp"
ProtocolAH Protocol = "ah"
ProtocolIPv6ICMP Protocol = "ipv6-icmp"
ProtocolSCTP Protocol = "sctp"
ProtocolFC Protocol = "fc"
ProtocolWildcard Protocol = "*"
)
// String returns the string representation of the Protocol.
func (p Protocol) String() string {
return string(p)
}
// Description returns the human-readable description of the Protocol.
func (p Protocol) Description() string {
switch p {
case ProtocolICMP:
return "Internet Control Message Protocol"
case ProtocolIGMP:
return "Internet Group Management Protocol"
case ProtocolIPv4:
return "IPv4 encapsulation"
case ProtocolTCP:
return "Transmission Control Protocol"
case ProtocolEGP:
return "Exterior Gateway Protocol"
case ProtocolIGP:
return "Interior Gateway Protocol"
case ProtocolUDP:
return "User Datagram Protocol"
case ProtocolGRE:
return "Generic Routing Encapsulation"
case ProtocolESP:
return "Encapsulating Security Payload"
case ProtocolAH:
return "Authentication Header"
case ProtocolIPv6ICMP:
return "Internet Control Message Protocol for IPv6"
case ProtocolSCTP:
return "Stream Control Transmission Protocol"
case ProtocolFC:
return "Fibre Channel"
case ProtocolWildcard:
return "Wildcard (all protocols)"
default:
return "Unknown Protocol"
}
}
// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement.
// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values.
func (p Protocol) parseProtocol() ([]int, bool) {
switch p {
case "":
return nil, false
case ProtocolWildcard:
// Wildcard protocol - allows all protocols like empty string
return nil, false
case ProtocolIGMP:
return []int{protocolIGMP}, true
case ProtocolIPv4, ProtocolIPInIP:
return []int{protocolIPv4}, true
case ProtocolTCP:
return []int{protocolTCP}, false
case ProtocolEGP:
return []int{protocolEGP}, true
case ProtocolIGP:
return []int{protocolIGP}, true
case ProtocolUDP:
return []int{protocolUDP}, false
case ProtocolGRE:
return []int{protocolGRE}, true
case ProtocolESP:
return []int{protocolESP}, true
case ProtocolAH:
return []int{protocolAH}, true
case ProtocolSCTP:
return []int{protocolSCTP}, false
case ProtocolICMP:
return []int{protocolICMP, protocolIPv6ICMP}, true
default:
// Try to parse as a numeric protocol number
// This should not fail since validation happened during unmarshaling
protocolNumber, _ := strconv.Atoi(string(p))
// Determine if wildcard is needed based on protocol number
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard
}
}
// UnmarshalJSON implements JSON unmarshaling for Protocol.
func (p *Protocol) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
// Normalize to lowercase for case-insensitive matching
*p = Protocol(strings.ToLower(str))
// Validate the protocol
if err := p.validate(); err != nil {
return err
}
return nil
}
// validate checks if the Protocol is valid.
func (p Protocol) validate() error {
switch p {
case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP,
ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE,
ProtocolESP, ProtocolAH, ProtocolSCTP:
return nil
case ProtocolWildcard:
// Wildcard "*" is not allowed - Tailscale rejects it
return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
default:
// Try to parse as a numeric protocol number
str := string(p)
// Check for leading zeros (not allowed by Tailscale)
if str == "0" || (len(str) > 1 && str[0] == '0') {
return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str)
}
protocolNumber, err := strconv.Atoi(str)
if err != nil {
return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p)
}
if protocolNumber < 0 || protocolNumber > 255 {
return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber)
}
return nil
}
}
// MarshalJSON implements JSON marshaling for Protocol.
func (p Protocol) MarshalJSON() ([]byte, error) {
return json.Marshal(string(p))
}
// Protocol constants matching the IANA numbers
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
protocolFC = 133 // Fibre Channel
)
type ACL struct {
Action string `json:"action"` // TODO(kradalby): add strict type
Protocol string `json:"proto"` // TODO(kradalby): add strict type
Action Action `json:"action"`
Protocol Protocol `json:"proto"`
Sources Aliases `json:"src"`
Destinations []AliasWithPorts `json:"dst"`
}

View File

@ -2464,13 +2464,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
"tags": ["web", "production"],
"created_by": "admin"
},
"action": "deny",
"action": "accept",
"proto": "udp",
"src": ["*"],
"dst": ["autogroup:internet:53"]
}`,
expected: ACL{
Action: "deny",
Action: ActionAccept,
Protocol: "udp",
Sources: []Alias{Wildcard},
Destinations: []AliasWithPorts{
@ -2482,6 +2482,16 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
},
wantErr: false,
},
{
name: "invalid action should fail",
input: `{
"action": "deny",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}`,
wantErr: true,
},
{
name: "no comment fields",
input: `{
@ -2491,13 +2501,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
"dst": ["tag:server:*"]
}`,
expected: ACL{
Action: "accept",
Action: ActionAccept,
Protocol: "icmp",
Sources: []Alias{mustParseAlias("tag:client")},
Destinations: []AliasWithPorts{
{
Alias: mustParseAlias("tag:server"),
Ports: []tailcfg.PortRange{{First: 0, Last: 65535}},
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
@ -2510,8 +2520,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
"#reason": "Temporary disable for maintenance"
}`,
expected: ACL{
Action: "",
Protocol: "",
Action: Action(""),
Protocol: Protocol(""),
Sources: nil,
Destinations: nil,
},
@ -2623,9 +2633,9 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
"dst": ["tag:server:22,80,443"]
},
{
"#note": "Deny all other traffic",
"action": "deny",
"proto": "*",
"#note": "Allow all other traffic",
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}
@ -2641,19 +2651,37 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
// First ACL
acl1 := policy.ACLs[0]
assert.Equal(t, "accept", acl1.Action)
assert.Equal(t, "tcp", acl1.Protocol)
assert.Equal(t, ActionAccept, acl1.Action)
assert.Equal(t, Protocol("tcp"), acl1.Protocol)
require.Len(t, acl1.Sources, 1)
require.Len(t, acl1.Destinations, 1)
// Second ACL
acl2 := policy.ACLs[1]
assert.Equal(t, "deny", acl2.Action)
assert.Equal(t, "*", acl2.Protocol)
assert.Equal(t, ActionAccept, acl2.Action)
assert.Equal(t, Protocol("tcp"), acl2.Protocol)
require.Len(t, acl2.Sources, 1)
require.Len(t, acl2.Destinations, 1)
}
func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
// Test that invalid actions are rejected
policyJSON := `{
"acls": [
{
"action": "deny",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}
]
}`
_, err := unmarshalPolicy([]byte(policyJSON))
require.Error(t, err)
assert.Contains(t, err.Error(), `invalid action "deny"`)
}
// Helper function to parse aliases for testing
func mustParseAlias(s string) Alias {
alias, err := parseAlias(s)

View File

@ -2,7 +2,6 @@ package v2
import (
"errors"
"fmt"
"slices"
"strconv"
"strings"
@ -97,72 +96,3 @@ func parsePort(portStr string) (uint16, error) {
return uint16(port), nil
}
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
ProtocolFC = 133 // Fibre Channel
)
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return nil, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
return []int{protocolIPv4}, true, nil
case "tcp":
return []int{protocolTCP}, false, nil
case "egp":
return []int{protocolEGP}, true, nil
case "igp":
return []int{protocolIGP}, true, nil
case "udp":
return []int{protocolUDP}, false, nil
case "gre":
return []int{protocolGRE}, true, nil
case "esp":
return []int{protocolESP}, true, nil
case "ah":
return []int{protocolAH}, true, nil
case "sctp":
return []int{protocolSCTP}, false, nil
case "icmp":
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
}
// TODO(kradalby): What is this?
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
}