diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 49e24fb8..75f55dac 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -58,41 +58,26 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { +func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListNodes(rx) + return ListNodes(rx, nodeIDs...) }) } -func ListNodes(tx *gorm.DB) (types.Nodes, error) { - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Find(&nodes).Error; err != nil { - return nil, err - } - - return nodes, nil -} - -func (hsdb *HSDatabase) ListNodesSubset(nodeIDs types.NodeIDs) (types.Nodes, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListNodesSubset(rx, nodeIDs) - }) -} - -func ListNodesSubset(tx *gorm.DB, nodeIDs types.NodeIDs) (types.Nodes, error) { - if len(nodeIDs) < 1 { +func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeIDs) (types.Nodes, error) { + if len(nodeIDs) > 0 && len(nodeIDs[0]) == 0 { return types.Nodes{}, nil } + var nodeFilter types.NodeIDs = nil + if len(nodeIDs) > 0 { + nodeFilter = nodeIDs[0] + } nodes := types.Nodes{} if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Where(nodeIDs).Find(&nodes).Error; err != nil { + Where(nodeFilter).Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 443cb765..4a51e70b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -753,7 +753,7 @@ func TestRenameNode(t *testing.T) { assert.ErrorContains(t, err, "name is not unique") } -func TestListNodesSubset(t *testing.T) { +func TestListNodes(t *testing.T) { // Setup test database db, err := newSQLiteTestDB() if err != nil { @@ -807,24 +807,31 @@ func TestListNodesSubset(t *testing.T) { assert.Len(t, nodes, 2) + // No parameter means no filter, should return all nodes + nodes, err = db.ListNodes() + require.NoError(t, err) + assert.Equal(t, len(nodes), 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) + // Empty node list should return empty list - nodes, err = db.ListNodesSubset(types.NodeIDs{}) + nodes, err = db.ListNodes(types.NodeIDs{}) require.NoError(t, err) assert.Equal(t, len(nodes), 0) // No match in IDs should return empty list and no error - nodes, err = db.ListNodesSubset(types.NodeIDs{3, 4, 5}) + nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}) require.NoError(t, err) assert.Equal(t, len(nodes), 0) // Partial match in IDs - nodes, err = db.ListNodesSubset(types.NodeIDs{2, 3}) + nodes, err = db.ListNodes(types.NodeIDs{2, 3}) require.NoError(t, err) assert.Equal(t, len(nodes), 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs - nodes, err = db.ListNodesSubset(types.NodeIDs{1, 2, 3}) + nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}) require.NoError(t, err) assert.Equal(t, len(nodes), 2) assert.Equal(t, "test1", nodes[0].Hostname) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 80d6c721..b24e9577 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -269,7 +269,7 @@ func (m *Mapper) PeerChangedResponse( } } - changedNodes, err := m.ListNodesSubset(changedIDs) + changedNodes, err := m.ListNodes(changedIDs) if err != nil { return nil, err } @@ -491,8 +491,8 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { return peers, nil } -func (m *Mapper) ListNodesSubset(nodeIDs []types.NodeID) (types.Nodes, error) { - nodes, err := m.db.ListNodesSubset(nodeIDs) +func (m *Mapper) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) { + nodes, err := m.db.ListNodes(nodeIDs...) if err != nil { return nil, err }