1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-05-14 01:17:07 +02:00

Rework to ensure transactional consistency in PeerChangedResponse again

This commit is contained in:
Nils Enkelmann 2025-03-28 13:48:06 +01:00
parent af04eb5ffd
commit 67bdd19560
2 changed files with 34 additions and 22 deletions

View File

@ -77,6 +77,25 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) {
return nodes, nil
}
func (hsdb *HSDatabase) ListNodesSubset(nodeIDs types.NodeIDs) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodesSubset(rx, nodeIDs)
})
}
func ListNodesSubset(tx *gorm.DB, nodeIDs types.NodeIDs) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Where(nodeIDs).Find(&nodes).Error; err != nil {
return nil, err
}
return nodes, nil
}
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{}

View File

@ -3,9 +3,7 @@ package mapper
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"gorm.io/gorm"
"io/fs"
"net/url"
"os"
@ -263,28 +261,20 @@ func (m *Mapper) PeerChangedResponse(
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
changedIDs = append(changedIDs, nodeID)
if nodeID != node.ID {
changedIDs = append(changedIDs, nodeID)
}
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := make(types.Nodes, 0, len(changedIDs))
for _, changedID := range changedIDs {
if changedID != node.ID {
changedNode, err := m.GetNode(changedID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
} else {
return nil, err
}
}
changedNodes = append(changedNodes, changedNode)
}
changedNodes, err := m.ListNodesSubset(changedIDs)
if err != nil {
return nil, err
}
err := appendPeerChanges(
err = appendPeerChanges(
&resp,
false, // partial change
m.polMan,
@ -501,15 +491,18 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
return peers, nil
}
func (m *Mapper) GetNode(nodeID types.NodeID) (*types.Node, error) {
node, err := m.db.GetNodeByID(nodeID)
func (m *Mapper) ListNodesSubset(nodeIDs []types.NodeID) (types.Nodes, error) {
nodes, err := m.db.ListNodesSubset(nodeIDs)
if err != nil {
return nil, err
}
online := m.notif.IsLikelyConnected(node.ID)
node.IsOnline = &online
return node, nil
for _, node := range nodes {
online := m.notif.IsLikelyConnected(node.ID)
node.IsOnline = &online
}
return nodes, nil
}
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {