From 2938d03878d44491924563e26e408d2b5d51e668 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 12 Sep 2025 14:47:56 +0200 Subject: [PATCH] policy: reject unsupported fields (#2764) --- CHANGELOG.md | 2 + go.mod | 2 +- hscontrol/policy/policy_test.go | 13 +- hscontrol/policy/v2/filter.go | 13 +- hscontrol/policy/v2/filter_test.go | 6 + hscontrol/policy/v2/types.go | 396 ++++++++++++-- hscontrol/policy/v2/types_test.go | 800 ++++++++++++++++++++++++++++- hscontrol/policy/v2/utils.go | 70 --- integration/cli_test.go | 2 +- integration/route_test.go | 6 +- 10 files changed, 1177 insertions(+), 133 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ab70873..e56dd827 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,8 @@ upstream is changed. [#2663](https://github.com/juanfont/headscale/pull/2663) - OIDC: Update user with claims from UserInfo _before_ comparing with allowed groups, email and domain [#2663](https://github.com/juanfont/headscale/pull/2663) +- Policy will now reject invalid fields, making it easier to spot spelling errors + [#2764](https://github.com/juanfont/headscale/pull/2764) ## 0.26.1 (2025-06-06) diff --git a/go.mod b/go.mod index f719bc0b..3af028b9 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/fsnotify/fsnotify v1.9.0 github.com/glebarez/sqlite v1.11.0 github.com/go-gormigrate/gormigrate/v2 v2.1.4 + github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 github.com/gofrs/uuid/v5 v5.3.2 github.com/google/go-cmp v0.7.0 github.com/gorilla/mux v1.8.1 @@ -131,7 +132,6 @@ require ( github.com/glebarez/go-sqlite v1.22.0 // indirect github.com/go-jose/go-jose/v3 v3.0.4 // indirect github.com/go-jose/go-jose/v4 v4.1.0 // indirect - github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index d2ff019d..c7cd3bcf 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -222,6 +222,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, { SrcIPs: []string{ @@ -236,6 +237,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, }, }, @@ -371,10 +373,12 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, { SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, DstPorts: hsExitNodeDestForTest, + IPProto: []int{6, 17}, }, }, }, @@ -478,6 +482,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, { SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, @@ -513,6 +518,7 @@ func TestReduceFilterRules(t *testing.T) { {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, }, + IPProto: []int{6, 17}, }, }, }, @@ -588,6 +594,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, { SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, @@ -601,6 +608,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, }, }, @@ -676,6 +684,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, { SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, @@ -689,6 +698,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, }, }, @@ -756,6 +766,7 @@ func TestReduceFilterRules(t *testing.T) { Ports: tailcfg.PortRangeAny, }, }, + IPProto: []int{6, 17}, }, }, }, @@ -1736,7 +1747,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: `SSH action "invalid" is not valid, must be accept or check`, + errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, }, { name: "invalid-check-period", diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 5793d96c..139b46a3 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -28,7 +28,7 @@ func (pol *Policy) compileFilterRules( var rules []tailcfg.FilterRule for _, acl := range pol.ACLs { - if acl.Action != "accept" { + if acl.Action != ActionAccept { return nil, ErrInvalidAction } @@ -41,12 +41,7 @@ func (pol *Policy) compileFilterRules( continue } - // TODO(kradalby): integrate type into schema - // TODO(kradalby): figure out the _ is wildcard stuff - protocols, _, err := parseProtocol(acl.Protocol) - if err != nil { - return nil, fmt.Errorf("parsing policy, protocol err: %w ", err) - } + protocols, _ := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange for _, dest := range acl.Destinations { @@ -132,9 +127,9 @@ func (pol *Policy) compileSSHPolicy( var action tailcfg.SSHAction switch rule.Action { - case "accept": + case SSHActionAccept: action = sshAction(true, 0) - case "check": + case SSHActionCheck: action = sshAction(true, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 8b73a6f5..37dcf149 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -92,6 +92,7 @@ func TestParsing(t *testing.T) { {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, }, + IPProto: []int{protocolTCP, protocolUDP}, }, }, wantErr: false, @@ -193,6 +194,7 @@ func TestParsing(t *testing.T) { DstPorts: []tailcfg.NetPortRange{ {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, }, + IPProto: []int{protocolTCP, protocolUDP}, }, }, wantErr: false, @@ -229,6 +231,7 @@ func TestParsing(t *testing.T) { Ports: tailcfg.PortRange{First: 5400, Last: 5500}, }, }, + IPProto: []int{protocolTCP, protocolUDP}, }, }, wantErr: false, @@ -268,6 +271,7 @@ func TestParsing(t *testing.T) { DstPorts: []tailcfg.NetPortRange{ {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, }, + IPProto: []int{protocolTCP, protocolUDP}, }, }, wantErr: false, @@ -301,6 +305,7 @@ func TestParsing(t *testing.T) { DstPorts: []tailcfg.NetPortRange{ {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, }, + IPProto: []int{protocolTCP, protocolUDP}, }, }, wantErr: false, @@ -334,6 +339,7 @@ func TestParsing(t *testing.T) { DstPorts: []tailcfg.NetPortRange{ {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, }, + IPProto: []int{protocolTCP, protocolUDP}, }, }, wantErr: false, diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 80797e17..2ce85927 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1,8 +1,6 @@ package v2 import ( - "bytes" - "encoding/json" "errors" "fmt" "net/netip" @@ -10,6 +8,8 @@ import ( "strconv" "strings" + "github.com/go-json-experiment/json" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" @@ -23,6 +23,13 @@ import ( "tailscale.com/util/slicesx" ) +// Global JSON options for consistent parsing across all struct unmarshaling +var policyJSONOpts = []json.Options{ + json.DefaultOptionsV2(), + json.MatchCaseInsensitiveNames(true), + json.RejectUnknownMembers(true), +} + const Wildcard = Asterix(0) type Asterix int @@ -614,10 +621,8 @@ type AliasWithPorts struct { } func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { - // TODO(kradalby): use encoding/json/v2 (go-json-experiment) - dec := json.NewDecoder(bytes.NewReader(b)) var v any - if err := dec.Decode(&v); err != nil { + if err := json.Unmarshal(b, &v); err != nil { return err } @@ -735,7 +740,7 @@ type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc - err := json.Unmarshal(b, &aliases) + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err } @@ -825,7 +830,7 @@ type AutoApprovers []AutoApprover func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { var autoApprovers []AutoApproverEnc - err := json.Unmarshal(b, &autoApprovers) + err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...) if err != nil { return err } @@ -920,7 +925,7 @@ type Owners []Owner func (o *Owners) UnmarshalJSON(b []byte) error { var owners []OwnerEnc - err := json.Unmarshal(b, &owners) + err := json.Unmarshal(b, &owners, policyJSONOpts...) if err != nil { return err } @@ -994,18 +999,46 @@ func (g Groups) Contains(group *Group) error { // that all group names conform to the expected format, which is always prefixed // with "group:". If any group name is invalid, an error is returned. func (g *Groups) UnmarshalJSON(b []byte) error { - var rawGroups map[string][]string - if err := json.Unmarshal(b, &rawGroups); err != nil { + // First unmarshal as a generic map to validate group names first + var rawMap map[string]interface{} + if err := json.Unmarshal(b, &rawMap); err != nil { return err } + // Validate group names first before checking data types + for key := range rawMap { + group := Group(key) + if err := group.Validate(); err != nil { + return err + } + } + + // Then validate each field can be converted to []string + rawGroups := make(map[string][]string) + for key, value := range rawMap { + switch v := value.(type) { + case []interface{}: + // Convert []interface{} to []string + var stringSlice []string + for _, item := range v { + if str, ok := item.(string); ok { + stringSlice = append(stringSlice, str) + } else { + return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) + } + } + rawGroups[key] = stringSlice + case string: + return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) + default: + return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v) + } + } + *g = make(Groups) for key, value := range rawGroups { group := Group(key) - if err := group.Validate(); err != nil { - return err - } - + // Group name already validated above var usernames Usernames for _, u := range value { @@ -1031,7 +1064,7 @@ type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { var rawHosts map[string]string - if err := json.Unmarshal(b, &rawHosts); err != nil { + if err := json.Unmarshal(b, &rawHosts, policyJSONOpts...); err != nil { return err } @@ -1242,13 +1275,290 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. return ret, exitNodeSet, nil } +// Action represents the action to take for an ACL rule. +type Action string + +const ( + ActionAccept Action = "accept" +) + +// SSHAction represents the action to take for an SSH rule. +type SSHAction string + +const ( + SSHActionAccept SSHAction = "accept" + SSHActionCheck SSHAction = "check" +) + +// String returns the string representation of the Action. +func (a Action) String() string { + return string(a) +} + +// UnmarshalJSON implements JSON unmarshaling for Action. +func (a *Action) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + switch str { + case "accept": + *a = ActionAccept + default: + return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + } + return nil +} + +// MarshalJSON implements JSON marshaling for Action. +func (a Action) MarshalJSON() ([]byte, error) { + return json.Marshal(string(a)) +} + +// String returns the string representation of the SSHAction. +func (a SSHAction) String() string { + return string(a) +} + +// UnmarshalJSON implements JSON unmarshaling for SSHAction. +func (a *SSHAction) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + switch str { + case "accept": + *a = SSHActionAccept + case "check": + *a = SSHActionCheck + default: + return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + } + return nil +} + +// MarshalJSON implements JSON marshaling for SSHAction. +func (a SSHAction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(a)) +} + +// Protocol represents a network protocol with its IANA number and descriptions. +type Protocol string + +const ( + ProtocolICMP Protocol = "icmp" + ProtocolIGMP Protocol = "igmp" + ProtocolIPv4 Protocol = "ipv4" + ProtocolIPInIP Protocol = "ip-in-ip" + ProtocolTCP Protocol = "tcp" + ProtocolEGP Protocol = "egp" + ProtocolIGP Protocol = "igp" + ProtocolUDP Protocol = "udp" + ProtocolGRE Protocol = "gre" + ProtocolESP Protocol = "esp" + ProtocolAH Protocol = "ah" + ProtocolIPv6ICMP Protocol = "ipv6-icmp" + ProtocolSCTP Protocol = "sctp" + ProtocolFC Protocol = "fc" + ProtocolWildcard Protocol = "*" +) + +// String returns the string representation of the Protocol. +func (p Protocol) String() string { + return string(p) +} + +// Description returns the human-readable description of the Protocol. +func (p Protocol) Description() string { + switch p { + case ProtocolICMP: + return "Internet Control Message Protocol" + case ProtocolIGMP: + return "Internet Group Management Protocol" + case ProtocolIPv4: + return "IPv4 encapsulation" + case ProtocolTCP: + return "Transmission Control Protocol" + case ProtocolEGP: + return "Exterior Gateway Protocol" + case ProtocolIGP: + return "Interior Gateway Protocol" + case ProtocolUDP: + return "User Datagram Protocol" + case ProtocolGRE: + return "Generic Routing Encapsulation" + case ProtocolESP: + return "Encapsulating Security Payload" + case ProtocolAH: + return "Authentication Header" + case ProtocolIPv6ICMP: + return "Internet Control Message Protocol for IPv6" + case ProtocolSCTP: + return "Stream Control Transmission Protocol" + case ProtocolFC: + return "Fibre Channel" + case ProtocolWildcard: + return "Wildcard (not supported - use specific protocol)" + default: + return "Unknown Protocol" + } +} + +// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. +func (p Protocol) parseProtocol() ([]int, bool) { + switch p { + case "": + // Empty protocol applies to TCP and UDP traffic only + return []int{protocolTCP, protocolUDP}, false + case ProtocolWildcard: + // Wildcard protocol - defensive handling (should not reach here due to validation) + return nil, false + case ProtocolIGMP: + return []int{protocolIGMP}, true + case ProtocolIPv4, ProtocolIPInIP: + return []int{protocolIPv4}, true + case ProtocolTCP: + return []int{protocolTCP}, false + case ProtocolEGP: + return []int{protocolEGP}, true + case ProtocolIGP: + return []int{protocolIGP}, true + case ProtocolUDP: + return []int{protocolUDP}, false + case ProtocolGRE: + return []int{protocolGRE}, true + case ProtocolESP: + return []int{protocolESP}, true + case ProtocolAH: + return []int{protocolAH}, true + case ProtocolSCTP: + return []int{protocolSCTP}, false + case ProtocolICMP: + return []int{protocolICMP, protocolIPv6ICMP}, true + default: + // Try to parse as a numeric protocol number + // This should not fail since validation happened during unmarshaling + protocolNumber, _ := strconv.Atoi(string(p)) + + // Determine if wildcard is needed based on protocol number + needsWildcard := protocolNumber != protocolTCP && + protocolNumber != protocolUDP && + protocolNumber != protocolSCTP + + return []int{protocolNumber}, needsWildcard + } +} + +// UnmarshalJSON implements JSON unmarshaling for Protocol. +func (p *Protocol) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + + // Normalize to lowercase for case-insensitive matching + *p = Protocol(strings.ToLower(str)) + + // Validate the protocol + if err := p.validate(); err != nil { + return err + } + + return nil +} + +// validate checks if the Protocol is valid. +func (p Protocol) validate() error { + switch p { + case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP, + ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE, + ProtocolESP, ProtocolAH, ProtocolSCTP: + return nil + case ProtocolWildcard: + // Wildcard "*" is not allowed - Tailscale rejects it + return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + default: + // Try to parse as a numeric protocol number + str := string(p) + + // Check for leading zeros (not allowed by Tailscale) + if str == "0" || (len(str) > 1 && str[0] == '0') { + return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + } + + protocolNumber, err := strconv.Atoi(str) + if err != nil { + return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + } + + if protocolNumber < 0 || protocolNumber > 255 { + return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + } + + return nil + } +} + +// MarshalJSON implements JSON marshaling for Protocol. +func (p Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(string(p)) +} + +// Protocol constants matching the IANA numbers +const ( + protocolICMP = 1 // Internet Control Message + protocolIGMP = 2 // Internet Group Management + protocolIPv4 = 4 // IPv4 encapsulation + protocolTCP = 6 // Transmission Control + protocolEGP = 8 // Exterior Gateway Protocol + protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) + protocolUDP = 17 // User Datagram + protocolGRE = 47 // Generic Routing Encapsulation + protocolESP = 50 // Encap Security Payload + protocolAH = 51 // Authentication Header + protocolIPv6ICMP = 58 // ICMP for IPv6 + protocolSCTP = 132 // Stream Control Transmission Protocol + protocolFC = 133 // Fibre Channel +) + type ACL struct { - Action string `json:"action"` // TODO(kradalby): add strict type - Protocol string `json:"proto"` // TODO(kradalby): add strict type + Action Action `json:"action"` + Protocol Protocol `json:"proto"` Sources Aliases `json:"src"` Destinations []AliasWithPorts `json:"dst"` } +// UnmarshalJSON implements custom unmarshalling for ACL that ignores fields starting with '#'. +// headscale-admin uses # in some field names to add metadata, so we will ignore +// those to ensure it doesnt break. +// https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38 +func (a *ACL) UnmarshalJSON(b []byte) error { + // First unmarshal into a map to filter out comment fields + var raw map[string]any + if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil { + return err + } + + // Remove any fields that start with '#' + filtered := make(map[string]any) + for key, value := range raw { + if !strings.HasPrefix(key, "#") { + filtered[key] = value + } + } + + // Marshal the filtered map back to JSON + filteredBytes, err := json.Marshal(filtered) + if err != nil { + return err + } + + // Create a type alias to avoid infinite recursion + type aclAlias ACL + var temp aclAlias + + // Unmarshal into the temporary struct using the v2 JSON options + if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil { + return err + } + + // Copy the result back to the original struct + *a = ACL(temp) + return nil +} + // Policy represents a Tailscale Network Policy. // TODO(kradalby): // Add validation method checking: @@ -1266,7 +1576,7 @@ type Policy struct { Hosts Hosts `json:"hosts,omitempty"` TagOwners TagOwners `json:"tagOwners,omitempty"` ACLs []ACL `json:"acls,omitempty"` - AutoApprovers AutoApproverPolicy `json:"autoApprovers,omitempty"` + AutoApprovers AutoApproverPolicy `json:"autoApprovers"` SSHs []SSH `json:"ssh,omitempty"` } @@ -1444,13 +1754,14 @@ func (p *Policy) validate() error { } } } + + // Validate protocol-port compatibility + if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil { + errs = append(errs, err) + } } for _, ssh := range p.SSHs { - if ssh.Action != "accept" && ssh.Action != "check" { - errs = append(errs, fmt.Errorf("SSH action %q is not valid, must be accept or check", ssh.Action)) - } - for _, user := range ssh.Users { if strings.HasPrefix(string(user), "autogroup:") { maybeAuto := AutoGroup(user) @@ -1564,7 +1875,7 @@ func (p *Policy) validate() error { // SSH controls who can ssh into which machines. type SSH struct { - Action string `json:"action"` + Action SSHAction `json:"action"` Sources SSHSrcAliases `json:"src"` Destinations SSHDstAliases `json:"dst"` Users SSHUsers `json:"users"` @@ -1595,7 +1906,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc - err := json.Unmarshal(b, &aliases) + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err } @@ -1618,7 +1929,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc - err := json.Unmarshal(b, &aliases) + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err } @@ -1762,9 +2073,13 @@ func unmarshalPolicy(b []byte) (*Policy, error) { } ast.Standardize() - acl := ast.Pack() - - if err = json.Unmarshal(acl, &policy); err != nil { + if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { + var serr *json.SemanticError + if errors.As(err, &serr) && serr.Err == json.ErrUnknownName { + ptr := serr.JSONPointer + name := ptr.LastToken() + return nil, fmt.Errorf("unknown field %q", name) + } return nil, fmt.Errorf("parsing policy from bytes: %w", err) } @@ -1775,6 +2090,25 @@ func unmarshalPolicy(b []byte) (*Policy, error) { return &policy, nil } -const ( - expectedTokenItems = 2 -) +// validateProtocolPortCompatibility checks that only TCP, UDP, and SCTP protocols +// can have specific ports. All other protocols should only use wildcard ports. +func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWithPorts) error { + // Only TCP, UDP, and SCTP support specific ports + supportsSpecificPorts := protocol == ProtocolTCP || protocol == ProtocolUDP || protocol == ProtocolSCTP || protocol == "" + + if supportsSpecificPorts { + return nil // No validation needed for these protocols + } + + // For all other protocols, check that all destinations use wildcard ports + for _, dst := range destinations { + for _, portRange := range dst.Ports { + // Check if it's not a wildcard port (0-65535) + if !(portRange.First == 0 && portRange.Last == 65535) { + return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) + } + } + } + + return nil +} diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 2a3ab578..38c2adf3 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -352,20 +352,6 @@ func TestUnmarshalPolicy(t *testing.T) { name: "2652-asterix-error-better-explain", input: ` { - "acls": [ - { - "action": "accept", - "src": [ - "*" - ], - "dst": [ - "*:*" - ], - "proto": [ - "*:*" - ] - } - ], "ssh": [ { "action": "accept", @@ -375,9 +361,7 @@ func TestUnmarshalPolicy(t *testing.T) { "dst": [ "*" ], - "proto": [ - "*:*" - ] + "users": ["root"] } ] } @@ -992,6 +976,500 @@ func TestUnmarshalPolicy(t *testing.T) { `, wantErr: `first port must be >0, or use '*' for wildcard`, }, + { + name: "disallow-unsupported-fields", + input: ` +{ + // rules doesnt exists, we have "acls" + "rules": [ + ] +} +`, + wantErr: `unknown field "rules"`, + }, + { + name: "disallow-unsupported-fields-nested", + input: ` +{ + "acls": [ + { "action": "accept", "BAD": ["FOO:BAR:FOO:BAR"], "NOT": ["BAD:BAD:BAD:BAD"] } + ] +} +`, + wantErr: `unknown field`, + }, + { + name: "invalid-group-name", + input: ` +{ + "groups": { + "group:test": ["user@example.com"], + "INVALID_GROUP_FIELD": ["user@example.com"] + } +} +`, + wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + }, + { + name: "invalid-group-datatype", + input: ` +{ + "groups": { + "group:test": ["user@example.com"], + "group:invalid": "should fail" + } +} +`, + wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`, + }, + { + name: "invalid-group-name-and-datatype-fails-on-name-first", + input: ` +{ + "groups": { + "group:test": ["user@example.com"], + "INVALID_GROUP_FIELD": "should fail" + } +} +`, + wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + }, + { + name: "disallow-unsupported-fields-hosts-level", + input: ` +{ + "hosts": { + "host1": "10.0.0.1", + "INVALID_HOST_FIELD": "should fail" + } +} +`, + wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`, + }, + { + name: "disallow-unsupported-fields-tagowners-level", + input: ` +{ + "tagOwners": { + "tag:test": ["user@example.com"], + "INVALID_TAG_FIELD": "should fail" + } +} +`, + wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`, + }, + { + name: "disallow-unsupported-fields-acls-level", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"], + "INVALID_ACL_FIELD": "should fail" + } + ] +} +`, + wantErr: `unknown field "INVALID_ACL_FIELD"`, + }, + { + name: "disallow-unsupported-fields-ssh-level", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": ["user@example.com"], + "dst": ["user@example.com"], + "users": ["root"], + "INVALID_SSH_FIELD": "should fail" + } + ] +} +`, + wantErr: `unknown field "INVALID_SSH_FIELD"`, + }, + { + name: "disallow-unsupported-fields-policy-level", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ], + "INVALID_POLICY_FIELD": "should fail at policy level" +} +`, + wantErr: `unknown field "INVALID_POLICY_FIELD"`, + }, + { + name: "disallow-unsupported-fields-autoapprovers-level", + input: ` +{ + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["user@example.com"] + }, + "exitNode": ["user@example.com"], + "INVALID_AUTO_APPROVER_FIELD": "should fail" + } +} +`, + wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`, + }, + // headscale-admin uses # in some field names to add metadata, so we will ignore + // those to ensure it doesnt break. + // https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38 + { + name: "hash-fields-are-allowed-but-ignored", + input: ` +{ + "acls": [ + { + "#ha-test": "SOME VALUE", + "action": "accept", + "src": [ + "10.0.0.1" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: Aliases{ + pp("10.0.0.1/32"), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(AutoGroup("autogroup:internet")), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "ssh-asterix-invalid-acl-input", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "user@example.com" + ], + "dst": [ + "user@example.com" + ], + "users": ["root"], + "proto": "tcp" + } + ] +} +`, + wantErr: `unknown field "proto"`, + }, + { + name: "protocol-wildcard-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "*", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`, + }, + { + name: "protocol-case-insensitive-uppercase", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "ICMP", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "icmp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-case-insensitive-mixed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "IcmP", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "icmp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-leading-zero-not-permitted", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "0", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + wantErr: `leading 0 not permitted in protocol number "0"`, + }, + { + name: "protocol-empty-applies-to-tcp-udp-only", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:80"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-icmp-with-specific-port-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "icmp", + "src": ["*"], + "dst": ["*:80"] + } + ] +} +`, + wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`, + }, + { + name: "protocol-icmp-with-wildcard-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "icmp", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "icmp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-gre-with-specific-port-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "gre", + "src": ["*"], + "dst": ["*:443"] + } + ] +} +`, + wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`, + }, + { + name: "protocol-tcp-with-specific-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:80"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-udp-with-specific-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "udp", + "src": ["*"], + "dst": ["*:53"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "udp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 53, Last: 53}}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-sctp-with-specific-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "sctp", + "src": ["*"], + "dst": ["*:9000"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "sctp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}}, + }, + }, + }, + }, + }, + }, } cmps := append(util.Comparers, @@ -2091,3 +2569,291 @@ func TestNodeCanHaveTag(t *testing.T) { }) } } + +func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { + tests := []struct { + name string + input string + expected ACL + wantErr bool + }{ + { + name: "basic ACL with comment fields", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp", + "src": ["user1@example.com"], + "dst": ["tag:server:80"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("user1@example.com")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + wantErr: false, + }, + { + name: "multiple comment fields", + input: `{ + "#description": "Allow access to web servers", + "#note": "Created by admin", + "#created_date": "2024-01-15", + "action": "accept", + "proto": "tcp", + "src": ["group:developers"], + "dst": ["10.0.0.0/24:443"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("group:developers")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("10.0.0.0/24"), + Ports: []tailcfg.PortRange{{First: 443, Last: 443}}, + }, + }, + }, + wantErr: false, + }, + { + name: "comment field with complex object value", + input: `{ + "#metadata": { + "description": "Complex comment object", + "tags": ["web", "production"], + "created_by": "admin" + }, + "action": "accept", + "proto": "udp", + "src": ["*"], + "dst": ["autogroup:internet:53"] + }`, + expected: ACL{ + Action: ActionAccept, + Protocol: "udp", + Sources: []Alias{Wildcard}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("autogroup:internet"), + Ports: []tailcfg.PortRange{{First: 53, Last: 53}}, + }, + }, + }, + wantErr: false, + }, + { + name: "invalid action should fail", + input: `{ + "action": "deny", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + }`, + wantErr: true, + }, + { + name: "no comment fields", + input: `{ + "action": "accept", + "proto": "icmp", + "src": ["tag:client"], + "dst": ["tag:server:*"] + }`, + expected: ACL{ + Action: ActionAccept, + Protocol: "icmp", + Sources: []Alias{mustParseAlias("tag:client")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "only comment fields", + input: `{ + "#comment": "This rule is disabled", + "#reason": "Temporary disable for maintenance" + }`, + expected: ACL{ + Action: Action(""), + Protocol: Protocol(""), + Sources: nil, + Destinations: nil, + }, + wantErr: false, + }, + { + name: "invalid JSON", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp" + "src": ["invalid json"] + }`, + wantErr: true, + }, + { + name: "invalid field after comment filtering", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp", + "src": ["user1@example.com"], + "dst": ["invalid-destination"] + }`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) + + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected.Action, acl.Action) + assert.Equal(t, tt.expected.Protocol, acl.Protocol) + assert.Equal(t, len(tt.expected.Sources), len(acl.Sources)) + assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations)) + + // Compare sources + for i, expectedSrc := range tt.expected.Sources { + if i < len(acl.Sources) { + assert.Equal(t, expectedSrc, acl.Sources[i]) + } + } + + // Compare destinations + for i, expectedDst := range tt.expected.Destinations { + if i < len(acl.Destinations) { + assert.Equal(t, expectedDst.Alias, acl.Destinations[i].Alias) + assert.Equal(t, expectedDst.Ports, acl.Destinations[i].Ports) + } + } + }) + } +} + +func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { + // Test that marshaling and unmarshaling preserves data (excluding comments) + original := ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("group:admins")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 22, Last: 22}, {First: 80, Last: 80}}, + }, + }, + } + + // Marshal to JSON + jsonBytes, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal back + var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) + require.NoError(t, err) + + // Should be equal + assert.Equal(t, original.Action, unmarshaled.Action) + assert.Equal(t, original.Protocol, unmarshaled.Protocol) + assert.Equal(t, len(original.Sources), len(unmarshaled.Sources)) + assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations)) +} + +func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { + // Test that ACL unmarshaling works within a Policy context + policyJSON := `{ + "groups": { + "group:developers": ["user1@example.com", "user2@example.com"] + }, + "tagOwners": { + "tag:server": ["group:developers"] + }, + "acls": [ + { + "#description": "Allow developers to access servers", + "#priority": "high", + "action": "accept", + "proto": "tcp", + "src": ["group:developers"], + "dst": ["tag:server:22,80,443"] + }, + { + "#note": "Allow all other traffic", + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + policy, err := unmarshalPolicy([]byte(policyJSON)) + require.NoError(t, err) + require.NotNil(t, policy) + + // Check that ACLs were parsed correctly + require.Len(t, policy.ACLs, 2) + + // First ACL + acl1 := policy.ACLs[0] + assert.Equal(t, ActionAccept, acl1.Action) + assert.Equal(t, Protocol("tcp"), acl1.Protocol) + require.Len(t, acl1.Sources, 1) + require.Len(t, acl1.Destinations, 1) + + // Second ACL + acl2 := policy.ACLs[1] + assert.Equal(t, ActionAccept, acl2.Action) + assert.Equal(t, Protocol("tcp"), acl2.Protocol) + require.Len(t, acl2.Sources, 1) + require.Len(t, acl2.Destinations, 1) +} + +func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { + // Test that invalid actions are rejected + policyJSON := `{ + "acls": [ + { + "action": "deny", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + _, err := unmarshalPolicy([]byte(policyJSON)) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid action "deny"`) +} + +// Helper function to parse aliases for testing +func mustParseAlias(s string) Alias { + alias, err := parseAlias(s) + if err != nil { + panic(err) + } + return alias +} diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index 2c551eda..7482c97b 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -2,7 +2,6 @@ package v2 import ( "errors" - "fmt" "slices" "strconv" "strings" @@ -97,72 +96,3 @@ func parsePort(portStr string) (uint16, error) { return uint16(port), nil } - -// For some reason golang.org/x/net/internal/iana is an internal package. -const ( - protocolICMP = 1 // Internet Control Message - protocolIGMP = 2 // Internet Group Management - protocolIPv4 = 4 // IPv4 encapsulation - protocolTCP = 6 // Transmission Control - protocolEGP = 8 // Exterior Gateway Protocol - protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) - protocolUDP = 17 // User Datagram - protocolGRE = 47 // Generic Routing Encapsulation - protocolESP = 50 // Encap Security Payload - protocolAH = 51 // Authentication Header - protocolIPv6ICMP = 58 // ICMP for IPv6 - protocolSCTP = 132 // Stream Control Transmission Protocol - ProtocolFC = 133 // Fibre Channel -) - -// parseProtocol reads the proto field of the ACL and generates a list of -// protocols that will be allowed, following the IANA IP protocol number -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// -// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, -// as per Tailscale behaviour (see tailcfg.FilterRule). -// -// Also returns a boolean indicating if the protocol -// requires all the destinations to use wildcard as port number (only TCP, -// UDP and SCTP support specifying ports). -func parseProtocol(protocol string) ([]int, bool, error) { - switch protocol { - case "": - return nil, false, nil - case "igmp": - return []int{protocolIGMP}, true, nil - case "ipv4", "ip-in-ip": - return []int{protocolIPv4}, true, nil - case "tcp": - return []int{protocolTCP}, false, nil - case "egp": - return []int{protocolEGP}, true, nil - case "igp": - return []int{protocolIGP}, true, nil - case "udp": - return []int{protocolUDP}, false, nil - case "gre": - return []int{protocolGRE}, true, nil - case "esp": - return []int{protocolESP}, true, nil - case "ah": - return []int{protocolAH}, true, nil - case "sctp": - return []int{protocolSCTP}, false, nil - case "icmp": - return []int{protocolICMP, protocolIPv6ICMP}, true, nil - - default: - protocolNumber, err := strconv.Atoi(protocol) - if err != nil { - return nil, false, fmt.Errorf("parsing protocol number: %w", err) - } - - // TODO(kradalby): What is this? - needsWildcard := protocolNumber != protocolTCP && - protocolNumber != protocolUDP && - protocolNumber != protocolSCTP - - return []int{protocolNumber}, needsWildcard, nil - } -} diff --git a/integration/cli_test.go b/integration/cli_test.go index 83ab74cf..98e2ddf3 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1885,7 +1885,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, "compiling filter rules: invalid action") + assert.ErrorContains(t, err, `invalid action "unknown-action"`) // The new policy was invalid, the old one should still be in place, which // is none. diff --git a/integration/route_test.go b/integration/route_test.go index 9af24f77..9aced164 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1481,7 +1481,7 @@ func TestSubnetRouteACL(t *testing.T) { wantClientFilter := []filter.Match{ { IPProto: views.SliceOf([]ipproto.Proto{ - ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + ipproto.TCP, ipproto.UDP, }), Srcs: []netip.Prefix{ netip.MustParsePrefix("100.64.0.1/32"), @@ -1513,7 +1513,7 @@ func TestSubnetRouteACL(t *testing.T) { wantSubnetFilter := []filter.Match{ { IPProto: views.SliceOf([]ipproto.Proto{ - ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + ipproto.TCP, ipproto.UDP, }), Srcs: []netip.Prefix{ netip.MustParsePrefix("100.64.0.1/32"), @@ -1535,7 +1535,7 @@ func TestSubnetRouteACL(t *testing.T) { }, { IPProto: views.SliceOf([]ipproto.Proto{ - ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, + ipproto.TCP, ipproto.UDP, }), Srcs: []netip.Prefix{ netip.MustParsePrefix("100.64.0.1/32"),