diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 338e513b..8dac82ae 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -28,7 +28,7 @@ func (pol *Policy) compileFilterRules( var rules []tailcfg.FilterRule for _, acl := range pol.ACLs { - if acl.Action != "accept" { + if acl.Action != ActionAccept { return nil, ErrInvalidAction } @@ -41,12 +41,7 @@ func (pol *Policy) compileFilterRules( continue } - // TODO(kradalby): integrate type into schema - // TODO(kradalby): figure out the _ is wildcard stuff - protocols, _, err := parseProtocol(acl.Protocol) - if err != nil { - return nil, fmt.Errorf("parsing policy, protocol err: %w ", err) - } + protocols, _ := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange for _, dest := range acl.Destinations { diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 59e297a9..f400b9c5 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1242,9 +1242,214 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. return ret, exitNodeSet, nil } +// Action represents the action to take for an ACL rule. +type Action string + +const ( + ActionAccept Action = "accept" +) + +// String returns the string representation of the Action. +func (a Action) String() string { + return string(a) +} + +// UnmarshalJSON implements JSON unmarshaling for Action. +func (a *Action) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + switch str { + case "accept": + *a = ActionAccept + default: + return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + } + return nil +} + +// MarshalJSON implements JSON marshaling for Action. +func (a Action) MarshalJSON() ([]byte, error) { + return json.Marshal(string(a)) +} + +// Protocol represents a network protocol with its IANA number and descriptions. +type Protocol string + +const ( + ProtocolICMP Protocol = "icmp" + ProtocolIGMP Protocol = "igmp" + ProtocolIPv4 Protocol = "ipv4" + ProtocolIPInIP Protocol = "ip-in-ip" + ProtocolTCP Protocol = "tcp" + ProtocolEGP Protocol = "egp" + ProtocolIGP Protocol = "igp" + ProtocolUDP Protocol = "udp" + ProtocolGRE Protocol = "gre" + ProtocolESP Protocol = "esp" + ProtocolAH Protocol = "ah" + ProtocolIPv6ICMP Protocol = "ipv6-icmp" + ProtocolSCTP Protocol = "sctp" + ProtocolFC Protocol = "fc" + ProtocolWildcard Protocol = "*" +) + +// String returns the string representation of the Protocol. +func (p Protocol) String() string { + return string(p) +} + +// Description returns the human-readable description of the Protocol. +func (p Protocol) Description() string { + switch p { + case ProtocolICMP: + return "Internet Control Message Protocol" + case ProtocolIGMP: + return "Internet Group Management Protocol" + case ProtocolIPv4: + return "IPv4 encapsulation" + case ProtocolTCP: + return "Transmission Control Protocol" + case ProtocolEGP: + return "Exterior Gateway Protocol" + case ProtocolIGP: + return "Interior Gateway Protocol" + case ProtocolUDP: + return "User Datagram Protocol" + case ProtocolGRE: + return "Generic Routing Encapsulation" + case ProtocolESP: + return "Encapsulating Security Payload" + case ProtocolAH: + return "Authentication Header" + case ProtocolIPv6ICMP: + return "Internet Control Message Protocol for IPv6" + case ProtocolSCTP: + return "Stream Control Transmission Protocol" + case ProtocolFC: + return "Fibre Channel" + case ProtocolWildcard: + return "Wildcard (all protocols)" + default: + return "Unknown Protocol" + } +} + +// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. +func (p Protocol) parseProtocol() ([]int, bool) { + switch p { + case "": + return nil, false + case ProtocolWildcard: + // Wildcard protocol - allows all protocols like empty string + return nil, false + case ProtocolIGMP: + return []int{protocolIGMP}, true + case ProtocolIPv4, ProtocolIPInIP: + return []int{protocolIPv4}, true + case ProtocolTCP: + return []int{protocolTCP}, false + case ProtocolEGP: + return []int{protocolEGP}, true + case ProtocolIGP: + return []int{protocolIGP}, true + case ProtocolUDP: + return []int{protocolUDP}, false + case ProtocolGRE: + return []int{protocolGRE}, true + case ProtocolESP: + return []int{protocolESP}, true + case ProtocolAH: + return []int{protocolAH}, true + case ProtocolSCTP: + return []int{protocolSCTP}, false + case ProtocolICMP: + return []int{protocolICMP, protocolIPv6ICMP}, true + default: + // Try to parse as a numeric protocol number + // This should not fail since validation happened during unmarshaling + protocolNumber, _ := strconv.Atoi(string(p)) + + // Determine if wildcard is needed based on protocol number + needsWildcard := protocolNumber != protocolTCP && + protocolNumber != protocolUDP && + protocolNumber != protocolSCTP + + return []int{protocolNumber}, needsWildcard + } +} + +// UnmarshalJSON implements JSON unmarshaling for Protocol. +func (p *Protocol) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + + // Normalize to lowercase for case-insensitive matching + *p = Protocol(strings.ToLower(str)) + + // Validate the protocol + if err := p.validate(); err != nil { + return err + } + + return nil +} + +// validate checks if the Protocol is valid. +func (p Protocol) validate() error { + switch p { + case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP, + ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE, + ProtocolESP, ProtocolAH, ProtocolSCTP: + return nil + case ProtocolWildcard: + // Wildcard "*" is not allowed - Tailscale rejects it + return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + default: + // Try to parse as a numeric protocol number + str := string(p) + + // Check for leading zeros (not allowed by Tailscale) + if str == "0" || (len(str) > 1 && str[0] == '0') { + return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + } + + protocolNumber, err := strconv.Atoi(str) + if err != nil { + return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + } + + if protocolNumber < 0 || protocolNumber > 255 { + return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + } + + return nil + } +} + +// MarshalJSON implements JSON marshaling for Protocol. +func (p Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(string(p)) +} + +// Protocol constants matching the IANA numbers +const ( + protocolICMP = 1 // Internet Control Message + protocolIGMP = 2 // Internet Group Management + protocolIPv4 = 4 // IPv4 encapsulation + protocolTCP = 6 // Transmission Control + protocolEGP = 8 // Exterior Gateway Protocol + protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) + protocolUDP = 17 // User Datagram + protocolGRE = 47 // Generic Routing Encapsulation + protocolESP = 50 // Encap Security Payload + protocolAH = 51 // Authentication Header + protocolIPv6ICMP = 58 // ICMP for IPv6 + protocolSCTP = 132 // Stream Control Transmission Protocol + protocolFC = 133 // Fibre Channel +) + type ACL struct { - Action string `json:"action"` // TODO(kradalby): add strict type - Protocol string `json:"proto"` // TODO(kradalby): add strict type + Action Action `json:"action"` + Protocol Protocol `json:"proto"` Sources Aliases `json:"src"` Destinations []AliasWithPorts `json:"dst"` } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 3669e790..d47c8cf5 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -2464,13 +2464,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { "tags": ["web", "production"], "created_by": "admin" }, - "action": "deny", + "action": "accept", "proto": "udp", "src": ["*"], "dst": ["autogroup:internet:53"] }`, expected: ACL{ - Action: "deny", + Action: ActionAccept, Protocol: "udp", Sources: []Alias{Wildcard}, Destinations: []AliasWithPorts{ @@ -2482,6 +2482,16 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { }, wantErr: false, }, + { + name: "invalid action should fail", + input: `{ + "action": "deny", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + }`, + wantErr: true, + }, { name: "no comment fields", input: `{ @@ -2491,13 +2501,13 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { "dst": ["tag:server:*"] }`, expected: ACL{ - Action: "accept", + Action: ActionAccept, Protocol: "icmp", Sources: []Alias{mustParseAlias("tag:client")}, Destinations: []AliasWithPorts{ { Alias: mustParseAlias("tag:server"), - Ports: []tailcfg.PortRange{{First: 0, Last: 65535}}, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, }, @@ -2510,8 +2520,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { "#reason": "Temporary disable for maintenance" }`, expected: ACL{ - Action: "", - Protocol: "", + Action: Action(""), + Protocol: Protocol(""), Sources: nil, Destinations: nil, }, @@ -2623,9 +2633,9 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { "dst": ["tag:server:22,80,443"] }, { - "#note": "Deny all other traffic", - "action": "deny", - "proto": "*", + "#note": "Allow all other traffic", + "action": "accept", + "proto": "tcp", "src": ["*"], "dst": ["*:*"] } @@ -2641,19 +2651,37 @@ func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { // First ACL acl1 := policy.ACLs[0] - assert.Equal(t, "accept", acl1.Action) - assert.Equal(t, "tcp", acl1.Protocol) + assert.Equal(t, ActionAccept, acl1.Action) + assert.Equal(t, Protocol("tcp"), acl1.Protocol) require.Len(t, acl1.Sources, 1) require.Len(t, acl1.Destinations, 1) // Second ACL acl2 := policy.ACLs[1] - assert.Equal(t, "deny", acl2.Action) - assert.Equal(t, "*", acl2.Protocol) + assert.Equal(t, ActionAccept, acl2.Action) + assert.Equal(t, Protocol("tcp"), acl2.Protocol) require.Len(t, acl2.Sources, 1) require.Len(t, acl2.Destinations, 1) } +func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { + // Test that invalid actions are rejected + policyJSON := `{ + "acls": [ + { + "action": "deny", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + _, err := unmarshalPolicy([]byte(policyJSON)) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid action "deny"`) +} + // Helper function to parse aliases for testing func mustParseAlias(s string) Alias { alias, err := parseAlias(s) diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index 2c551eda..7482c97b 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -2,7 +2,6 @@ package v2 import ( "errors" - "fmt" "slices" "strconv" "strings" @@ -97,72 +96,3 @@ func parsePort(portStr string) (uint16, error) { return uint16(port), nil } - -// For some reason golang.org/x/net/internal/iana is an internal package. -const ( - protocolICMP = 1 // Internet Control Message - protocolIGMP = 2 // Internet Group Management - protocolIPv4 = 4 // IPv4 encapsulation - protocolTCP = 6 // Transmission Control - protocolEGP = 8 // Exterior Gateway Protocol - protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) - protocolUDP = 17 // User Datagram - protocolGRE = 47 // Generic Routing Encapsulation - protocolESP = 50 // Encap Security Payload - protocolAH = 51 // Authentication Header - protocolIPv6ICMP = 58 // ICMP for IPv6 - protocolSCTP = 132 // Stream Control Transmission Protocol - ProtocolFC = 133 // Fibre Channel -) - -// parseProtocol reads the proto field of the ACL and generates a list of -// protocols that will be allowed, following the IANA IP protocol number -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// -// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, -// as per Tailscale behaviour (see tailcfg.FilterRule). -// -// Also returns a boolean indicating if the protocol -// requires all the destinations to use wildcard as port number (only TCP, -// UDP and SCTP support specifying ports). -func parseProtocol(protocol string) ([]int, bool, error) { - switch protocol { - case "": - return nil, false, nil - case "igmp": - return []int{protocolIGMP}, true, nil - case "ipv4", "ip-in-ip": - return []int{protocolIPv4}, true, nil - case "tcp": - return []int{protocolTCP}, false, nil - case "egp": - return []int{protocolEGP}, true, nil - case "igp": - return []int{protocolIGP}, true, nil - case "udp": - return []int{protocolUDP}, false, nil - case "gre": - return []int{protocolGRE}, true, nil - case "esp": - return []int{protocolESP}, true, nil - case "ah": - return []int{protocolAH}, true, nil - case "sctp": - return []int{protocolSCTP}, false, nil - case "icmp": - return []int{protocolICMP, protocolIPv6ICMP}, true, nil - - default: - protocolNumber, err := strconv.Atoi(protocol) - if err != nil { - return nil, false, fmt.Errorf("parsing protocol number: %w", err) - } - - // TODO(kradalby): What is this? - needsWildcard := protocolNumber != protocolTCP && - protocolNumber != protocolUDP && - protocolNumber != protocolSCTP - - return []int{protocolNumber}, needsWildcard, nil - } -}