From cf70643818c9a71d83907cce361a6b89f1285a26 Mon Sep 17 00:00:00 2001 From: Nils Enkelmann Date: Mon, 7 Apr 2025 10:52:51 +0200 Subject: [PATCH] query for all nodes in ListNodes if no parameter is given --- hscontrol/db/node.go | 17 +++++++---------- hscontrol/db/node_test.go | 14 ++++++++------ hscontrol/mapper/mapper.go | 15 ++++++++++----- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 75f55dac..6bfe657c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -58,26 +58,23 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) { +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter +func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { return ListNodes(rx, nodeIDs...) }) } -func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeIDs) (types.Nodes, error) { - if len(nodeIDs) > 0 && len(nodeIDs[0]) == 0 { - return types.Nodes{}, nil - } - var nodeFilter types.NodeIDs = nil - if len(nodeIDs) > 0 { - nodeFilter = nodeIDs[0] - } +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter +func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Where(nodeFilter).Find(&nodes).Error; err != nil { + Where(nodeIDs).Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 4a51e70b..73650663 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -814,24 +814,26 @@ func TestListNodes(t *testing.T) { assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) - // Empty node list should return empty list - nodes, err = db.ListNodes(types.NodeIDs{}) + // Empty node list should return all nodes + nodes, err = db.ListNodes(types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, len(nodes), 0) + assert.Equal(t, len(nodes), 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) // No match in IDs should return empty list and no error - nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}) + nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...) require.NoError(t, err) assert.Equal(t, len(nodes), 0) // Partial match in IDs - nodes, err = db.ListNodes(types.NodeIDs{2, 3}) + nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) require.NoError(t, err) assert.Equal(t, len(nodes), 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs - nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}) + nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) require.NoError(t, err) assert.Equal(t, len(nodes), 2) assert.Equal(t, "test1", nodes[0].Hostname) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index b24e9577..0b5d5d93 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -255,6 +255,7 @@ func (m *Mapper) PeerChangedResponse( patches []*tailcfg.PeerChange, messages ...string, ) ([]byte, error) { + var err error resp := m.baseMapResponse() var removedIDs []tailcfg.NodeID @@ -268,10 +269,12 @@ func (m *Mapper) PeerChangedResponse( removedIDs = append(removedIDs, nodeID.NodeID()) } } - - changedNodes, err := m.ListNodes(changedIDs) - if err != nil { - return nil, err + changedNodes := types.Nodes{} + if len(changedIDs) > 0 { + changedNodes, err = m.ListNodes(changedIDs...) + if err != nil { + return nil, err + } } err = appendPeerChanges( @@ -491,7 +494,9 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { return peers, nil } -func (m *Mapper) ListNodes(nodeIDs ...types.NodeIDs) (types.Nodes, error) { +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter +func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { nodes, err := m.db.ListNodes(nodeIDs...) if err != nil { return nil, err