mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
policy: make ACL action and protocol strict types
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
462ef80f42
commit
64ad05f1c5
@ -28,7 +28,7 @@ func (pol *Policy) compileFilterRules(
|
||||
var rules []tailcfg.FilterRule
|
||||
|
||||
for _, acl := range pol.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
if acl.Action != ActionAccept {
|
||||
return nil, ErrInvalidAction
|
||||
}
|
||||
|
||||
@ -41,12 +41,7 @@ func (pol *Policy) compileFilterRules(
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(kradalby): integrate type into schema
|
||||
// TODO(kradalby): figure out the _ is wildcard stuff
|
||||
protocols, _, err := parseProtocol(acl.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
|
||||
}
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
|
@ -1242,9 +1242,214 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
||||
return ret, exitNodeSet, nil
|
||||
}
|
||||
|
||||
// Action represents the action to take for an ACL rule.
|
||||
type Action string
|
||||
|
||||
const (
|
||||
ActionAccept Action = "accept"
|
||||
)
|
||||
|
||||
// String returns the string representation of the Action.
|
||||
func (a Action) String() string {
|
||||
return string(a)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements JSON unmarshaling for Action.
|
||||
func (a *Action) UnmarshalJSON(b []byte) error {
|
||||
str := strings.Trim(string(b), `"`)
|
||||
switch str {
|
||||
case "accept":
|
||||
*a = ActionAccept
|
||||
default:
|
||||
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements JSON marshaling for Action.
|
||||
func (a Action) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(string(a))
|
||||
}
|
||||
|
||||
// Protocol represents a network protocol with its IANA number and descriptions.
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
ProtocolIGMP Protocol = "igmp"
|
||||
ProtocolIPv4 Protocol = "ipv4"
|
||||
ProtocolIPInIP Protocol = "ip-in-ip"
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
ProtocolEGP Protocol = "egp"
|
||||
ProtocolIGP Protocol = "igp"
|
||||
ProtocolUDP Protocol = "udp"
|
||||
ProtocolGRE Protocol = "gre"
|
||||
ProtocolESP Protocol = "esp"
|
||||
ProtocolAH Protocol = "ah"
|
||||
ProtocolIPv6ICMP Protocol = "ipv6-icmp"
|
||||
ProtocolSCTP Protocol = "sctp"
|
||||
ProtocolFC Protocol = "fc"
|
||||
ProtocolWildcard Protocol = "*"
|
||||
)
|
||||
|
||||
// String returns the string representation of the Protocol.
|
||||
func (p Protocol) String() string {
|
||||
return string(p)
|
||||
}
|
||||
|
||||
// Description returns the human-readable description of the Protocol.
|
||||
func (p Protocol) Description() string {
|
||||
switch p {
|
||||
case ProtocolICMP:
|
||||
return "Internet Control Message Protocol"
|
||||
case ProtocolIGMP:
|
||||
return "Internet Group Management Protocol"
|
||||
case ProtocolIPv4:
|
||||
return "IPv4 encapsulation"
|
||||
case ProtocolTCP:
|
||||
return "Transmission Control Protocol"
|
||||
case ProtocolEGP:
|
||||
return "Exterior Gateway Protocol"
|
||||
case ProtocolIGP:
|
||||
return "Interior Gateway Protocol"
|
||||
case ProtocolUDP:
|
||||
return "User Datagram Protocol"
|
||||
case ProtocolGRE:
|
||||
return "Generic Routing Encapsulation"
|
||||
case ProtocolESP:
|
||||
return "Encapsulating Security Payload"
|
||||
case ProtocolAH:
|
||||
return "Authentication Header"
|
||||
case ProtocolIPv6ICMP:
|
||||
return "Internet Control Message Protocol for IPv6"
|
||||
case ProtocolSCTP:
|
||||
return "Stream Control Transmission Protocol"
|
||||
case ProtocolFC:
|
||||
return "Fibre Channel"
|
||||
case ProtocolWildcard:
|
||||
return "Wildcard (all protocols)"
|
||||
default:
|
||||
return "Unknown Protocol"
|
||||
}
|
||||
}
|
||||
|
||||
// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement.
|
||||
// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values.
|
||||
func (p Protocol) parseProtocol() ([]int, bool) {
|
||||
switch p {
|
||||
case "":
|
||||
return nil, false
|
||||
case ProtocolWildcard:
|
||||
// Wildcard protocol - allows all protocols like empty string
|
||||
return nil, false
|
||||
case ProtocolIGMP:
|
||||
return []int{protocolIGMP}, true
|
||||
case ProtocolIPv4, ProtocolIPInIP:
|
||||
return []int{protocolIPv4}, true
|
||||
case ProtocolTCP:
|
||||
return []int{protocolTCP}, false
|
||||
case ProtocolEGP:
|
||||
return []int{protocolEGP}, true
|
||||
case ProtocolIGP:
|
||||
return []int{protocolIGP}, true
|
||||
case ProtocolUDP:
|
||||
return []int{protocolUDP}, false
|
||||
case ProtocolGRE:
|
||||
return []int{protocolGRE}, true
|
||||
case ProtocolESP:
|
||||
return []int{protocolESP}, true
|
||||
case ProtocolAH:
|
||||
return []int{protocolAH}, true
|
||||
case ProtocolSCTP:
|
||||
return []int{protocolSCTP}, false
|
||||
case ProtocolICMP:
|
||||
return []int{protocolICMP, protocolIPv6ICMP}, true
|
||||
default:
|
||||
// Try to parse as a numeric protocol number
|
||||
// This should not fail since validation happened during unmarshaling
|
||||
protocolNumber, _ := strconv.Atoi(string(p))
|
||||
|
||||
// Determine if wildcard is needed based on protocol number
|
||||
needsWildcard := protocolNumber != protocolTCP &&
|
||||
protocolNumber != protocolUDP &&
|
||||
protocolNumber != protocolSCTP
|
||||
|
||||
return []int{protocolNumber}, needsWildcard
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements JSON unmarshaling for Protocol.
|
||||
func (p *Protocol) UnmarshalJSON(b []byte) error {
|
||||
str := strings.Trim(string(b), `"`)
|
||||
|
||||
// Normalize to lowercase for case-insensitive matching
|
||||
*p = Protocol(strings.ToLower(str))
|
||||
|
||||
// Validate the protocol
|
||||
if err := p.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate checks if the Protocol is valid.
|
||||
func (p Protocol) validate() error {
|
||||
switch p {
|
||||
case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP,
|
||||
ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE,
|
||||
ProtocolESP, ProtocolAH, ProtocolSCTP:
|
||||
return nil
|
||||
case ProtocolWildcard:
|
||||
// Wildcard "*" is not allowed - Tailscale rejects it
|
||||
return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
|
||||
default:
|
||||
// Try to parse as a numeric protocol number
|
||||
str := string(p)
|
||||
|
||||
// Check for leading zeros (not allowed by Tailscale)
|
||||
if str == "0" || (len(str) > 1 && str[0] == '0') {
|
||||
return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str)
|
||||
}
|
||||
|
||||
protocolNumber, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p)
|
||||
}
|
||||
|
||||
if protocolNumber < 0 || protocolNumber > 255 {
|
||||
return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements JSON marshaling for Protocol.
|
||||
func (p Protocol) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(string(p))
|
||||
}
|
||||
|
||||
// Protocol constants matching the IANA numbers
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
protocolIGMP = 2 // Internet Group Management
|
||||
protocolIPv4 = 4 // IPv4 encapsulation
|
||||
protocolTCP = 6 // Transmission Control
|
||||
protocolEGP = 8 // Exterior Gateway Protocol
|
||||
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||
protocolUDP = 17 // User Datagram
|
||||
protocolGRE = 47 // Generic Routing Encapsulation
|
||||
protocolESP = 50 // Encap Security Payload
|
||||
protocolAH = 51 // Authentication Header
|
||||
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||
protocolFC = 133 // Fibre Channel
|
||||
)
|
||||
|
||||
type ACL struct {
|
||||
Action string `json:"action"` // TODO(kradalby): add strict type
|
||||
Protocol string `json:"proto"` // TODO(kradalby): add strict type
|
||||
Action Action `json:"action"`
|
||||
Protocol Protocol `json:"proto"`
|
||||
Sources Aliases `json:"src"`
|
||||
Destinations []AliasWithPorts `json:"dst"`
|
||||
}
|
||||
|
@ -2464,13 +2464,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
||||
"tags": ["web", "production"],
|
||||
"created_by": "admin"
|
||||
},
|
||||
"action": "deny",
|
||||
"action": "accept",
|
||||
"proto": "udp",
|
||||
"src": ["*"],
|
||||
"dst": ["autogroup:internet:53"]
|
||||
}`,
|
||||
expected: ACL{
|
||||
Action: "deny",
|
||||
Action: ActionAccept,
|
||||
Protocol: "udp",
|
||||
Sources: []Alias{Wildcard},
|
||||
Destinations: []AliasWithPorts{
|
||||
@ -2482,6 +2482,16 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid action should fail",
|
||||
input: `{
|
||||
"action": "deny",
|
||||
"proto": "tcp",
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no comment fields",
|
||||
input: `{
|
||||
@ -2491,13 +2501,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
||||
"dst": ["tag:server:*"]
|
||||
}`,
|
||||
expected: ACL{
|
||||
Action: "accept",
|
||||
Action: ActionAccept,
|
||||
Protocol: "icmp",
|
||||
Sources: []Alias{mustParseAlias("tag:client")},
|
||||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: mustParseAlias("tag:server"),
|
||||
Ports: []tailcfg.PortRange{{First: 0, Last: 65535}},
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -2510,8 +2520,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
||||
"#reason": "Temporary disable for maintenance"
|
||||
}`,
|
||||
expected: ACL{
|
||||
Action: "",
|
||||
Protocol: "",
|
||||
Action: Action(""),
|
||||
Protocol: Protocol(""),
|
||||
Sources: nil,
|
||||
Destinations: nil,
|
||||
},
|
||||
@ -2623,9 +2633,9 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
|
||||
"dst": ["tag:server:22,80,443"]
|
||||
},
|
||||
{
|
||||
"#note": "Deny all other traffic",
|
||||
"action": "deny",
|
||||
"proto": "*",
|
||||
"#note": "Allow all other traffic",
|
||||
"action": "accept",
|
||||
"proto": "tcp",
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
@ -2641,19 +2651,37 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
|
||||
|
||||
// First ACL
|
||||
acl1 := policy.ACLs[0]
|
||||
assert.Equal(t, "accept", acl1.Action)
|
||||
assert.Equal(t, "tcp", acl1.Protocol)
|
||||
assert.Equal(t, ActionAccept, acl1.Action)
|
||||
assert.Equal(t, Protocol("tcp"), acl1.Protocol)
|
||||
require.Len(t, acl1.Sources, 1)
|
||||
require.Len(t, acl1.Destinations, 1)
|
||||
|
||||
// Second ACL
|
||||
acl2 := policy.ACLs[1]
|
||||
assert.Equal(t, "deny", acl2.Action)
|
||||
assert.Equal(t, "*", acl2.Protocol)
|
||||
assert.Equal(t, ActionAccept, acl2.Action)
|
||||
assert.Equal(t, Protocol("tcp"), acl2.Protocol)
|
||||
require.Len(t, acl2.Sources, 1)
|
||||
require.Len(t, acl2.Destinations, 1)
|
||||
}
|
||||
|
||||
func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
|
||||
// Test that invalid actions are rejected
|
||||
policyJSON := `{
|
||||
"acls": [
|
||||
{
|
||||
"action": "deny",
|
||||
"proto": "tcp",
|
||||
"src": ["*"],
|
||||
"dst": ["*:*"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
_, err := unmarshalPolicy([]byte(policyJSON))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `invalid action "deny"`)
|
||||
}
|
||||
|
||||
// Helper function to parse aliases for testing
|
||||
func mustParseAlias(s string) Alias {
|
||||
alias, err := parseAlias(s)
|
||||
|
@ -2,7 +2,6 @@ package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -97,72 +96,3 @@ func parsePort(portStr string) (uint16, error) {
|
||||
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
protocolIGMP = 2 // Internet Group Management
|
||||
protocolIPv4 = 4 // IPv4 encapsulation
|
||||
protocolTCP = 6 // Transmission Control
|
||||
protocolEGP = 8 // Exterior Gateway Protocol
|
||||
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||
protocolUDP = 17 // User Datagram
|
||||
protocolGRE = 47 // Generic Routing Encapsulation
|
||||
protocolESP = 50 // Encap Security Payload
|
||||
protocolAH = 51 // Authentication Header
|
||||
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||
ProtocolFC = 133 // Fibre Channel
|
||||
)
|
||||
|
||||
// parseProtocol reads the proto field of the ACL and generates a list of
|
||||
// protocols that will be allowed, following the IANA IP protocol number
|
||||
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||
//
|
||||
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
|
||||
// as per Tailscale behaviour (see tailcfg.FilterRule).
|
||||
//
|
||||
// Also returns a boolean indicating if the protocol
|
||||
// requires all the destinations to use wildcard as port number (only TCP,
|
||||
// UDP and SCTP support specifying ports).
|
||||
func parseProtocol(protocol string) ([]int, bool, error) {
|
||||
switch protocol {
|
||||
case "":
|
||||
return nil, false, nil
|
||||
case "igmp":
|
||||
return []int{protocolIGMP}, true, nil
|
||||
case "ipv4", "ip-in-ip":
|
||||
return []int{protocolIPv4}, true, nil
|
||||
case "tcp":
|
||||
return []int{protocolTCP}, false, nil
|
||||
case "egp":
|
||||
return []int{protocolEGP}, true, nil
|
||||
case "igp":
|
||||
return []int{protocolIGP}, true, nil
|
||||
case "udp":
|
||||
return []int{protocolUDP}, false, nil
|
||||
case "gre":
|
||||
return []int{protocolGRE}, true, nil
|
||||
case "esp":
|
||||
return []int{protocolESP}, true, nil
|
||||
case "ah":
|
||||
return []int{protocolAH}, true, nil
|
||||
case "sctp":
|
||||
return []int{protocolSCTP}, false, nil
|
||||
case "icmp":
|
||||
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
|
||||
|
||||
default:
|
||||
protocolNumber, err := strconv.Atoi(protocol)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): What is this?
|
||||
needsWildcard := protocolNumber != protocolTCP &&
|
||||
protocolNumber != protocolUDP &&
|
||||
protocolNumber != protocolSCTP
|
||||
|
||||
return []int{protocolNumber}, needsWildcard, nil
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user