diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index ad822feb..1101cdd7 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1242,6 +1242,45 @@ type ACL struct { Destinations []AliasWithPorts `json:"dst"` } +// UnmarshalJSON implements custom unmarshalling for ACL that ignores fields starting with '#'. +// headscale-admin uses # in some field names to add metadata, so we will ignore +// those to ensure it doesnt break. +// https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38 +func (a *ACL) UnmarshalJSON(b []byte) error { + // First unmarshal into a map to filter out comment fields + var raw map[string]any + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + + // Remove any fields that start with '#' + filtered := make(map[string]any) + for key, value := range raw { + if !strings.HasPrefix(key, "#") { + filtered[key] = value + } + } + + // Marshal the filtered map back to JSON + filteredBytes, err := json.Marshal(filtered) + if err != nil { + return err + } + + // Create a type alias to avoid infinite recursion + type aclAlias ACL + var temp aclAlias + + // Unmarshal into the temporary struct using the v2 JSON options + if err := json.Unmarshal(filteredBytes, &temp, json.DefaultOptionsV2(), json.MatchCaseInsensitiveNames(true)); err != nil { + return err + } + + // Copy the result back to the original struct + *a = ACL(temp) + return nil +} + // Policy represents a Tailscale Network Policy. // TODO(kradalby): // Add validation method checking: diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index d2ef5502..80cdd02d 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -2300,3 +2300,263 @@ func TestNodeCanHaveTag(t *testing.T) { }) } } + +func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { + tests := []struct { + name string + input string + expected ACL + wantErr bool + }{ + { + name: "basic ACL with comment fields", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp", + "src": ["user1@example.com"], + "dst": ["tag:server:80"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("user1@example.com")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + wantErr: false, + }, + { + name: "multiple comment fields", + input: `{ + "#description": "Allow access to web servers", + "#note": "Created by admin", + "#created_date": "2024-01-15", + "action": "accept", + "proto": "tcp", + "src": ["group:developers"], + "dst": ["10.0.0.0/24:443"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("group:developers")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("10.0.0.0/24"), + Ports: []tailcfg.PortRange{{First: 443, Last: 443}}, + }, + }, + }, + wantErr: false, + }, + { + name: "comment field with complex object value", + input: `{ + "#metadata": { + "description": "Complex comment object", + "tags": ["web", "production"], + "created_by": "admin" + }, + "action": "deny", + "proto": "udp", + "src": ["*"], + "dst": ["autogroup:internet:53"] + }`, + expected: ACL{ + Action: "deny", + Protocol: "udp", + Sources: []Alias{Wildcard}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("autogroup:internet"), + Ports: []tailcfg.PortRange{{First: 53, Last: 53}}, + }, + }, + }, + wantErr: false, + }, + { + name: "no comment fields", + input: `{ + "action": "accept", + "proto": "icmp", + "src": ["tag:client"], + "dst": ["tag:server:*"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "icmp", + Sources: []Alias{mustParseAlias("tag:client")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 0, Last: 65535}}, + }, + }, + }, + wantErr: false, + }, + { + name: "only comment fields", + input: `{ + "#comment": "This rule is disabled", + "#reason": "Temporary disable for maintenance" + }`, + expected: ACL{ + Action: "", + Protocol: "", + Sources: nil, + Destinations: nil, + }, + wantErr: false, + }, + { + name: "invalid JSON", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp" + "src": ["invalid json"] + }`, + wantErr: true, + }, + { + name: "invalid field after comment filtering", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp", + "src": ["user1@example.com"], + "dst": ["invalid-destination"] + }`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) + + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected.Action, acl.Action) + assert.Equal(t, tt.expected.Protocol, acl.Protocol) + assert.Equal(t, len(tt.expected.Sources), len(acl.Sources)) + assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations)) + + // Compare sources + for i, expectedSrc := range tt.expected.Sources { + if i < len(acl.Sources) { + assert.Equal(t, expectedSrc, acl.Sources[i]) + } + } + + // Compare destinations + for i, expectedDst := range tt.expected.Destinations { + if i < len(acl.Destinations) { + assert.Equal(t, expectedDst.Alias, acl.Destinations[i].Alias) + assert.Equal(t, expectedDst.Ports, acl.Destinations[i].Ports) + } + } + }) + } +} + +func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { + // Test that marshaling and unmarshaling preserves data (excluding comments) + original := ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("group:admins")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 22, Last: 22}, {First: 80, Last: 80}}, + }, + }, + } + + // Marshal to JSON + jsonBytes, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal back + var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) + require.NoError(t, err) + + // Should be equal + assert.Equal(t, original.Action, unmarshaled.Action) + assert.Equal(t, original.Protocol, unmarshaled.Protocol) + assert.Equal(t, len(original.Sources), len(unmarshaled.Sources)) + assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations)) +} + +func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { + // Test that ACL unmarshaling works within a Policy context + policyJSON := `{ + "groups": { + "group:developers": ["user1@example.com", "user2@example.com"] + }, + "tagOwners": { + "tag:server": ["group:developers"] + }, + "acls": [ + { + "#description": "Allow developers to access servers", + "#priority": "high", + "action": "accept", + "proto": "tcp", + "src": ["group:developers"], + "dst": ["tag:server:22,80,443"] + }, + { + "#note": "Deny all other traffic", + "action": "deny", + "proto": "*", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + policy, err := unmarshalPolicy([]byte(policyJSON)) + require.NoError(t, err) + require.NotNil(t, policy) + + // Check that ACLs were parsed correctly + require.Len(t, policy.ACLs, 2) + + // First ACL + acl1 := policy.ACLs[0] + assert.Equal(t, "accept", acl1.Action) + assert.Equal(t, "tcp", acl1.Protocol) + require.Len(t, acl1.Sources, 1) + require.Len(t, acl1.Destinations, 1) + + // Second ACL + acl2 := policy.ACLs[1] + assert.Equal(t, "deny", acl2.Action) + assert.Equal(t, "*", acl2.Protocol) + require.Len(t, acl2.Sources, 1) + require.Len(t, acl2.Destinations, 1) +} + +// Helper function to parse aliases for testing +func mustParseAlias(s string) Alias { + alias, err := parseAlias(s) + if err != nil { + panic(err) + } + return alias +}