1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-05 13:49:57 +02:00

query for all nodes in ListNodes if no parameter is given

This commit is contained in:
Nils Enkelmann 2025-04-07 10:52:51 +02:00
parent e0df114d5a
commit cf70643818
3 changed files with 25 additions and 21 deletions

View File

@ -58,26 +58,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
return nodes, nil return nodes, nil
} }
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) { // ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx, nodeIDs...) return ListNodes(rx, nodeIDs...)
}) })
} }
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeIDs) (types.Nodes, error) { // ListNodes queries the database for either all nodes if no parameters are given
if len(nodeIDs) > 0 && len(nodeIDs[0]) == 0 { // or for the given nodes if at least one node ID is given as parameter
return types.Nodes{}, nil func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
}
var nodeFilter types.NodeIDs = nil
if len(nodeIDs) > 0 {
nodeFilter = nodeIDs[0]
}
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
Where(nodeFilter).Find(&nodes).Error; err != nil { Where(nodeIDs).Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }

View File

@ -814,24 +814,26 @@ func TestListNodes(t *testing.T) {
assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname) assert.Equal(t, "test2", nodes[1].Hostname)
// Empty node list should return empty list // Empty node list should return all nodes
nodes, err = db.ListNodes(types.NodeIDs{}) nodes, err = db.ListNodes(types.NodeIDs{}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(nodes), 0) assert.Equal(t, len(nodes), 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
// No match in IDs should return empty list and no error // No match in IDs should return empty list and no error
nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}) nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(nodes), 0) assert.Equal(t, len(nodes), 0)
// Partial match in IDs // Partial match in IDs
nodes, err = db.ListNodes(types.NodeIDs{2, 3}) nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(nodes), 1) assert.Equal(t, len(nodes), 1)
assert.Equal(t, "test2", nodes[0].Hostname) assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs // Several matched IDs
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}) nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(nodes), 2) assert.Equal(t, len(nodes), 2)
assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test1", nodes[0].Hostname)

View File

@ -255,6 +255,7 @@ func (m *Mapper) PeerChangedResponse(
patches []*tailcfg.PeerChange, patches []*tailcfg.PeerChange,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
var err error
resp := m.baseMapResponse() resp := m.baseMapResponse()
var removedIDs []tailcfg.NodeID var removedIDs []tailcfg.NodeID
@ -268,11 +269,13 @@ func (m *Mapper) PeerChangedResponse(
removedIDs = append(removedIDs, nodeID.NodeID()) removedIDs = append(removedIDs, nodeID.NodeID())
} }
} }
changedNodes := types.Nodes{}
changedNodes, err := m.ListNodes(changedIDs) if len(changedIDs) > 0 {
changedNodes, err = m.ListNodes(changedIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
err = appendPeerChanges( err = appendPeerChanges(
&resp, &resp,
@ -491,7 +494,9 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
return peers, nil return peers, nil
} }
func (m *Mapper) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) { // ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.db.ListNodes(nodeIDs...) nodes, err := m.db.ListNodes(nodeIDs...)
if err != nil { if err != nil {
return nil, err return nil, err