1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-02 13:47:00 +02:00

policy: add user helper for tests

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-21 23:18:56 +02:00
parent 19add15927
commit 27a518a2fa
No known key found for this signature in database
4 changed files with 208 additions and 130 deletions

View File

@ -5,14 +5,11 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func TestParsing(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser"},
}
tu := GetTestUsers()
tests := []struct {
name string
format string
@ -352,18 +349,19 @@ func TestParsing(t *testing.T) {
}
rules, err := pol.compileFilterRules(
users,
tu.FilteredSlice("testuser"),
types.Nodes{
&types.Node{
IPv4: ap("100.100.100.100"),
},
&types.Node{
IPv4: ap("200.200.200.200"),
User: &users[0],
UserID: &users[0].ID,
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{},
},
})
},
)
if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)

View File

@ -8,7 +8,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
@ -25,10 +24,7 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo)
}
func TestPolicyManager(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
}
tu := GetTestUsers()
tests := []struct {
name string
@ -48,7 +44,7 @@ func TestPolicyManager(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
pm, err := NewPolicyManager([]byte(tt.pol), tu.FilteredSlice("testuser", "otheruser"), tt.nodes)
require.NoError(t, err)
filter, matchers := pm.Filter()

View File

@ -0,0 +1,96 @@
package v2
import (
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
)
// TestUsers provides a convenient way to manage test users across tests
type TestUsers struct {
users map[string]*types.User
once sync.Once
}
var defaultTestUsers TestUsers
// GetTestUsers returns a singleton instance of TestUsers with predefined test users
func GetTestUsers() *TestUsers {
defaultTestUsers.once.Do(func() {
defaultTestUsers.users = map[string]*types.User{
"testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"},
"groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"},
"groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"},
"groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"},
"notme": {Model: gorm.Model{ID: 5}, Name: "notme"},
"user1": {Model: gorm.Model{ID: 6}, Name: "user1"},
"user2": {Model: gorm.Model{ID: 7}, Name: "user2"},
"user3": {Model: gorm.Model{ID: 8}, Name: "user3"},
"otheruser": {Model: gorm.Model{ID: 9}, Name: "otheruser", Email: "otheruser@headscale.net"},
"mickael": {Model: gorm.Model{ID: 10}, Name: "mickael"},
"user100": {Model: gorm.Model{ID: 11}, Name: "user100"},
}
})
return &defaultTestUsers
}
// User returns a copy of the User with the given name
func (tu *TestUsers) User(name string) types.User {
if user, ok := tu.users[name]; ok {
return *user
}
// Return empty user if not found
return types.User{}
}
// UserPtr returns a pointer to the User with the given name
func (tu *TestUsers) UserPtr(name string) *types.User {
return tu.users[name]
}
// ID returns the ID for the given user name
func (tu *TestUsers) ID(name string) uint {
if user, ok := tu.users[name]; ok {
return user.ID
}
return 0
}
// IDPtr returns a pointer to the ID for the given user name
func (tu *TestUsers) IDPtr(name string) *uint {
if user, ok := tu.users[name]; ok {
id := user.ID
return &id
}
return nil
}
// AsMap returns all users as a map
func (tu *TestUsers) AsMap() map[string]types.User {
result := make(map[string]types.User, len(tu.users))
for name, user := range tu.users {
result[name] = *user
}
return result
}
// AsSlice returns all users as a slice
func (tu *TestUsers) AsSlice() types.Users {
result := make(types.Users, 0, len(tu.users))
for _, user := range tu.users {
result = append(result, *user)
}
return result
}
// FilteredSlice returns a slice with only the specified user names
func (tu *TestUsers) FilteredSlice(names ...string) types.Users {
result := make(types.Users, 0, len(names))
for _, name := range names {
if user, ok := tu.users[name]; ok {
result = append(result, *user)
}
}
return result
}

View File

