1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-16 17:50:44 +02:00

policy: reject unsupported fields (#2764)

This commit is contained in:
Kristoffer Dalby 2025-09-12 14:47:56 +02:00 committed by GitHub
parent 1b1c989268
commit 2938d03878
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1177 additions and 133 deletions

View File

@ -89,6 +89,8 @@ upstream is changed.
[#2663](https://github.com/juanfont/headscale/pull/2663)
- OIDC: Update user with claims from UserInfo _before_ comparing with allowed
groups, email and domain [#2663](https://github.com/juanfont/headscale/pull/2663)
- Policy will now reject invalid fields, making it easier to spot spelling errors
[#2764](https://github.com/juanfont/headscale/pull/2764)
## 0.26.1 (2025-06-06)

2
go.mod
View File

@ -18,6 +18,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/glebarez/sqlite v1.11.0
github.com/go-gormigrate/gormigrate/v2 v2.1.4
github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874
github.com/gofrs/uuid/v5 v5.3.2
github.com/google/go-cmp v0.7.0
github.com/gorilla/mux v1.8.1
@ -131,7 +132,6 @@ require (
github.com/glebarez/go-sqlite v1.22.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
github.com/go-jose/go-jose/v4 v4.1.0 // indirect
github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect

View File

@ -222,6 +222,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{
@ -236,6 +237,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
@ -371,10 +373,12 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDestForTest,
IPProto: []int{6, 17},
},
},
},
@ -478,6 +482,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
@ -513,6 +518,7 @@ func TestReduceFilterRules(t *testing.T) {
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{6, 17},
},
},
},
@ -588,6 +594,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
@ -601,6 +608,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
@ -676,6 +684,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
@ -689,6 +698,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
@ -756,6 +766,7 @@ func TestReduceFilterRules(t *testing.T) {
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
@ -1736,7 +1747,7 @@ func TestSSHPolicyRules(t *testing.T) {
]
}`,
expectErr: true,
errorMessage: `SSH action "invalid" is not valid, must be accept or check`,
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
},
{
name: "invalid-check-period",

View File

@ -28,7 +28,7 @@ func (pol *Policy) compileFilterRules(
var rules []tailcfg.FilterRule
for _, acl := range pol.ACLs {
if acl.Action != "accept" {
if acl.Action != ActionAccept {
return nil, ErrInvalidAction
}
@ -41,12 +41,7 @@ func (pol *Policy) compileFilterRules(
continue
}
// TODO(kradalby): integrate type into schema
// TODO(kradalby): figure out the _ is wildcard stuff
protocols, _, err := parseProtocol(acl.Protocol)
if err != nil {
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
}
protocols, _ := acl.Protocol.parseProtocol()
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
@ -132,9 +127,9 @@ func (pol *Policy) compileSSHPolicy(
var action tailcfg.SSHAction
switch rule.Action {
case "accept":
case SSHActionAccept:
action = sshAction(true, 0)
case "check":
case SSHActionCheck:
action = sshAction(true, time.Duration(rule.CheckPeriod))
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)

View File

@ -92,6 +92,7 @@ func TestParsing(t *testing.T) {
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolTCP, protocolUDP},
},
},
wantErr: false,
@ -193,6 +194,7 @@ func TestParsing(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolTCP, protocolUDP},
},
},
wantErr: false,
@ -229,6 +231,7 @@ func TestParsing(t *testing.T) {
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
},
},
IPProto: []int{protocolTCP, protocolUDP},
},
},
wantErr: false,
@ -268,6 +271,7 @@ func TestParsing(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolTCP, protocolUDP},
},
},
wantErr: false,
@ -301,6 +305,7 @@ func TestParsing(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolTCP, protocolUDP},
},
},
wantErr: false,
@ -334,6 +339,7 @@ func TestParsing(t *testing.T) {
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{protocolTCP, protocolUDP},
},
},
wantErr: false,

View File

@ -1,8 +1,6 @@
package v2
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/netip"
@ -10,6 +8,8 @@ import (
"strconv"
"strings"
"github.com/go-json-experiment/json"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
@ -23,6 +23,13 @@ import (
"tailscale.com/util/slicesx"
)
// Global JSON options for consistent parsing across all struct unmarshaling
var policyJSONOpts = []json.Options{
json.DefaultOptionsV2(),
json.MatchCaseInsensitiveNames(true),
json.RejectUnknownMembers(true),
}
const Wildcard = Asterix(0)
type Asterix int
@ -614,10 +621,8 @@ type AliasWithPorts struct {
}
func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
if err := json.Unmarshal(b, &v); err != nil {
return err
}
@ -735,7 +740,7 @@ type Aliases []Alias
func (a *Aliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
}
@ -825,7 +830,7 @@ type AutoApprovers []AutoApprover
func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
var autoApprovers []AutoApproverEnc
err := json.Unmarshal(b, &autoApprovers)
err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...)
if err != nil {
return err
}
@ -920,7 +925,7 @@ type Owners []Owner
func (o *Owners) UnmarshalJSON(b []byte) error {
var owners []OwnerEnc
err := json.Unmarshal(b, &owners)
err := json.Unmarshal(b, &owners, policyJSONOpts...)
if err != nil {
return err
}
@ -994,18 +999,46 @@ func (g Groups) Contains(group *Group) error {
// that all group names conform to the expected format, which is always prefixed
// with "group:". If any group name is invalid, an error is returned.
func (g *Groups) UnmarshalJSON(b []byte) error {
var rawGroups map[string][]string
if err := json.Unmarshal(b, &rawGroups); err != nil {
// First unmarshal as a generic map to validate group names first
var rawMap map[string]interface{}
if err := json.Unmarshal(b, &rawMap); err != nil {
return err
}
// Validate group names first before checking data types
for key := range rawMap {
group := Group(key)
if err := group.Validate(); err != nil {
return err
}
}
// Then validate each field can be converted to []string
rawGroups := make(map[string][]string)
for key, value := range rawMap {
switch v := value.(type) {
case []interface{}:
// Convert []interface{} to []string
var stringSlice []string
for _, item := range v {
if str, ok := item.(string); ok {
stringSlice = append(stringSlice, str)
} else {
return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item)
}
}
rawGroups[key] = stringSlice
case string:
return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v)
default:
return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v)
}
}
*g = make(Groups)
for key, value := range rawGroups {
group := Group(key)
if err := group.Validate(); err != nil {
return err
}
// Group name already validated above
var usernames Usernames
for _, u := range value {
@ -1031,7 +1064,7 @@ type Hosts map[Host]Prefix
func (h *Hosts) UnmarshalJSON(b []byte) error {
var rawHosts map[string]string
if err := json.Unmarshal(b, &rawHosts); err != nil {
if err := json.Unmarshal(b, &rawHosts, policyJSONOpts...); err != nil {
return err
}
@ -1242,13 +1275,290 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
return ret, exitNodeSet, nil
}
// Action represents the action to take for an ACL rule.
type Action string
const (
ActionAccept Action = "accept"
)
// SSHAction represents the action to take for an SSH rule.
type SSHAction string
const (
SSHActionAccept SSHAction = "accept"
SSHActionCheck SSHAction = "check"
)
// String returns the string representation of the Action.
func (a Action) String() string {
return string(a)
}
// UnmarshalJSON implements JSON unmarshaling for Action.
func (a *Action) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
switch str {
case "accept":
*a = ActionAccept
default:
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
}
return nil
}
// MarshalJSON implements JSON marshaling for Action.
func (a Action) MarshalJSON() ([]byte, error) {
return json.Marshal(string(a))
}
// String returns the string representation of the SSHAction.
func (a SSHAction) String() string {
return string(a)
}
// UnmarshalJSON implements JSON unmarshaling for SSHAction.
func (a *SSHAction) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
switch str {
case "accept":
*a = SSHActionAccept
case "check":
*a = SSHActionCheck
default:
return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str)
}
return nil
}
// MarshalJSON implements JSON marshaling for SSHAction.
func (a SSHAction) MarshalJSON() ([]byte, error) {
return json.Marshal(string(a))
}
// Protocol represents a network protocol with its IANA number and descriptions.
type Protocol string
const (
ProtocolICMP Protocol = "icmp"
ProtocolIGMP Protocol = "igmp"
ProtocolIPv4 Protocol = "ipv4"
ProtocolIPInIP Protocol = "ip-in-ip"
ProtocolTCP Protocol = "tcp"
ProtocolEGP Protocol = "egp"
ProtocolIGP Protocol = "igp"
ProtocolUDP Protocol = "udp"
ProtocolGRE Protocol = "gre"
ProtocolESP Protocol = "esp"
ProtocolAH Protocol = "ah"
ProtocolIPv6ICMP Protocol = "ipv6-icmp"
ProtocolSCTP Protocol = "sctp"
ProtocolFC Protocol = "fc"
ProtocolWildcard Protocol = "*"
)
// String returns the string representation of the Protocol.
func (p Protocol) String() string {
return string(p)
}
// Description returns the human-readable description of the Protocol.
func (p Protocol) Description() string {
switch p {
case ProtocolICMP:
return "Internet Control Message Protocol"
case ProtocolIGMP:
return "Internet Group Management Protocol"
case ProtocolIPv4:
return "IPv4 encapsulation"
case ProtocolTCP:
return "Transmission Control Protocol"
case ProtocolEGP:
return "Exterior Gateway Protocol"
case ProtocolIGP:
return "Interior Gateway Protocol"
case ProtocolUDP:
return "User Datagram Protocol"
case ProtocolGRE:
return "Generic Routing Encapsulation"
case ProtocolESP:
return "Encapsulating Security Payload"
case ProtocolAH:
return "Authentication Header"
case ProtocolIPv6ICMP:
return "Internet Control Message Protocol for IPv6"
case ProtocolSCTP:
return "Stream Control Transmission Protocol"
case ProtocolFC:
return "Fibre Channel"
case ProtocolWildcard:
return "Wildcard (not supported - use specific protocol)"
default:
return "Unknown Protocol"
}
}
// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement.
// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values.
func (p Protocol) parseProtocol() ([]int, bool) {
switch p {
case "":
// Empty protocol applies to TCP and UDP traffic only
return []int{protocolTCP, protocolUDP}, false
case ProtocolWildcard:
// Wildcard protocol - defensive handling (should not reach here due to validation)
return nil, false
case ProtocolIGMP:
return []int{protocolIGMP}, true
case ProtocolIPv4, ProtocolIPInIP:
return []int{protocolIPv4}, true
case ProtocolTCP:
return []int{protocolTCP}, false
case ProtocolEGP:
return []int{protocolEGP}, true
case ProtocolIGP:
return []int{protocolIGP}, true
case ProtocolUDP:
return []int{protocolUDP}, false
case ProtocolGRE:
return []int{protocolGRE}, true
case ProtocolESP:
return []int{protocolESP}, true
case ProtocolAH:
return []int{protocolAH}, true
case ProtocolSCTP:
return []int{protocolSCTP}, false
case ProtocolICMP:
return []int{protocolICMP, protocolIPv6ICMP}, true
default:
// Try to parse as a numeric protocol number
// This should not fail since validation happened during unmarshaling
protocolNumber, _ := strconv.Atoi(string(p))
// Determine if wildcard is needed based on protocol number
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard
}
}
// UnmarshalJSON implements JSON unmarshaling for Protocol.
func (p *Protocol) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
// Normalize to lowercase for case-insensitive matching
*p = Protocol(strings.ToLower(str))
// Validate the protocol
if err := p.validate(); err != nil {
return err
}
return nil
}
// validate checks if the Protocol is valid.
func (p Protocol) validate() error {
switch p {
case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP,
ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE,
ProtocolESP, ProtocolAH, ProtocolSCTP:
return nil
case ProtocolWildcard:
// Wildcard "*" is not allowed - Tailscale rejects it
return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
default:
// Try to parse as a numeric protocol number
str := string(p)
// Check for leading zeros (not allowed by Tailscale)
if str == "0" || (len(str) > 1 && str[0] == '0') {
return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str)
}
protocolNumber, err := strconv.Atoi(str)
if err != nil {
return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p)
}
if protocolNumber < 0 || protocolNumber > 255 {
return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber)
}
return nil
}
}
// MarshalJSON implements JSON marshaling for Protocol.
func (p Protocol) MarshalJSON() ([]byte, error) {
return json.Marshal(string(p))
}
// Protocol constants matching the IANA numbers
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
protocolFC = 133 // Fibre Channel
)
type ACL struct {
Action string `json:"action"` // TODO(kradalby): add strict type
Protocol string `json:"proto"` // TODO(kradalby): add strict type
Action Action `json:"action"`
Protocol Protocol `json:"proto"`
Sources Aliases `json:"src"`
Destinations []AliasWithPorts `json:"dst"`
}
// UnmarshalJSON implements custom unmarshalling for ACL that ignores fields starting with '#'.
// headscale-admin uses # in some field names to add metadata, so we will ignore
// those to ensure it doesnt break.
// https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38
func (a *ACL) UnmarshalJSON(b []byte) error {
// First unmarshal into a map to filter out comment fields
var raw map[string]any
if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil {
return err
}
// Remove any fields that start with '#'
filtered := make(map[string]any)
for key, value := range raw {
if !strings.HasPrefix(key, "#") {
filtered[key] = value
}
}
// Marshal the filtered map back to JSON
filteredBytes, err := json.Marshal(filtered)
if err != nil {
return err
}
// Create a type alias to avoid infinite recursion
type aclAlias ACL
var temp aclAlias
// Unmarshal into the temporary struct using the v2 JSON options
if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil {
return err
}
// Copy the result back to the original struct
*a = ACL(temp)
return nil
}
// Policy represents a Tailscale Network Policy.
// TODO(kradalby):
// Add validation method checking:
@ -1266,7 +1576,7 @@ type Policy struct {
Hosts Hosts `json:"hosts,omitempty"`
TagOwners TagOwners `json:"tagOwners,omitempty"`
ACLs []ACL `json:"acls,omitempty"`
AutoApprovers AutoApproverPolicy `json:"autoApprovers,omitempty"`
AutoApprovers AutoApproverPolicy `json:"autoApprovers"`
SSHs []SSH `json:"ssh,omitempty"`
}
@ -1444,13 +1754,14 @@ func (p *Policy) validate() error {
}
}
}
// Validate protocol-port compatibility
if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil {
errs = append(errs, err)
}
}
for _, ssh := range p.SSHs {
if ssh.Action != "accept" && ssh.Action != "check" {
errs = append(errs, fmt.Errorf("SSH action %q is not valid, must be accept or check", ssh.Action))
}
for _, user := range ssh.Users {
if strings.HasPrefix(string(user), "autogroup:") {
maybeAuto := AutoGroup(user)
@ -1564,7 +1875,7 @@ func (p *Policy) validate() error {
// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action"`
Action SSHAction `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users SSHUsers `json:"users"`
@ -1595,7 +1906,7 @@ func (g Groups) MarshalJSON() ([]byte, error) {
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
}
@ -1618,7 +1929,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
}
@ -1762,9 +2073,13 @@ func unmarshalPolicy(b []byte) (*Policy, error) {
}
ast.Standardize()
acl := ast.Pack()
if err = json.Unmarshal(acl, &policy); err != nil {
if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil {
var serr *json.SemanticError
if errors.As(err, &serr) && serr.Err == json.ErrUnknownName {
ptr := serr.JSONPointer
name := ptr.LastToken()
return nil, fmt.Errorf("unknown field %q", name)
}
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
}
@ -1775,6 +2090,25 @@ func unmarshalPolicy(b []byte) (*Policy, error) {
return &policy, nil
}
const (
expectedTokenItems = 2
)
// validateProtocolPortCompatibility checks that only TCP, UDP, and SCTP protocols
// can have specific ports. All other protocols should only use wildcard ports.
func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWithPorts) error {
// Only TCP, UDP, and SCTP support specific ports
supportsSpecificPorts := protocol == ProtocolTCP || protocol == ProtocolUDP || protocol == ProtocolSCTP || protocol == ""
if supportsSpecificPorts {
return nil // No validation needed for these protocols
}
// For all other protocols, check that all destinations use wildcard ports
for _, dst := range destinations {
for _, portRange := range dst.Ports {
// Check if it's not a wildcard port (0-65535)
if !(portRange.First == 0 && portRange.Last == 65535) {
return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol)
}
}
}
return nil
}

View File

@ -352,20 +352,6 @@ func TestUnmarshalPolicy(t *testing.T) {
name: "2652-asterix-error-better-explain",
input: `
{
"acls": [
{
"action": "accept",
"src": [
"*"
],
"dst": [
"*:*"
],
"proto": [
"*:*"
]
}
],
"ssh": [
{
"action": "accept",
@ -375,9 +361,7 @@ func TestUnmarshalPolicy(t *testing.T) {
"dst": [
"*"
],
"proto": [
"*:*"
]
"users": ["root"]
}
]
}
@ -992,6 +976,500 @@ func TestUnmarshalPolicy(t *testing.T) {
`,
wantErr: `first port must be >0, or use '*' for wildcard`,
},
{
name: "disallow-unsupported-fields",
input: `
{
// rules doesnt exists, we have "acls"
"rules": [
]
}
`,
wantErr: `unknown field "rules"`,
},
{
name: "disallow-unsupported-fields-nested",
input: `
{
"acls": [
{ "action": "accept", "BAD": ["FOO:BAR:FOO:BAR"], "NOT": ["BAD:BAD:BAD:BAD"] }
]
}
`,
wantErr: `unknown field`,
},
{
name: "invalid-group-name",
input: `
{
"groups": {
"group:test": ["user@example.com"],
"INVALID_GROUP_FIELD": ["user@example.com"]
}
}
`,
wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`,
},
{
name: "invalid-group-datatype",
input: `
{
"groups": {
"group:test": ["user@example.com"],
"group:invalid": "should fail"
}
}
`,
wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`,
},
{
name: "invalid-group-name-and-datatype-fails-on-name-first",
input: `
{
"groups": {
"group:test": ["user@example.com"],
"INVALID_GROUP_FIELD": "should fail"
}
}
`,
wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`,
},
{
name: "disallow-unsupported-fields-hosts-level",
input: `
{
"hosts": {
"host1": "10.0.0.1",
"INVALID_HOST_FIELD": "should fail"
}
}
`,
wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`,
},
{
name: "disallow-unsupported-fields-tagowners-level",
input: `
{
"tagOwners": {
"tag:test": ["user@example.com"],
"INVALID_TAG_FIELD": "should fail"
}
}
`,
wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`,
},
{
name: "disallow-unsupported-fields-acls-level",
input: `
{
"acls": [
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"],
"INVALID_ACL_FIELD": "should fail"
}
]
}
`,
wantErr: `unknown field "INVALID_ACL_FIELD"`,
},
{
name: "disallow-unsupported-fields-ssh-level",
input: `
{
"ssh": [
{
"action": "accept",
"src": ["user@example.com"],
"dst": ["user@example.com"],
"users": ["root"],
"INVALID_SSH_FIELD": "should fail"
}
]
}
`,
wantErr: `unknown field "INVALID_SSH_FIELD"`,
},
{
name: "disallow-unsupported-fields-policy-level",
input: `
{
"acls": [
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}
],
"INVALID_POLICY_FIELD": "should fail at policy level"
}
`,
wantErr: `unknown field "INVALID_POLICY_FIELD"`,
},
{
name: "disallow-unsupported-fields-autoapprovers-level",
input: `
{
"autoApprovers": {
"routes": {
"10.0.0.0/8": ["user@example.com"]
},
"exitNode": ["user@example.com"],
"INVALID_AUTO_APPROVER_FIELD": "should fail"
}
}
`,
wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`,
},
// headscale-admin uses # in some field names to add metadata, so we will ignore
// those to ensure it doesnt break.
// https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38
{
name: "hash-fields-are-allowed-but-ignored",
input: `
{
"acls": [
{
"#ha-test": "SOME VALUE",
"action": "accept",
"src": [
"10.0.0.1"
],
"dst": [
"autogroup:internet:*"
]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Sources: Aliases{
pp("10.0.0.1/32"),
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(AutoGroup("autogroup:internet")),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
},
},
},
{
name: "ssh-asterix-invalid-acl-input",
input: `
{
"ssh": [
{
"action": "accept",
"src": [
"user@example.com"
],
"dst": [
"user@example.com"
],
"users": ["root"],
"proto": "tcp"
}
]
}
`,
wantErr: `unknown field "proto"`,
},
{
name: "protocol-wildcard-not-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "*",
"src": ["*"],
"dst": ["*:*"]
}
]
}
`,
wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`,
},
{
name: "protocol-case-insensitive-uppercase",
input: `
{
"acls": [
{
"action": "accept",
"proto": "ICMP",
"src": ["*"],
"dst": ["*:*"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "icmp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
},
},
},
{
name: "protocol-case-insensitive-mixed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "IcmP",
"src": ["*"],
"dst": ["*:*"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "icmp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
},
},
},
{
name: "protocol-leading-zero-not-permitted",
input: `
{
"acls": [
{
"action": "accept",
"proto": "0",
"src": ["*"],
"dst": ["*:*"]
}
]
}
`,
wantErr: `leading 0 not permitted in protocol number "0"`,
},
{
name: "protocol-empty-applies-to-tcp-udp-only",
input: `
{
"acls": [
{
"action": "accept",
"src": ["*"],
"dst": ["*:80"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
},
},
},
{
name: "protocol-icmp-with-specific-port-not-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "icmp",
"src": ["*"],
"dst": ["*:80"]
}
]
}
`,
wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`,
},
{
name: "protocol-icmp-with-wildcard-port-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "icmp",
"src": ["*"],
"dst": ["*:*"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "icmp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
},
},
},
{
name: "protocol-gre-with-specific-port-not-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "gre",
"src": ["*"],
"dst": ["*:443"]
}
]
}
`,
wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`,
},
{
name: "protocol-tcp-with-specific-port-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:80"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
},
},
},
{
name: "protocol-udp-with-specific-port-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "udp",
"src": ["*"],
"dst": ["*:53"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "udp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{{First: 53, Last: 53}},
},
},
},
},
},
},
{
name: "protocol-sctp-with-specific-port-allowed",
input: `
{
"acls": [
{
"action": "accept",
"proto": "sctp",
"src": ["*"],
"dst": ["*:9000"]
}
]
}
`,
want: &Policy{
ACLs: []ACL{
{
Action: "accept",
Protocol: "sctp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: Wildcard,
Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}},
},
},
},
},
},
},
}
cmps := append(util.Comparers,
@ -2091,3 +2569,291 @@ func TestNodeCanHaveTag(t *testing.T) {
})
}
}
func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
tests := []struct {
name string
input string
expected ACL
wantErr bool
}{
{
name: "basic ACL with comment fields",
input: `{
"#comment": "This is a comment",
"action": "accept",
"proto": "tcp",
"src": ["user1@example.com"],
"dst": ["tag:server:80"]
}`,
expected: ACL{
Action: "accept",
Protocol: "tcp",
Sources: []Alias{mustParseAlias("user1@example.com")},
Destinations: []AliasWithPorts{
{
Alias: mustParseAlias("tag:server"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
wantErr: false,
},
{
name: "multiple comment fields",
input: `{
"#description": "Allow access to web servers",
"#note": "Created by admin",
"#created_date": "2024-01-15",
"action": "accept",
"proto": "tcp",
"src": ["group:developers"],
"dst": ["10.0.0.0/24:443"]
}`,
expected: ACL{
Action: "accept",
Protocol: "tcp",
Sources: []Alias{mustParseAlias("group:developers")},
Destinations: []AliasWithPorts{
{
Alias: mustParseAlias("10.0.0.0/24"),
Ports: []tailcfg.PortRange{{First: 443, Last: 443}},
},
},
},
wantErr: false,
},
{
name: "comment field with complex object value",
input: `{
"#metadata": {
"description": "Complex comment object",
"tags": ["web", "production"],
"created_by": "admin"
},
"action": "accept",
"proto": "udp",
"src": ["*"],
"dst": ["autogroup:internet:53"]
}`,
expected: ACL{
Action: ActionAccept,
Protocol: "udp",
Sources: []Alias{Wildcard},
Destinations: []AliasWithPorts{
{
Alias: mustParseAlias("autogroup:internet"),
Ports: []tailcfg.PortRange{{First: 53, Last: 53}},
},
},
},
wantErr: false,
},
{
name: "invalid action should fail",
input: `{
"action": "deny",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}`,
wantErr: true,
},
{
name: "no comment fields",
input: `{
"action": "accept",
"proto": "icmp",
"src": ["tag:client"],
"dst": ["tag:server:*"]
}`,
expected: ACL{
Action: ActionAccept,
Protocol: "icmp",
Sources: []Alias{mustParseAlias("tag:client")},
Destinations: []AliasWithPorts{
{
Alias: mustParseAlias("tag:server"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
wantErr: false,
},
{
name: "only comment fields",
input: `{
"#comment": "This rule is disabled",
"#reason": "Temporary disable for maintenance"
}`,
expected: ACL{
Action: Action(""),
Protocol: Protocol(""),
Sources: nil,
Destinations: nil,
},
wantErr: false,
},
{
name: "invalid JSON",
input: `{
"#comment": "This is a comment",
"action": "accept",
"proto": "tcp"
"src": ["invalid json"]
}`,
wantErr: true,
},
{
name: "invalid field after comment filtering",
input: `{
"#comment": "This is a comment",
"action": "accept",
"proto": "tcp",
"src": ["user1@example.com"],
"dst": ["invalid-destination"]
}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var acl ACL
err := json.Unmarshal([]byte(tt.input), &acl)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected.Action, acl.Action)
assert.Equal(t, tt.expected.Protocol, acl.Protocol)
assert.Equal(t, len(tt.expected.Sources), len(acl.Sources))
assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations))
// Compare sources
for i, expectedSrc := range tt.expected.Sources {
if i < len(acl.Sources) {
assert.Equal(t, expectedSrc, acl.Sources[i])
}
}
// Compare destinations
for i, expectedDst := range tt.expected.Destinations {
if i < len(acl.Destinations) {
assert.Equal(t, expectedDst.Alias, acl.Destinations[i].Alias)
assert.Equal(t, expectedDst.Ports, acl.Destinations[i].Ports)
}
}
})
}
}
func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) {
// Test that marshaling and unmarshaling preserves data (excluding comments)
original := ACL{
Action: "accept",
Protocol: "tcp",
Sources: []Alias{mustParseAlias("group:admins")},
Destinations: []AliasWithPorts{
{
Alias: mustParseAlias("tag:server"),
Ports: []tailcfg.PortRange{{First: 22, Last: 22}, {First: 80, Last: 80}},
},
},
}
// Marshal to JSON
jsonBytes, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal back
var unmarshaled ACL
err = json.Unmarshal(jsonBytes, &unmarshaled)
require.NoError(t, err)
// Should be equal
assert.Equal(t, original.Action, unmarshaled.Action)
assert.Equal(t, original.Protocol, unmarshaled.Protocol)
assert.Equal(t, len(original.Sources), len(unmarshaled.Sources))
assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations))
}
func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
// Test that ACL unmarshaling works within a Policy context
policyJSON := `{
"groups": {
"group:developers": ["user1@example.com", "user2@example.com"]
},
"tagOwners": {
"tag:server": ["group:developers"]
},
"acls": [
{
"#description": "Allow developers to access servers",
"#priority": "high",
"action": "accept",
"proto": "tcp",
"src": ["group:developers"],
"dst": ["tag:server:22,80,443"]
},
{
"#note": "Allow all other traffic",
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}
]
}`
policy, err := unmarshalPolicy([]byte(policyJSON))
require.NoError(t, err)
require.NotNil(t, policy)
// Check that ACLs were parsed correctly
require.Len(t, policy.ACLs, 2)
// First ACL
acl1 := policy.ACLs[0]
assert.Equal(t, ActionAccept, acl1.Action)
assert.Equal(t, Protocol("tcp"), acl1.Protocol)
require.Len(t, acl1.Sources, 1)
require.Len(t, acl1.Destinations, 1)
// Second ACL
acl2 := policy.ACLs[1]
assert.Equal(t, ActionAccept, acl2.Action)
assert.Equal(t, Protocol("tcp"), acl2.Protocol)
require.Len(t, acl2.Sources, 1)
require.Len(t, acl2.Destinations, 1)
}
func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
// Test that invalid actions are rejected
policyJSON := `{
"acls": [
{
"action": "deny",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"]
}
]
}`
_, err := unmarshalPolicy([]byte(policyJSON))
require.Error(t, err)
assert.Contains(t, err.Error(), `invalid action "deny"`)
}
// Helper function to parse aliases for testing
func mustParseAlias(s string) Alias {
alias, err := parseAlias(s)
if err != nil {
panic(err)
}
return alias
}

View File

@ -2,7 +2,6 @@ package v2
import (
"errors"
"fmt"
"slices"
"strconv"
"strings"
@ -97,72 +96,3 @@ func parsePort(portStr string) (uint16, error) {
return uint16(port), nil
}
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
ProtocolFC = 133 // Fibre Channel
)
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return nil, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
return []int{protocolIPv4}, true, nil
case "tcp":
return []int{protocolTCP}, false, nil
case "egp":
return []int{protocolEGP}, true, nil
case "igp":
return []int{protocolIGP}, true, nil
case "udp":
return []int{protocolUDP}, false, nil
case "gre":
return []int{protocolGRE}, true, nil
case "esp":
return []int{protocolESP}, true, nil
case "ah":
return []int{protocolAH}, true, nil
case "sctp":
return []int{protocolSCTP}, false, nil
case "icmp":
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
}
// TODO(kradalby): What is this?
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
}

View File

@ -1885,7 +1885,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
policyFilePath,
},
)
assert.ErrorContains(t, err, "compiling filter rules: invalid action")
assert.ErrorContains(t, err, `invalid action "unknown-action"`)
// The new policy was invalid, the old one should still be in place, which
// is none.

View File

@ -1481,7 +1481,7 @@ func TestSubnetRouteACL(t *testing.T) {
wantClientFilter := []filter.Match{
{
IPProto: views.SliceOf([]ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
ipproto.TCP, ipproto.UDP,
}),
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
@ -1513,7 +1513,7 @@ func TestSubnetRouteACL(t *testing.T) {
wantSubnetFilter := []filter.Match{
{
IPProto: views.SliceOf([]ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
ipproto.TCP, ipproto.UDP,
}),
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
@ -1535,7 +1535,7 @@ func TestSubnetRouteACL(t *testing.T) {
},
{
IPProto: views.SliceOf([]ipproto.Proto{
ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6,
ipproto.TCP, ipproto.UDP,
}),
Srcs: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),