1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-07-27 13:48:02 +02:00

policy: add ssh unmarshal tests

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-16 10:28:26 +02:00
parent 065c41d1d8
commit 122308ba36
No known key found for this signature in database
2 changed files with 152 additions and 9 deletions

View File

@ -1567,6 +1567,8 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) {
aliases[i] = string(*v)
case *Host:
aliases[i] = string(*v)
case Asterix:
aliases[i] = "*"
default:
return nil, fmt.Errorf("unknown SSH destination alias type: %T", v)
}
@ -1592,6 +1594,8 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
aliases[i] = string(*v)
case *AutoGroup:
aliases[i] = string(*v)
case Asterix:
aliases[i] = "*"
default:
return nil, fmt.Errorf("unknown SSH source alias type: %T", v)
}

View File

@ -10,6 +10,8 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
@ -84,10 +86,10 @@ func TestMarshalJSON(t *testing.T) {
require.NoError(t, err)
// Compare the original and round-tripped policies
cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool {
return x == y
}))
cmps = append(cmps,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
cmpopts.EquateEmpty(),
)
@ -589,6 +591,138 @@ func TestUnmarshalPolicy(t *testing.T) {
`,
wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`,
},
{
name: "ssh-basic",
input: `
{
"groups": {
"group:admins": ["admin@example.com"]
},
"tagOwners": {
"tag:servers": ["group:admins"]
},
"ssh": [
{
"action": "accept",
"src": [
"group:admins"
],
"dst": [
"tag:servers"
],
"users": ["root", "admin"]
}
]
}
`,
want: &Policy{
Groups: Groups{
Group("group:admins"): []Username{Username("admin@example.com")},
},
TagOwners: TagOwners{
Tag("tag:servers"): Owners{gp("group:admins")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{
gp("group:admins"),
},
Destinations: SSHDstAliases{
tp("tag:servers"),
},
Users: []SSHUser{
SSHUser("root"),
SSHUser("admin"),
},
},
},
},
},
{
name: "ssh-with-tag-and-user",
input: `
{
"tagOwners": {
"tag:web": ["admin@example.com"]
},
"ssh": [
{
"action": "accept",
"src": [
"tag:web"
],
"dst": [
"admin@example.com"
],
"users": ["*"]
}
]
}
`,
want: &Policy{
TagOwners: TagOwners{
Tag("tag:web"): Owners{ptr.To(Username("admin@example.com"))},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{
tp("tag:web"),
},
Destinations: SSHDstAliases{
ptr.To(Username("admin@example.com")),
},
Users: []SSHUser{
SSHUser("*"),
},
},
},
},
},
{
name: "ssh-with-check-period",
input: `
{
"groups": {
"group:admins": ["admin@example.com"]
},
"ssh": [
{
"action": "accept",
"src": [
"group:admins"
],
"dst": [
"admin@example.com"
],
"users": ["root"],
"checkPeriod": "24h"
}
]
}
`,
want: &Policy{
Groups: Groups{
Group("group:admins"): []Username{Username("admin@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{
gp("group:admins"),
},
Destinations: SSHDstAliases{
ptr.To(Username("admin@example.com")),
},
Users: []SSHUser{
SSHUser("root"),
},
CheckPeriod: model.Duration(24 * time.Hour),
},
},
},
},
{
name: "group-must-be-defined-acl-src",
input: `
@ -786,10 +920,12 @@ func TestUnmarshalPolicy(t *testing.T) {
},
}
cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool {
return x == y
}))
cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{}))
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
)
// For round-trip testing, we'll normalize the policies before comparing
@ -829,7 +965,10 @@ func TestUnmarshalPolicy(t *testing.T) {
}
// Add EquateEmpty to handle nil vs empty maps/slices
roundTripCmps := append(cmps, cmpopts.EquateEmpty())
roundTripCmps := append(cmps,
cmpopts.EquateEmpty(),
cmpopts.IgnoreUnexported(Policy{}),
)
// Compare using the enhanced comparers for round-trip testing
if diff := cmp.Diff(policy, roundTripped, roundTripCmps...); diff != "" {