@ -16,8 +16,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
xmaps "golang.org/x/exp/maps"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
@ -1029,13 +1027,7 @@ func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) }
func p(pref string) Prefix { return Prefix(mp(pref)) }
func TestResolvePolicy(t *testing.T) {
users := map[string]types.User{
"testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"},
"groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"},
"groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"},
"groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"},
"notme": {Model: gorm.Model{ID: 5}, Name: "notme"},
}
tu := GetTestUsers()
tests := []struct {
name string
nodes types.Nodes
@ -1065,14 +1057,14 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: ptr.To(users["notme"]),
UserID: ptr.To(users["notme"].ID),
User: tu.UserPtr("notme"),
UserID: tu.IDPtr("notme"),
IPv4: ap("100.100.101.1"),
},
// Not matching tags, usernames are ignored if a node is tagged
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
},
@ -1080,21 +1072,21 @@ func TestResolvePolicy(t *testing.T) {
// since 0.27.0, tags are only considered when
// set directly on the node, not via pak.
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
AuthKey: &types.PreAuthKey{
Tags: []string{"alsotagged"},
},
IPv4: ap("100.100.101.3"),
},
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
IPv4: ap("100.100.101.103"),
},
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
IPv4: ap("100.100.101.104"),
},
},
@ -1106,14 +1098,14 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: ptr.To(users["notme"]),
UserID: ptr.To(users["notme"].ID),
User: tu.UserPtr("notme"),
UserID: tu.IDPtr("notme"),
IPv4: ap("100.100.101.4"),
},
// Not matching forced tags
{
User: ptr.To(users["groupuser"]),
UserID: ptr.To(users["groupuser"].ID),
User: tu.UserPtr("groupuser"),
UserID: tu.IDPtr("groupuser"),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
},
@ -1121,30 +1113,30 @@ func TestResolvePolicy(t *testing.T) {
// since 0.27.0, tags are only considered when
// set directly on the node, not via pak.
{
User: ptr.To(users["groupuser"]),
UserID: ptr.To(users["groupuser"].ID),
User: tu.UserPtr("groupuser"),
UserID: tu.IDPtr("groupuser"),
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
IPv4: ap("100.100.101.6"),
},
{
User: ptr.To(users["groupuser"]),
UserID: ptr.To(users["groupuser"].ID),
User: tu.UserPtr("groupuser"),
UserID: tu.IDPtr("groupuser"),
IPv4: ap("100.100.101.203"),
},
// not matchin username because tagged
// since 0.27.0, tags are only considered when
// set directly on the node, not via pak.
{
User: ptr.To(users["groupuser"]),
UserID: ptr.To(users["groupuser"].ID),
User: tu.UserPtr("groupuser"),
UserID: tu.IDPtr("groupuser"),
Tags: []string{"tag:taggg"},
IPv4: ap("100.100.101.209"),
},
{
User: ptr.To(users["groupuser1"]),
UserID: ptr.To(users["groupuser1"].ID),
User: tu.UserPtr("groupuser1"),
UserID: tu.IDPtr("groupuser1"),
IPv4: ap("100.100.101.204"),
},
},
@ -1162,8 +1154,8 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: ptr.To(users["notme"]),
UserID: ptr.To(users["notme"].ID),
User: tu.UserPtr("notme"),
UserID: tu.IDPtr("notme"),
IPv4: ap("100.100.101.9"),
},
// Not matching forced tags
@ -1185,8 +1177,8 @@ func TestResolvePolicy(t *testing.T) {
},
// matching tag with user (user is ignored)
{
User: ptr.To(users["notme"]),
UserID: ptr.To(users["notme"].ID),
User: tu.UserPtr("notme"),
UserID: tu.IDPtr("notme"),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.109"),
},
@ -1225,13 +1217,13 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(Group("group:testgroup")),
nodes: types.Nodes{
{
User: ptr.To(users["groupuser1"]),
UserID: ptr.To(users["groupuser1"].ID),
User: tu.UserPtr("groupuser1"),
UserID: tu.IDPtr("groupuser1"),
IPv4: ap("100.100.101.203"),
},
{
User: ptr.To(users["groupuser2"]),
UserID: ptr.To(users["groupuser2"].ID),
User: tu.UserPtr("groupuser2"),
UserID: tu.IDPtr("groupuser2"),
IPv4: ap("100.100.101.204"),
},
},
@ -1252,8 +1244,8 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(Username("invaliduser@")),
nodes: types.Nodes{
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
IPv4: ap("100.100.101.103"),
},
},
@ -1285,21 +1277,21 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Node with no tags (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
IPv4: ap("100.100.101.1"),
},
// Node with forced tags (should be excluded)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be excluded)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@ -1307,8 +1299,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with non-allowed requested tag (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
@ -1316,8 +1308,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, one allowed (should be excluded)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
@ -1325,8 +1317,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, none allowed (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
@ -1350,21 +1342,21 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Node with no tags (should be excluded)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
IPv4: ap("100.100.101.1"),
},
// Node with forced tag (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@ -1372,8 +1364,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with non-allowed requested tag (should be excluded)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
@ -1381,8 +1373,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, one allowed (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
@ -1390,8 +1382,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, none allowed (should be excluded)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
@ -1399,8 +1391,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple forced tags (should be included)
{
User: ptr.To(users["testuser"]),
UserID: ptr.To(users["testuser"].ID),
User: tu.UserPtr("testuser"),
UserID: tu.IDPtr("testuser"),
Tags: []string{"tag:test", "tag:other"},
IPv4: ap("100.100.101.7"),
},
@ -1426,7 +1418,7 @@ func TestResolvePolicy(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol,
xmaps.Values(users),
tu.AsSlice(),
tt.nodes)
if tt.wantErr == "" {
if err != nil {
@ -1455,27 +1447,23 @@ func TestResolvePolicy(t *testing.T) {
}
func TestResolveAutoApprovers(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
{Model: gorm.Model{ID: 3}, Name: "user3"},
}
tu := GetTestUsers()
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: &users[0],
UserID: &users[0].ID,
User: tu.UserPtr("user1"),
UserID: tu.IDPtr("user1"),
},
{
IPv4: ap("100.64.0.2"),
User: &users[1],
UserID: &users[1].ID,
User: tu.UserPtr("user2"),
UserID: tu.IDPtr("user2"),
},
{
IPv4: ap("100.64.0.3"),
User: &users[2],
UserID: &users[2].ID,
User: tu.UserPtr("user3"),
UserID: tu.IDPtr("user3"),
},
{
IPv4: ap("100.64.0.4"),
@ -1610,7 +1598,7 @@ func TestResolveAutoApprovers(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes)
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, tu.FilteredSlice("user1", "user2", "user3"), nodes)
if (err != nil) != tt.wantErr {
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
return
@ -1648,27 +1636,27 @@ func ipSetComparer(x, y *netipx.IPSet) bool {
}
func TestNodeCanApproveRoute(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
{Model: gorm.Model{ID: 3}, Name: "user3"},
}
tu := GetTestUsers()
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: &users[0],
UserID: &users[0].ID,
User: tu.UserPtr("user1"),
UserID: tu.IDPtr("user1"),
},
{
IPv4: ap("100.64.0.2"),
User: &users[1],
UserID: &users[1].ID,
User: tu.UserPtr("user2"),
UserID: tu.IDPtr("user2"),
},
{
IPv4: ap("100.64.0.3"),
User: &users[2],
UserID: &users[2].ID,
User: tu.UserPtr("user3"),
UserID: tu.IDPtr("user3"),
},
{
IPv4: ap("100.64.0.4"),
Tags: []string{"tag:testtag"},
},
}
@ -1772,7 +1760,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
b, err := json.Marshal(tt.policy)
require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes)
pm, err := NewPolicyManager(b, tu.FilteredSlice("user1", "user2", "user3"), nodes)
require.NoErrorf(t, err, "NewPolicyManager() error = %v", err)
got := pm.NodeCanApproveRoute(tt.node, tt.route)
@ -1784,27 +1772,27 @@ func TestNodeCanApproveRoute(t *testing.T) {
}
func TestResolveTagOwners(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
{Model: gorm.Model{ID: 3}, Name: "user3"},
}
tu := GetTestUsers()
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: &users[0],
UserID: &users[0].ID,
User: tu.UserPtr("user1"),
UserID: tu.IDPtr("user1"),
},
{
IPv4: ap("100.64.0.2"),
User: &users[1],
UserID: &users[1].ID,
User: tu.UserPtr("user2"),
UserID: tu.IDPtr("user2"),
},
{
IPv4: ap("100.64.0.3"),
User: &users[2],
UserID: &users[2].ID,
User: tu.UserPtr("user3"),
UserID: tu.IDPtr("user3"),
},
{
IPv4: ap("100.64.0.4"),
Tags: []string{"tag:testtag"},
},
}
@ -1859,7 +1847,7 @@ func TestResolveTagOwners(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := resolveTagOwners(tt.policy, users, nodes)
got, err := resolveTagOwners(tt.policy, tu.FilteredSlice("user1", "user2", "user3"), nodes)
if (err != nil) != tt.wantErr {
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
return
@ -1872,27 +1860,27 @@ func TestResolveTagOwners(t *testing.T) {
}
func TestNodeCanHaveTag(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
{Model: gorm.Model{ID: 3}, Name: "user3"},
}
tu := GetTestUsers()
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: &users[0],
UserID: &users[0].ID,
User: tu.UserPtr("user1"),
UserID: tu.IDPtr("user1"),
},
{
IPv4: ap("100.64.0.2"),
User: &users[1],
UserID: &users[1].ID,
User: tu.UserPtr("user2"),
UserID: tu.IDPtr("user2"),
},
{
IPv4: ap("100.64.0.3"),
User: &users[2],
UserID: &users[2].ID,
User: tu.UserPtr("user3"),
UserID: tu.IDPtr("user3"),
},
{
IPv4: ap("100.64.0.4"),
Tags: []string{"tag:testtag"},
},
}
@ -1973,7 +1961,7 @@ func TestNodeCanHaveTag(t *testing.T) {
b, err := json.Marshal(tt.policy)
require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes)
pm, err := NewPolicyManager(b, tu.FilteredSlice("user1", "user2", "user3"), nodes)
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
return