diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index ec1bddb0..7095f666 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -5,14 +5,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" - "gorm.io/gorm" "tailscale.com/tailcfg" ) func TestParsing(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "testuser"}, - } + tu := GetTestUsers() tests := []struct { name string format string @@ -352,18 +349,19 @@ func TestParsing(t *testing.T) { } rules, err := pol.compileFilterRules( - users, + tu.FilteredSlice("testuser"), types.Nodes{ &types.Node{ IPv4: ap("100.100.100.100"), }, &types.Node{ IPv4: ap("200.200.200.200"), - User: &users[0], - UserID: &users[0].ID, + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{}, }, - }) + }, + ) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index e5d13b2b..26735ae3 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -8,7 +8,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/require" - "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -25,10 +24,7 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) } func TestPolicyManager(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"}, - {Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"}, - } + tu := GetTestUsers() tests := []struct { name string @@ -48,7 +44,7 @@ func TestPolicyManager(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) + pm, err := NewPolicyManager([]byte(tt.pol), tu.FilteredSlice("testuser", "otheruser"), tt.nodes) require.NoError(t, err) filter, matchers := pm.Filter() diff --git a/hscontrol/policy/v2/testusers_helper.go b/hscontrol/policy/v2/testusers_helper.go new file mode 100644 index 00000000..63d692f6 --- /dev/null +++ b/hscontrol/policy/v2/testusers_helper.go @@ -0,0 +1,96 @@ +package v2 + +import ( + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "gorm.io/gorm" +) + +// TestUsers provides a convenient way to manage test users across tests +type TestUsers struct { + users map[string]*types.User + once sync.Once +} + +var defaultTestUsers TestUsers + +// GetTestUsers returns a singleton instance of TestUsers with predefined test users +func GetTestUsers() *TestUsers { + defaultTestUsers.once.Do(func() { + defaultTestUsers.users = map[string]*types.User{ + "testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"}, + "groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"}, + "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, + "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, + "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, + "user1": {Model: gorm.Model{ID: 6}, Name: "user1"}, + "user2": {Model: gorm.Model{ID: 7}, Name: "user2"}, + "user3": {Model: gorm.Model{ID: 8}, Name: "user3"}, + "otheruser": {Model: gorm.Model{ID: 9}, Name: "otheruser", Email: "otheruser@headscale.net"}, + "mickael": {Model: gorm.Model{ID: 10}, Name: "mickael"}, + "user100": {Model: gorm.Model{ID: 11}, Name: "user100"}, + } + }) + return &defaultTestUsers +} + +// User returns a copy of the User with the given name +func (tu *TestUsers) User(name string) types.User { + if user, ok := tu.users[name]; ok { + return *user + } + // Return empty user if not found + return types.User{} +} + +// UserPtr returns a pointer to the User with the given name +func (tu *TestUsers) UserPtr(name string) *types.User { + return tu.users[name] +} + +// ID returns the ID for the given user name +func (tu *TestUsers) ID(name string) uint { + if user, ok := tu.users[name]; ok { + return user.ID + } + return 0 +} + +// IDPtr returns a pointer to the ID for the given user name +func (tu *TestUsers) IDPtr(name string) *uint { + if user, ok := tu.users[name]; ok { + id := user.ID + return &id + } + return nil +} + +// AsMap returns all users as a map +func (tu *TestUsers) AsMap() map[string]types.User { + result := make(map[string]types.User, len(tu.users)) + for name, user := range tu.users { + result[name] = *user + } + return result +} + +// AsSlice returns all users as a slice +func (tu *TestUsers) AsSlice() types.Users { + result := make(types.Users, 0, len(tu.users)) + for _, user := range tu.users { + result = append(result, *user) + } + return result +} + +// FilteredSlice returns a slice with only the specified user names +func (tu *TestUsers) FilteredSlice(names ...string) types.Users { + result := make(types.Users, 0, len(names)) + for _, name := range names { + if user, ok := tu.users[name]; ok { + result = append(result, *user) + } + } + return result +} \ No newline at end of file diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 83230605..1255b12b 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -16,8 +16,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go4.org/netipx" - xmaps "golang.org/x/exp/maps" - "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/ptr" @@ -1029,13 +1027,7 @@ func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) } func p(pref string) Prefix { return Prefix(mp(pref)) } func TestResolvePolicy(t *testing.T) { - users := map[string]types.User{ - "testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"}, - "groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"}, - "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, - "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, - "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, - } + tu := GetTestUsers() tests := []struct { name string nodes types.Nodes @@ -1065,14 +1057,14 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: ptr.To(users["notme"]), - UserID: ptr.To(users["notme"].ID), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), IPv4: ap("100.100.101.1"), }, // Not matching tags, usernames are ignored if a node is tagged { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.2"), }, @@ -1080,21 +1072,21 @@ func TestResolvePolicy(t *testing.T) { // since 0.27.0, tags are only considered when // set directly on the node, not via pak. { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), AuthKey: &types.PreAuthKey{ Tags: []string{"alsotagged"}, }, IPv4: ap("100.100.101.3"), }, { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), IPv4: ap("100.100.101.103"), }, { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), IPv4: ap("100.100.101.104"), }, }, @@ -1106,14 +1098,14 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: ptr.To(users["notme"]), - UserID: ptr.To(users["notme"].ID), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), IPv4: ap("100.100.101.4"), }, // Not matching forced tags { - User: ptr.To(users["groupuser"]), - UserID: ptr.To(users["groupuser"].ID), + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.5"), }, @@ -1121,30 +1113,30 @@ func TestResolvePolicy(t *testing.T) { // since 0.27.0, tags are only considered when // set directly on the node, not via pak. { - User: ptr.To(users["groupuser"]), - UserID: ptr.To(users["groupuser"].ID), + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), AuthKey: &types.PreAuthKey{ Tags: []string{"tag:alsotagged"}, }, IPv4: ap("100.100.101.6"), }, { - User: ptr.To(users["groupuser"]), - UserID: ptr.To(users["groupuser"].ID), + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), IPv4: ap("100.100.101.203"), }, // not matchin username because tagged // since 0.27.0, tags are only considered when // set directly on the node, not via pak. { - User: ptr.To(users["groupuser"]), - UserID: ptr.To(users["groupuser"].ID), + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), Tags: []string{"tag:taggg"}, IPv4: ap("100.100.101.209"), }, { - User: ptr.To(users["groupuser1"]), - UserID: ptr.To(users["groupuser1"].ID), + User: tu.UserPtr("groupuser1"), + UserID: tu.IDPtr("groupuser1"), IPv4: ap("100.100.101.204"), }, }, @@ -1162,8 +1154,8 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: ptr.To(users["notme"]), - UserID: ptr.To(users["notme"].ID), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), IPv4: ap("100.100.101.9"), }, // Not matching forced tags @@ -1185,8 +1177,8 @@ func TestResolvePolicy(t *testing.T) { }, // matching tag with user (user is ignored) { - User: ptr.To(users["notme"]), - UserID: ptr.To(users["notme"].ID), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.109"), }, @@ -1225,13 +1217,13 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Group("group:testgroup")), nodes: types.Nodes{ { - User: ptr.To(users["groupuser1"]), - UserID: ptr.To(users["groupuser1"].ID), + User: tu.UserPtr("groupuser1"), + UserID: tu.IDPtr("groupuser1"), IPv4: ap("100.100.101.203"), }, { - User: ptr.To(users["groupuser2"]), - UserID: ptr.To(users["groupuser2"].ID), + User: tu.UserPtr("groupuser2"), + UserID: tu.IDPtr("groupuser2"), IPv4: ap("100.100.101.204"), }, }, @@ -1252,8 +1244,8 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Username("invaliduser@")), nodes: types.Nodes{ { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), IPv4: ap("100.100.101.103"), }, }, @@ -1285,21 +1277,21 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), IPv4: ap("100.100.101.1"), }, // Node with forced tags (should be excluded) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be excluded) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1307,8 +1299,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1316,8 +1308,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be excluded) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1325,8 +1317,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1350,21 +1342,21 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be excluded) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), IPv4: ap("100.100.101.1"), }, // Node with forced tag (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1372,8 +1364,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be excluded) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1381,8 +1373,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1390,8 +1382,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be excluded) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1399,8 +1391,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple forced tags (should be included) { - User: ptr.To(users["testuser"]), - UserID: ptr.To(users["testuser"].ID), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.7"), }, @@ -1426,7 +1418,7 @@ func TestResolvePolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ips, err := tt.toResolve.Resolve(tt.pol, - xmaps.Values(users), + tu.AsSlice(), tt.nodes) if tt.wantErr == "" { if err != nil { @@ -1455,27 +1447,23 @@ func TestResolvePolicy(t *testing.T) { } func TestResolveAutoApprovers(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: &users[0], - UserID: &users[0].ID, + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { IPv4: ap("100.64.0.2"), - User: &users[1], - UserID: &users[1].ID, + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { IPv4: ap("100.64.0.3"), - User: &users[2], - UserID: &users[2].ID, + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), }, { IPv4: ap("100.64.0.4"), @@ -1610,7 +1598,7 @@ func TestResolveAutoApprovers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes) + got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, tu.FilteredSlice("user1", "user2", "user3"), nodes) if (err != nil) != tt.wantErr { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return @@ -1648,27 +1636,27 @@ func ipSetComparer(x, y *netipx.IPSet) bool { } func TestNodeCanApproveRoute(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: &users[0], - UserID: &users[0].ID, + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { IPv4: ap("100.64.0.2"), - User: &users[1], - UserID: &users[1].ID, + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { IPv4: ap("100.64.0.3"), - User: &users[2], - UserID: &users[2].ID, + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, } @@ -1772,7 +1760,7 @@ func TestNodeCanApproveRoute(t *testing.T) { b, err := json.Marshal(tt.policy) require.NoError(t, err) - pm, err := NewPolicyManager(b, users, nodes) + pm, err := NewPolicyManager(b, tu.FilteredSlice("user1", "user2", "user3"), nodes) require.NoErrorf(t, err, "NewPolicyManager() error = %v", err) got := pm.NodeCanApproveRoute(tt.node, tt.route) @@ -1784,27 +1772,27 @@ func TestNodeCanApproveRoute(t *testing.T) { } func TestResolveTagOwners(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: &users[0], - UserID: &users[0].ID, + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { IPv4: ap("100.64.0.2"), - User: &users[1], - UserID: &users[1].ID, + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { IPv4: ap("100.64.0.3"), - User: &users[2], - UserID: &users[2].ID, + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, } @@ -1859,7 +1847,7 @@ func TestResolveTagOwners(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := resolveTagOwners(tt.policy, users, nodes) + got, err := resolveTagOwners(tt.policy, tu.FilteredSlice("user1", "user2", "user3"), nodes) if (err != nil) != tt.wantErr { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return @@ -1872,27 +1860,27 @@ func TestResolveTagOwners(t *testing.T) { } func TestNodeCanHaveTag(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: &users[0], - UserID: &users[0].ID, + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { IPv4: ap("100.64.0.2"), - User: &users[1], - UserID: &users[1].ID, + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { IPv4: ap("100.64.0.3"), - User: &users[2], - UserID: &users[2].ID, + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, } @@ -1973,7 +1961,7 @@ func TestNodeCanHaveTag(t *testing.T) { b, err := json.Marshal(tt.policy) require.NoError(t, err) - pm, err := NewPolicyManager(b, users, nodes) + pm, err := NewPolicyManager(b, tu.FilteredSlice("user1", "user2", "user3"), nodes) if tt.wantErr != "" { require.ErrorContains(t, err, tt.wantErr) return