mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-05 13:49:57 +02:00
query for all nodes in ListNodes if no parameter is given
This commit is contained in:
parent
e0df114d5a
commit
cf70643818
@ -58,26 +58,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
|
|||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) {
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
|
// or for the given nodes if at least one node ID is given as parameter
|
||||||
|
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return ListNodes(rx, nodeIDs...)
|
return ListNodes(rx, nodeIDs...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeIDs) (types.Nodes, error) {
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
if len(nodeIDs) > 0 && len(nodeIDs[0]) == 0 {
|
// or for the given nodes if at least one node ID is given as parameter
|
||||||
return types.Nodes{}, nil
|
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
}
|
|
||||||
var nodeFilter types.NodeIDs = nil
|
|
||||||
if len(nodeIDs) > 0 {
|
|
||||||
nodeFilter = nodeIDs[0]
|
|
||||||
}
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
if err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Where(nodeFilter).Find(&nodes).Error; err != nil {
|
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -814,24 +814,26 @@ func TestListNodes(t *testing.T) {
|
|||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
// Empty node list should return empty list
|
// Empty node list should return all nodes
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{})
|
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(nodes), 0)
|
assert.Equal(t, len(nodes), 2)
|
||||||
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
// No match in IDs should return empty list and no error
|
// No match in IDs should return empty list and no error
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5})
|
nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(nodes), 0)
|
assert.Equal(t, len(nodes), 0)
|
||||||
|
|
||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3})
|
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(nodes), 1)
|
assert.Equal(t, len(nodes), 1)
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs
|
// Several matched IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3})
|
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(nodes), 2)
|
assert.Equal(t, len(nodes), 2)
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
|
@ -255,6 +255,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
patches []*tailcfg.PeerChange,
|
patches []*tailcfg.PeerChange,
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
|
var err error
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
var removedIDs []tailcfg.NodeID
|
var removedIDs []tailcfg.NodeID
|
||||||
@ -268,11 +269,13 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
changedNodes := types.Nodes{}
|
||||||
changedNodes, err := m.ListNodes(changedIDs)
|
if len(changedIDs) > 0 {
|
||||||
|
changedNodes, err = m.ListNodes(changedIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = appendPeerChanges(
|
err = appendPeerChanges(
|
||||||
&resp,
|
&resp,
|
||||||
@ -491,7 +494,9 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
|||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) {
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
|
// or for the given nodes if at least one node ID is given as parameter
|
||||||
|
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
nodes, err := m.db.ListNodes(nodeIDs...)
|
nodes, err := m.db.ListNodes(nodeIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user