mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-24 13:46:53 +02:00
combine ListNodes and ListNodesSubset into one function
This commit is contained in:
parent
e3d44d3b44
commit
e0df114d5a
@ -58,41 +58,26 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) {
|
||||
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodes(rx)
|
||||
return ListNodes(rx, nodeIDs...)
|
||||
})
|
||||
}
|
||||
|
||||
func ListNodes(tx *gorm.DB) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListNodesSubset(nodeIDs types.NodeIDs) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodesSubset(rx, nodeIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func ListNodesSubset(tx *gorm.DB, nodeIDs types.NodeIDs) (types.Nodes, error) {
|
||||
if len(nodeIDs) < 1 {
|
||||
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeIDs) (types.Nodes, error) {
|
||||
if len(nodeIDs) > 0 && len(nodeIDs[0]) == 0 {
|
||||
return types.Nodes{}, nil
|
||||
}
|
||||
var nodeFilter types.NodeIDs = nil
|
||||
if len(nodeIDs) > 0 {
|
||||
nodeFilter = nodeIDs[0]
|
||||
}
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
||||
Where(nodeFilter).Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -753,7 +753,7 @@ func TestRenameNode(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "name is not unique")
|
||||
}
|
||||
|
||||
func TestListNodesSubset(t *testing.T) {
|
||||
func TestListNodes(t *testing.T) {
|
||||
// Setup test database
|
||||
db, err := newSQLiteTestDB()
|
||||
if err != nil {
|
||||
@ -807,24 +807,31 @@ func TestListNodesSubset(t *testing.T) {
|
||||
|
||||
assert.Len(t, nodes, 2)
|
||||
|
||||
// No parameter means no filter, should return all nodes
|
||||
nodes, err = db.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(nodes), 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
// Empty node list should return empty list
|
||||
nodes, err = db.ListNodesSubset(types.NodeIDs{})
|
||||
nodes, err = db.ListNodes(types.NodeIDs{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(nodes), 0)
|
||||
|
||||
// No match in IDs should return empty list and no error
|
||||
nodes, err = db.ListNodesSubset(types.NodeIDs{3, 4, 5})
|
||||
nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(nodes), 0)
|
||||
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListNodesSubset(types.NodeIDs{2, 3})
|
||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(nodes), 1)
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs
|
||||
nodes, err = db.ListNodesSubset(types.NodeIDs{1, 2, 3})
|
||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(nodes), 2)
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
|
@ -269,7 +269,7 @@ func (m *Mapper) PeerChangedResponse(
|
||||
}
|
||||
}
|
||||
|
||||
changedNodes, err := m.ListNodesSubset(changedIDs)
|
||||
changedNodes, err := m.ListNodes(changedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -491,8 +491,8 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (m *Mapper) ListNodesSubset(nodeIDs []types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.db.ListNodesSubset(nodeIDs)
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) {
|
||||
nodes, err := m.db.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user