mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
policy: make ACL action and protocol strict types
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
462ef80f42
commit
64ad05f1c5
@ -28,7 +28,7 @@ func (pol *Policy) compileFilterRules(
|
|||||||
var rules []tailcfg.FilterRule
|
var rules []tailcfg.FilterRule
|
||||||
|
|
||||||
for _, acl := range pol.ACLs {
|
for _, acl := range pol.ACLs {
|
||||||
if acl.Action != "accept" {
|
if acl.Action != ActionAccept {
|
||||||
return nil, ErrInvalidAction
|
return nil, ErrInvalidAction
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,12 +41,7 @@ func (pol *Policy) compileFilterRules(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): integrate type into schema
|
protocols, _ := acl.Protocol.parseProtocol()
|
||||||
// TODO(kradalby): figure out the _ is wildcard stuff
|
|
||||||
protocols, _, err := parseProtocol(acl.Protocol)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var destPorts []tailcfg.NetPortRange
|
var destPorts []tailcfg.NetPortRange
|
||||||
for _, dest := range acl.Destinations {
|
for _, dest := range acl.Destinations {
|
||||||
|
@ -1242,9 +1242,214 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||||||
return ret, exitNodeSet, nil
|
return ret, exitNodeSet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Action represents the action to take for an ACL rule.
|
||||||
|
type Action string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ActionAccept Action = "accept"
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of the Action.
|
||||||
|
func (a Action) String() string {
|
||||||
|
return string(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements JSON unmarshaling for Action.
|
||||||
|
func (a *Action) UnmarshalJSON(b []byte) error {
|
||||||
|
str := strings.Trim(string(b), `"`)
|
||||||
|
switch str {
|
||||||
|
case "accept":
|
||||||
|
*a = ActionAccept
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements JSON marshaling for Action.
|
||||||
|
func (a Action) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(string(a))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol represents a network protocol with its IANA number and descriptions.
|
||||||
|
type Protocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolICMP Protocol = "icmp"
|
||||||
|
ProtocolIGMP Protocol = "igmp"
|
||||||
|
ProtocolIPv4 Protocol = "ipv4"
|
||||||
|
ProtocolIPInIP Protocol = "ip-in-ip"
|
||||||
|
ProtocolTCP Protocol = "tcp"
|
||||||
|
ProtocolEGP Protocol = "egp"
|
||||||
|
ProtocolIGP Protocol = "igp"
|
||||||
|
ProtocolUDP Protocol = "udp"
|
||||||
|
ProtocolGRE Protocol = "gre"
|
||||||
|
ProtocolESP Protocol = "esp"
|
||||||
|
ProtocolAH Protocol = "ah"
|
||||||
|
ProtocolIPv6ICMP Protocol = "ipv6-icmp"
|
||||||
|
ProtocolSCTP Protocol = "sctp"
|
||||||
|
ProtocolFC Protocol = "fc"
|
||||||
|
ProtocolWildcard Protocol = "*"
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of the Protocol.
|
||||||
|
func (p Protocol) String() string {
|
||||||
|
return string(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description returns the human-readable description of the Protocol.
|
||||||
|
func (p Protocol) Description() string {
|
||||||
|
switch p {
|
||||||
|
case ProtocolICMP:
|
||||||
|
return "Internet Control Message Protocol"
|
||||||
|
case ProtocolIGMP:
|
||||||
|
return "Internet Group Management Protocol"
|
||||||
|
case ProtocolIPv4:
|
||||||
|
return "IPv4 encapsulation"
|
||||||
|
case ProtocolTCP:
|
||||||
|
return "Transmission Control Protocol"
|
||||||
|
case ProtocolEGP:
|
||||||
|
return "Exterior Gateway Protocol"
|
||||||
|
case ProtocolIGP:
|
||||||
|
return "Interior Gateway Protocol"
|
||||||
|
case ProtocolUDP:
|
||||||
|
return "User Datagram Protocol"
|
||||||
|
case ProtocolGRE:
|
||||||
|
return "Generic Routing Encapsulation"
|
||||||
|
case ProtocolESP:
|
||||||
|
return "Encapsulating Security Payload"
|
||||||
|
case ProtocolAH:
|
||||||
|
return "Authentication Header"
|
||||||
|
case ProtocolIPv6ICMP:
|
||||||
|
return "Internet Control Message Protocol for IPv6"
|
||||||
|
case ProtocolSCTP:
|
||||||
|
return "Stream Control Transmission Protocol"
|
||||||
|
case ProtocolFC:
|
||||||
|
return "Fibre Channel"
|
||||||
|
case ProtocolWildcard:
|
||||||
|
return "Wildcard (all protocols)"
|
||||||
|
default:
|
||||||
|
return "Unknown Protocol"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement.
|
||||||
|
// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values.
|
||||||
|
func (p Protocol) parseProtocol() ([]int, bool) {
|
||||||
|
switch p {
|
||||||
|
case "":
|
||||||
|
return nil, false
|
||||||
|
case ProtocolWildcard:
|
||||||
|
// Wildcard protocol - allows all protocols like empty string
|
||||||
|
return nil, false
|
||||||
|
case ProtocolIGMP:
|
||||||
|
return []int{protocolIGMP}, true
|
||||||
|
case ProtocolIPv4, ProtocolIPInIP:
|
||||||
|
return []int{protocolIPv4}, true
|
||||||
|
case ProtocolTCP:
|
||||||
|
return []int{protocolTCP}, false
|
||||||
|
case ProtocolEGP:
|
||||||
|
return []int{protocolEGP}, true
|
||||||
|
case ProtocolIGP:
|
||||||
|
return []int{protocolIGP}, true
|
||||||
|
case ProtocolUDP:
|
||||||
|
return []int{protocolUDP}, false
|
||||||
|
case ProtocolGRE:
|
||||||
|
return []int{protocolGRE}, true
|
||||||
|
case ProtocolESP:
|
||||||
|
return []int{protocolESP}, true
|
||||||
|
case ProtocolAH:
|
||||||
|
return []int{protocolAH}, true
|
||||||
|
case ProtocolSCTP:
|
||||||
|
return []int{protocolSCTP}, false
|
||||||
|
case ProtocolICMP:
|
||||||
|
return []int{protocolICMP, protocolIPv6ICMP}, true
|
||||||
|
default:
|
||||||
|
// Try to parse as a numeric protocol number
|
||||||
|
// This should not fail since validation happened during unmarshaling
|
||||||
|
protocolNumber, _ := strconv.Atoi(string(p))
|
||||||
|
|
||||||
|
// Determine if wildcard is needed based on protocol number
|
||||||
|
needsWildcard := protocolNumber != protocolTCP &&
|
||||||
|
protocolNumber != protocolUDP &&
|
||||||
|
protocolNumber != protocolSCTP
|
||||||
|
|
||||||
|
return []int{protocolNumber}, needsWildcard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements JSON unmarshaling for Protocol.
|
||||||
|
func (p *Protocol) UnmarshalJSON(b []byte) error {
|
||||||
|
str := strings.Trim(string(b), `"`)
|
||||||
|
|
||||||
|
// Normalize to lowercase for case-insensitive matching
|
||||||
|
*p = Protocol(strings.ToLower(str))
|
||||||
|
|
||||||
|
// Validate the protocol
|
||||||
|
if err := p.validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate checks if the Protocol is valid.
|
||||||
|
func (p Protocol) validate() error {
|
||||||
|
switch p {
|
||||||
|
case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP,
|
||||||
|
ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE,
|
||||||
|
ProtocolESP, ProtocolAH, ProtocolSCTP:
|
||||||
|
return nil
|
||||||
|
case ProtocolWildcard:
|
||||||
|
// Wildcard "*" is not allowed - Tailscale rejects it
|
||||||
|
return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
|
||||||
|
default:
|
||||||
|
// Try to parse as a numeric protocol number
|
||||||
|
str := string(p)
|
||||||
|
|
||||||
|
// Check for leading zeros (not allowed by Tailscale)
|
||||||
|
if str == "0" || (len(str) > 1 && str[0] == '0') {
|
||||||
|
return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str)
|
||||||
|
}
|
||||||
|
|
||||||
|
protocolNumber, err := strconv.Atoi(str)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocolNumber < 0 || protocolNumber > 255 {
|
||||||
|
return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements JSON marshaling for Protocol.
|
||||||
|
func (p Protocol) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(string(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol constants matching the IANA numbers
|
||||||
|
const (
|
||||||
|
protocolICMP = 1 // Internet Control Message
|
||||||
|
protocolIGMP = 2 // Internet Group Management
|
||||||
|
protocolIPv4 = 4 // IPv4 encapsulation
|
||||||
|
protocolTCP = 6 // Transmission Control
|
||||||
|
protocolEGP = 8 // Exterior Gateway Protocol
|
||||||
|
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||||
|
protocolUDP = 17 // User Datagram
|
||||||
|
protocolGRE = 47 // Generic Routing Encapsulation
|
||||||
|
protocolESP = 50 // Encap Security Payload
|
||||||
|
protocolAH = 51 // Authentication Header
|
||||||
|
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||||
|
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||||
|
protocolFC = 133 // Fibre Channel
|
||||||
|
)
|
||||||
|
|
||||||
type ACL struct {
|
type ACL struct {
|
||||||
Action string `json:"action"` // TODO(kradalby): add strict type
|
Action Action `json:"action"`
|
||||||
Protocol string `json:"proto"` // TODO(kradalby): add strict type
|
Protocol Protocol `json:"proto"`
|
||||||
Sources Aliases `json:"src"`
|
Sources Aliases `json:"src"`
|
||||||
Destinations []AliasWithPorts `json:"dst"`
|
Destinations []AliasWithPorts `json:"dst"`
|
||||||
}
|
}
|
||||||
|
@ -2464,13 +2464,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
|||||||
"tags": ["web", "production"],
|
"tags": ["web", "production"],
|
||||||
"created_by": "admin"
|
"created_by": "admin"
|
||||||
},
|
},
|
||||||
"action": "deny",
|
"action": "accept",
|
||||||
"proto": "udp",
|
"proto": "udp",
|
||||||
"src": ["*"],
|
"src": ["*"],
|
||||||
"dst": ["autogroup:internet:53"]
|
"dst": ["autogroup:internet:53"]
|
||||||
}`,
|
}`,
|
||||||
expected: ACL{
|
expected: ACL{
|
||||||
Action: "deny",
|
Action: ActionAccept,
|
||||||
Protocol: "udp",
|
Protocol: "udp",
|
||||||
Sources: []Alias{Wildcard},
|
Sources: []Alias{Wildcard},
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
@ -2482,6 +2482,16 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "invalid action should fail",
|
||||||
|
input: `{
|
||||||
|
"action": "deny",
|
||||||
|
"proto": "tcp",
|
||||||
|
"src": ["*"],
|
||||||
|
"dst": ["*:*"]
|
||||||
|
}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no comment fields",
|
name: "no comment fields",
|
||||||
input: `{
|
input: `{
|
||||||
@ -2491,13 +2501,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
|||||||
"dst": ["tag:server:*"]
|
"dst": ["tag:server:*"]
|
||||||
}`,
|
}`,
|
||||||
expected: ACL{
|
expected: ACL{
|
||||||
Action: "accept",
|
Action: ActionAccept,
|
||||||
Protocol: "icmp",
|
Protocol: "icmp",
|
||||||
Sources: []Alias{mustParseAlias("tag:client")},
|
Sources: []Alias{mustParseAlias("tag:client")},
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
{
|
{
|
||||||
Alias: mustParseAlias("tag:server"),
|
Alias: mustParseAlias("tag:server"),
|
||||||
Ports: []tailcfg.PortRange{{First: 0, Last: 65535}},
|
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -2510,8 +2520,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
|||||||
"#reason": "Temporary disable for maintenance"
|
"#reason": "Temporary disable for maintenance"
|
||||||
}`,
|
}`,
|
||||||
expected: ACL{
|
expected: ACL{
|
||||||
Action: "",
|
Action: Action(""),
|
||||||
Protocol: "",
|
Protocol: Protocol(""),
|
||||||
Sources: nil,
|
Sources: nil,
|
||||||
Destinations: nil,
|
Destinations: nil,
|
||||||
},
|
},
|
||||||
@ -2623,9 +2633,9 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
|
|||||||
"dst": ["tag:server:22,80,443"]
|
"dst": ["tag:server:22,80,443"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"#note": "Deny all other traffic",
|
"#note": "Allow all other traffic",
|
||||||
"action": "deny",
|
"action": "accept",
|
||||||
"proto": "*",
|
"proto": "tcp",
|
||||||
"src": ["*"],
|
"src": ["*"],
|
||||||
"dst": ["*:*"]
|
"dst": ["*:*"]
|
||||||
}
|
}
|
||||||
@ -2641,19 +2651,37 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
|
|||||||
|
|
||||||
// First ACL
|
// First ACL
|
||||||
acl1 := policy.ACLs[0]
|
acl1 := policy.ACLs[0]
|
||||||
assert.Equal(t, "accept", acl1.Action)
|
assert.Equal(t, ActionAccept, acl1.Action)
|
||||||
assert.Equal(t, "tcp", acl1.Protocol)
|
assert.Equal(t, Protocol("tcp"), acl1.Protocol)
|
||||||
require.Len(t, acl1.Sources, 1)
|
require.Len(t, acl1.Sources, 1)
|
||||||
require.Len(t, acl1.Destinations, 1)
|
require.Len(t, acl1.Destinations, 1)
|
||||||
|
|
||||||
// Second ACL
|
// Second ACL
|
||||||
acl2 := policy.ACLs[1]
|
acl2 := policy.ACLs[1]
|
||||||
assert.Equal(t, "deny", acl2.Action)
|
assert.Equal(t, ActionAccept, acl2.Action)
|
||||||
assert.Equal(t, "*", acl2.Protocol)
|
assert.Equal(t, Protocol("tcp"), acl2.Protocol)
|
||||||
require.Len(t, acl2.Sources, 1)
|
require.Len(t, acl2.Sources, 1)
|
||||||
require.Len(t, acl2.Destinations, 1)
|
require.Len(t, acl2.Destinations, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
|
||||||
|
// Test that invalid actions are rejected
|
||||||
|
policyJSON := `{
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "deny",
|
||||||
|
"proto": "tcp",
|
||||||
|
"src": ["*"],
|
||||||
|
"dst": ["*:*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
_, err := unmarshalPolicy([]byte(policyJSON))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `invalid action "deny"`)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to parse aliases for testing
|
// Helper function to parse aliases for testing
|
||||||
func mustParseAlias(s string) Alias {
|
func mustParseAlias(s string) Alias {
|
||||||
alias, err := parseAlias(s)
|
alias, err := parseAlias(s)
|
||||||
|
@ -2,7 +2,6 @@ package v2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -97,72 +96,3 @@ func parsePort(portStr string) (uint16, error) {
|
|||||||
|
|
||||||
return uint16(port), nil
|
return uint16(port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
|
||||||
const (
|
|
||||||
protocolICMP = 1 // Internet Control Message
|
|
||||||
protocolIGMP = 2 // Internet Group Management
|
|
||||||
protocolIPv4 = 4 // IPv4 encapsulation
|
|
||||||
protocolTCP = 6 // Transmission Control
|
|
||||||
protocolEGP = 8 // Exterior Gateway Protocol
|
|
||||||
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
|
||||||
protocolUDP = 17 // User Datagram
|
|
||||||
protocolGRE = 47 // Generic Routing Encapsulation
|
|
||||||
protocolESP = 50 // Encap Security Payload
|
|
||||||
protocolAH = 51 // Authentication Header
|
|
||||||
protocolIPv6ICMP = 58 // ICMP for IPv6
|
|
||||||
protocolSCTP = 132 // Stream Control Transmission Protocol
|
|
||||||
ProtocolFC = 133 // Fibre Channel
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseProtocol reads the proto field of the ACL and generates a list of
|
|
||||||
// protocols that will be allowed, following the IANA IP protocol number
|
|
||||||
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
|
||||||
//
|
|
||||||
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
|
|
||||||
// as per Tailscale behaviour (see tailcfg.FilterRule).
|
|
||||||
//
|
|
||||||
// Also returns a boolean indicating if the protocol
|
|
||||||
// requires all the destinations to use wildcard as port number (only TCP,
|
|
||||||
// UDP and SCTP support specifying ports).
|
|
||||||
func parseProtocol(protocol string) ([]int, bool, error) {
|
|
||||||
switch protocol {
|
|
||||||
case "":
|
|
||||||
return nil, false, nil
|
|
||||||
case "igmp":
|
|
||||||
return []int{protocolIGMP}, true, nil
|
|
||||||
case "ipv4", "ip-in-ip":
|
|
||||||
return []int{protocolIPv4}, true, nil
|
|
||||||
case "tcp":
|
|
||||||
return []int{protocolTCP}, false, nil
|
|
||||||
case "egp":
|
|
||||||
return []int{protocolEGP}, true, nil
|
|
||||||
case "igp":
|
|
||||||
return []int{protocolIGP}, true, nil
|
|
||||||
case "udp":
|
|
||||||
return []int{protocolUDP}, false, nil
|
|
||||||
case "gre":
|
|
||||||
return []int{protocolGRE}, true, nil
|
|
||||||
case "esp":
|
|
||||||
return []int{protocolESP}, true, nil
|
|
||||||
case "ah":
|
|
||||||
return []int{protocolAH}, true, nil
|
|
||||||
case "sctp":
|
|
||||||
return []int{protocolSCTP}, false, nil
|
|
||||||
case "icmp":
|
|
||||||
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
protocolNumber, err := strconv.Atoi(protocol)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(kradalby): What is this?
|
|
||||||
needsWildcard := protocolNumber != protocolTCP &&
|
|
||||||
protocolNumber != protocolUDP &&
|
|
||||||
protocolNumber != protocolSCTP
|
|
||||||
|
|
||||||
return []int{protocolNumber}, needsWildcard, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user