diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index c92a4497..443cb765 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -752,3 +752,81 @@ func TestRenameNode(t *testing.T) { }) assert.ErrorContains(t, err, "name is not unique") } + +func TestListNodesSubset(t *testing.T) { + // Setup test database + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + user2, err := db.CreateUser(types.User{Name: "user2"}) + require.NoError(t, err) + + node1 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test1", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test2", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + err = db.DB.Save(&node1).Error + require.NoError(t, err) + + err = db.DB.Save(&node2).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNode(tx, node1, nil, nil) + if err != nil { + return err + } + _, err = RegisterNode(tx, node2, nil, nil) + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 2) + + // Empty node list should return empty list + nodes, err = db.ListNodesSubset(types.NodeIDs{}) + require.NoError(t, err) + assert.Equal(t, len(nodes), 0) + + // No match in IDs should return empty list and no error + nodes, err = db.ListNodesSubset(types.NodeIDs{3, 4, 5}) + require.NoError(t, err) + assert.Equal(t, len(nodes), 0) + + // Partial match in IDs + nodes, err = db.ListNodesSubset(types.NodeIDs{2, 3}) + require.NoError(t, err) + assert.Equal(t, len(nodes), 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Several matched IDs + nodes, err = db.ListNodesSubset(types.NodeIDs{1, 2, 3}) + require.NoError(t, err) + assert.Equal(t, len(nodes), 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) +}