mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
nodestore: return node when applied
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
1507bbf6d4
commit
8442220b80
@ -33,8 +33,8 @@ func TestNodeStoreDebugString(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
|
||||
store.PutNode(node1)
|
||||
store.PutNode(node2)
|
||||
_ = store.PutNode(node1)
|
||||
_ = store.PutNode(node2)
|
||||
|
||||
return store
|
||||
},
|
||||
|
@ -140,13 +140,15 @@ type work struct {
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
nodeResult chan types.NodeView // Channel to return the resulting node after batch application
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
// If the node already exists, it will be replaced.
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
// Returns the resulting node after all modifications in the batch have been applied.
|
||||
func (s *NodeStore) PutNode(n types.Node) types.NodeView {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
@ -155,6 +157,7 @@ func (s *NodeStore) PutNode(n types.Node) {
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
nodeResult: make(chan types.NodeView, 1),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
@ -162,7 +165,10 @@ func (s *NodeStore) PutNode(n types.Node) {
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
|
||||
return resultNode
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
@ -173,6 +179,7 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
// This is analogous to a database "transaction", or, the caller should
|
||||
// rather collect all data they want to change, and then call this function.
|
||||
// Fewer calls are better.
|
||||
// Returns the resulting node after all modifications in the batch have been applied.
|
||||
//
|
||||
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||
@ -181,7 +188,7 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||
// on all read operations.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
@ -190,6 +197,7 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
nodeResult: make(chan types.NodeView, 1),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
@ -197,7 +205,11 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
|
||||
// Return the node and whether it exists (is valid)
|
||||
return resultNode, resultNode.Valid()
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
@ -282,18 +294,32 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
for _, w := range batch {
|
||||
// Track which work items need node results
|
||||
nodeResultRequests := make(map[types.NodeID][]*work)
|
||||
|
||||
for i := range batch {
|
||||
w := &batch[i]
|
||||
switch w.op {
|
||||
case put:
|
||||
nodes[w.nodeID] = w.node
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
case update:
|
||||
// Update the specific node identified by nodeID
|
||||
if n, exists := nodes[w.nodeID]; exists {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
case del:
|
||||
delete(nodes, w.nodeID)
|
||||
// For delete operations, send an invalid NodeView if requested
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -303,6 +329,24 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
// Update node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
// Send the resulting nodes to all work items that requested them
|
||||
for nodeID, workItems := range nodeResultRequests {
|
||||
if node, exists := nodes[nodeID]; exists {
|
||||
nodeView := node.View()
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- nodeView
|
||||
close(w.nodeResult)
|
||||
}
|
||||
} else {
|
||||
// Node was deleted or doesn't exist
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- types.NodeView{} // Send invalid view
|
||||
close(w.nodeResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Signal completion for all work items
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
|
@ -249,7 +249,9 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add first node",
|
||||
action: func(store *NodeStore) {
|
||||
node := createTestNode(1, 1, "user1", "node1")
|
||||
store.PutNode(node)
|
||||
resultNode := store.PutNode(node)
|
||||
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
|
||||
assert.Equal(t, node.ID, resultNode.ID())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
@ -288,7 +290,9 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add second node same user",
|
||||
action: func(store *NodeStore) {
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
store.PutNode(node2)
|
||||
resultNode := store.PutNode(node2)
|
||||
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
|
||||
assert.Equal(t, types.NodeID(2), resultNode.ID())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
@ -308,7 +312,9 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add third node different user",
|
||||
action: func(store *NodeStore) {
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
store.PutNode(node3)
|
||||
resultNode := store.PutNode(node3)
|
||||
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
|
||||
assert.Equal(t, types.NodeID(3), resultNode.ID())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
@ -409,10 +415,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
{
|
||||
name: "update node hostname",
|
||||
action: func(store *NodeStore) {
|
||||
store.UpdateNode(1, func(n *types.Node) {
|
||||
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "updated-node1"
|
||||
n.GivenName = "updated-node1"
|
||||
})
|
||||
assert.True(t, ok, "UpdateNode should return true for existing node")
|
||||
assert.True(t, resultNode.Valid(), "Result node should be valid")
|
||||
assert.Equal(t, "updated-node1", resultNode.Hostname())
|
||||
assert.Equal(t, "updated-node1", resultNode.GivenName())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
@ -436,10 +446,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add nodes with odd-even filtering",
|
||||
action: func(store *NodeStore) {
|
||||
// Add nodes in sequence
|
||||
store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
assert.True(t, n1.Valid())
|
||||
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
assert.True(t, n2.Valid())
|
||||
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
assert.True(t, n3.Valid())
|
||||
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
assert.True(t, n4.Valid())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
@ -478,6 +492,328 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test batch modifications return correct node state",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial state",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "concurrent updates should reflect all batch changes",
|
||||
action: func(store *NodeStore) {
|
||||
// Start multiple updates that will be batched together
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2 types.NodeView
|
||||
var newNode3 types.NodeView
|
||||
var ok1, ok2 bool
|
||||
|
||||
// These should all be processed in the same batch
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node1"
|
||||
n.GivenName = "batch-given-1"
|
||||
})
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node2"
|
||||
n.GivenName = "batch-given-2"
|
||||
})
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
node3 := createTestNode(3, 1, "user1", "node3")
|
||||
newNode3 = store.PutNode(node3)
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
<-done1
|
||||
<-done2
|
||||
<-done3
|
||||
|
||||
// Verify the returned nodes reflect the batch state
|
||||
assert.True(t, ok1, "UpdateNode should succeed for node 1")
|
||||
assert.True(t, ok2, "UpdateNode should succeed for node 2")
|
||||
assert.True(t, resultNode1.Valid())
|
||||
assert.True(t, resultNode2.Valid())
|
||||
assert.True(t, newNode3.Valid())
|
||||
|
||||
// Check that returned nodes have the updated values
|
||||
assert.Equal(t, "batch-updated-node1", resultNode1.Hostname())
|
||||
assert.Equal(t, "batch-given-1", resultNode1.GivenName())
|
||||
assert.Equal(t, "batch-updated-node2", resultNode2.Hostname())
|
||||
assert.Equal(t, "batch-given-2", resultNode2.GivenName())
|
||||
assert.Equal(t, "node3", newNode3.Hostname())
|
||||
|
||||
// Verify the snapshot also reflects all changes
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname)
|
||||
assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname)
|
||||
|
||||
// Verify peer relationships are updated correctly with new node
|
||||
assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3
|
||||
assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3
|
||||
assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update non-existent node returns invalid view",
|
||||
action: func(store *NodeStore) {
|
||||
resultNode, ok := store.UpdateNode(999, func(n *types.Node) {
|
||||
n.Hostname = "should-not-exist"
|
||||
})
|
||||
|
||||
assert.False(t, ok, "UpdateNode should return false for non-existent node")
|
||||
assert.False(t, resultNode.Valid(), "Result should be invalid NodeView")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple updates to same node in batch all see final state",
|
||||
action: func(store *NodeStore) {
|
||||
// This test verifies that when multiple updates to the same node
|
||||
// are batched together, each returned node reflects ALL changes
|
||||
// in the batch, not just the individual update's changes.
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
|
||||
// These updates all modify node 1 and should be batched together
|
||||
// The final state should have all three modifications applied
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "multi-update-hostname"
|
||||
})
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "multi-update-givenname"
|
||||
})
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.ForcedTags = []string{"tag1", "tag2"}
|
||||
})
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
<-done1
|
||||
<-done2
|
||||
<-done3
|
||||
|
||||
// All updates should succeed
|
||||
assert.True(t, ok1, "First update should succeed")
|
||||
assert.True(t, ok2, "Second update should succeed")
|
||||
assert.True(t, ok3, "Third update should succeed")
|
||||
|
||||
// CRITICAL: Each returned node should reflect ALL changes from the batch
|
||||
// not just the change from its specific update call
|
||||
|
||||
// resultNode1 (from hostname update) should also have the givenname and tags changes
|
||||
assert.Equal(t, "multi-update-hostname", resultNode1.Hostname())
|
||||
assert.Equal(t, "multi-update-givenname", resultNode1.GivenName())
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice())
|
||||
|
||||
// resultNode2 (from givenname update) should also have the hostname and tags changes
|
||||
assert.Equal(t, "multi-update-hostname", resultNode2.Hostname())
|
||||
assert.Equal(t, "multi-update-givenname", resultNode2.GivenName())
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice())
|
||||
|
||||
// resultNode3 (from tags update) should also have the hostname and givenname changes
|
||||
assert.Equal(t, "multi-update-hostname", resultNode3.Hostname())
|
||||
assert.Equal(t, "multi-update-givenname", resultNode3.GivenName())
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice())
|
||||
|
||||
// Verify the snapshot also has all changes
|
||||
snapshot := store.data.Load()
|
||||
finalNode := snapshot.nodesByID[1]
|
||||
assert.Equal(t, "multi-update-hostname", finalNode.Hostname)
|
||||
assert.Equal(t, "multi-update-givenname", finalNode.GivenName)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test UpdateNode result is immutable for database save",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify returned node is complete and consistent",
|
||||
action: func(store *NodeStore) {
|
||||
// Update a node and verify the returned view is complete
|
||||
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "db-save-hostname"
|
||||
n.GivenName = "db-save-given"
|
||||
n.ForcedTags = []string{"db-tag1", "db-tag2"}
|
||||
})
|
||||
|
||||
assert.True(t, ok, "UpdateNode should succeed")
|
||||
assert.True(t, resultNode.Valid(), "Result should be valid")
|
||||
|
||||
// Verify the returned node has all expected values
|
||||
assert.Equal(t, "db-save-hostname", resultNode.Hostname())
|
||||
assert.Equal(t, "db-save-given", resultNode.GivenName())
|
||||
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice())
|
||||
|
||||
// Convert to struct as would be done for database save
|
||||
nodePtr := resultNode.AsStruct()
|
||||
assert.NotNil(t, nodePtr)
|
||||
assert.Equal(t, "db-save-hostname", nodePtr.Hostname)
|
||||
assert.Equal(t, "db-save-given", nodePtr.GivenName)
|
||||
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags)
|
||||
|
||||
// Verify the snapshot also reflects the same state
|
||||
snapshot := store.data.Load()
|
||||
storedNode := snapshot.nodesByID[1]
|
||||
assert.Equal(t, "db-save-hostname", storedNode.Hostname)
|
||||
assert.Equal(t, "db-save-given", storedNode.GivenName)
|
||||
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "concurrent updates all return consistent final state for DB save",
|
||||
action: func(store *NodeStore) {
|
||||
// Multiple goroutines updating the same node
|
||||
// All should receive the final batch state suitable for DB save
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var result1, result2, result3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
|
||||
// Start concurrent updates
|
||||
go func() {
|
||||
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "concurrent-db-hostname"
|
||||
})
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "concurrent-db-given"
|
||||
})
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.ForcedTags = []string{"concurrent-tag"}
|
||||
})
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
// Wait for all to complete
|
||||
<-done1
|
||||
<-done2
|
||||
<-done3
|
||||
|
||||
assert.True(t, ok1 && ok2 && ok3, "All updates should succeed")
|
||||
|
||||
// All results should be valid and suitable for database save
|
||||
assert.True(t, result1.Valid())
|
||||
assert.True(t, result2.Valid())
|
||||
assert.True(t, result3.Valid())
|
||||
|
||||
// Convert each to struct as would be done for DB save
|
||||
nodePtr1 := result1.AsStruct()
|
||||
nodePtr2 := result2.AsStruct()
|
||||
nodePtr3 := result3.AsStruct()
|
||||
|
||||
// All should have the complete final state
|
||||
assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname)
|
||||
assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName)
|
||||
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags)
|
||||
|
||||
assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname)
|
||||
assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName)
|
||||
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags)
|
||||
|
||||
assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname)
|
||||
assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName)
|
||||
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags)
|
||||
|
||||
// Verify consistency with stored state
|
||||
snapshot := store.data.Load()
|
||||
storedNode := snapshot.nodesByID[1]
|
||||
assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname)
|
||||
assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName)
|
||||
assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify returned node preserves all fields for DB save",
|
||||
action: func(store *NodeStore) {
|
||||
// Get initial state
|
||||
snapshot := store.data.Load()
|
||||
originalNode := snapshot.nodesByID[2]
|
||||
originalIPv4 := originalNode.IPv4
|
||||
originalIPv6 := originalNode.IPv6
|
||||
originalCreatedAt := originalNode.CreatedAt
|
||||
originalUser := originalNode.User
|
||||
|
||||
// Update only hostname
|
||||
resultNode, ok := store.UpdateNode(2, func(n *types.Node) {
|
||||
n.Hostname = "preserve-test-hostname"
|
||||
})
|
||||
|
||||
assert.True(t, ok, "Update should succeed")
|
||||
|
||||
// Convert to struct for DB save
|
||||
nodeForDB := resultNode.AsStruct()
|
||||
|
||||
// Verify all fields are preserved
|
||||
assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname)
|
||||
assert.Equal(t, originalIPv4, nodeForDB.IPv4)
|
||||
assert.Equal(t, originalIPv6, nodeForDB.IPv6)
|
||||
assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt)
|
||||
assert.Equal(t, originalUser.Name, nodeForDB.User.Name)
|
||||
assert.Equal(t, types.NodeID(2), nodeForDB.ID)
|
||||
|
||||
// These fields should be suitable for direct database save
|
||||
assert.NotNil(t, nodeForDB.IPv4)
|
||||
assert.NotNil(t, nodeForDB.IPv6)
|
||||
assert.False(t, nodeForDB.CreatedAt.IsZero())
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -336,38 +336,11 @@ func (s *State) ListAllUsers() ([]types.User, error) {
|
||||
return s.db.ListUsers()
|
||||
}
|
||||
|
||||
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
||||
// IMPORTANT: This function does NOT update the NodeStore. The caller MUST update the NodeStore
|
||||
// BEFORE calling this function with the EXACT same changes that the database update will make.
|
||||
// This ensures the NodeStore is the source of truth for the batcher and maintains consistency.
|
||||
// Returns error only; callers should get the updated NodeView from NodeStore to maintain consistency.
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) error {
|
||||
_, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := updateFn(tx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node, err := hsdb.GetNodeByID(tx, nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Save(node).Error; err != nil {
|
||||
return nil, fmt.Errorf("updating node: %w", err)
|
||||
}
|
||||
|
||||
return node, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// persistNodeToDB saves the current state of a node from NodeStore to the database.
|
||||
// CRITICAL: This function MUST get the latest node from NodeStore to ensure consistency.
|
||||
func (s *State) persistNodeToDB(nodeID types.NodeID) (types.NodeView, change.ChangeSet, error) {
|
||||
// CRITICAL: Always get the latest node from NodeStore to ensure we save the current state
|
||||
node, found := s.nodeStore.GetNode(nodeID)
|
||||
if !found {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
// persistNodeToDB saves the given node state to the database.
|
||||
// CRITICAL: This function MUST receive the exact node state to save, ensuring consistency.
|
||||
func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
|
||||
if !node.Valid() {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("invalid node view provided")
|
||||
}
|
||||
|
||||
nodePtr := node.AsStruct()
|
||||
@ -393,10 +366,10 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet,
|
||||
// Update NodeStore first
|
||||
nodePtr := node.AsStruct()
|
||||
|
||||
s.nodeStore.PutNode(*nodePtr)
|
||||
resultNode := s.nodeStore.PutNode(*nodePtr)
|
||||
|
||||
// Then save to database
|
||||
return s.persistNodeToDB(node.ID())
|
||||
// Then save to database using the result from PutNode
|
||||
return s.persistNodeToDB(resultNode)
|
||||
}
|
||||
|
||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||
@ -430,17 +403,14 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet {
|
||||
// This ensures that when the NodeCameOnline change is distributed and processed by other nodes,
|
||||
// the NodeStore already reflects the correct online status for full map generation.
|
||||
// now := time.Now()
|
||||
s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
n.IsOnline = ptr.To(true)
|
||||
// n.LastSeen = ptr.To(now)
|
||||
})
|
||||
c := []change.ChangeSet{change.NodeOnline(id)}
|
||||
|
||||
// Get fresh node data from NodeStore after the online status update
|
||||
node, found := s.GetNodeByID(id)
|
||||
if !found {
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
c := []change.ChangeSet{change.NodeOnline(id)}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected")
|
||||
|
||||
@ -460,39 +430,25 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet {
|
||||
func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Get node info before updating for logging
|
||||
node, found := s.GetNodeByID(id)
|
||||
var nodeName string
|
||||
if found {
|
||||
nodeName = node.Hostname()
|
||||
}
|
||||
|
||||
s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(now)
|
||||
// NodeStore is the source of truth for all node state including online status.
|
||||
n.IsOnline = ptr.To(false)
|
||||
})
|
||||
|
||||
if found {
|
||||
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Node disconnected")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node not found: %d", id)
|
||||
}
|
||||
|
||||
err := s.updateNodeTx(id, func(tx *gorm.DB) error {
|
||||
// Update last_seen in the database
|
||||
// Note: IsOnline is managed only in NodeStore (marked with gorm:"-"), not persisted to database
|
||||
return hsdb.SetLastSeen(tx, id, now)
|
||||
})
|
||||
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected")
|
||||
|
||||
// Special error handling for disconnect - we log errors but continue
|
||||
// because NodeStore is already updated and we need to notify peers
|
||||
_, c, err := s.persistNodeToDB(node)
|
||||
if err != nil {
|
||||
// Log error but don't fail the disconnection - NodeStore is already updated
|
||||
// and we need to send change notifications to peers
|
||||
log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update last seen in database")
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
// Log error but continue - disconnection must proceed
|
||||
log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update policy manager after node disconnect")
|
||||
log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database")
|
||||
c = change.EmptySet
|
||||
}
|
||||
|
||||
@ -610,35 +566,15 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
|
||||
// If the database update fails, the NodeStore change will remain, but since we return
|
||||
// an error, no change notification will be sent to the batcher.
|
||||
expiryPtr := expiry
|
||||
s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
node.Expiry = &expiryPtr
|
||||
})
|
||||
|
||||
err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
// Get the updated node from NodeStore to ensure consistency
|
||||
// TODO(kradalby): Validate if this NodeStore read makes sense after database update
|
||||
n, found := s.GetNodeByID(nodeID)
|
||||
if !found {
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.KeyExpiry(nodeID, expiry)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
return s.persistNodeToDB(n)
|
||||
}
|
||||
|
||||
// SetNodeTags assigns tags to a node for use in access control policies.
|
||||
@ -646,35 +582,15 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView,
|
||||
// CRITICAL: Update NodeStore BEFORE database to ensure consistency.
|
||||
// The NodeStore update is blocking and will be the source of truth for the batcher.
|
||||
// The database update MUST make the EXACT same change.
|
||||
s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
node.ForcedTags = tags
|
||||
})
|
||||
|
||||
err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetTags(tx, nodeID, tags)
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||
}
|
||||
|
||||
// Get the updated node from NodeStore to ensure consistency
|
||||
// TODO(kradalby): Validate if this NodeStore read makes sense after database update
|
||||
n, found := s.GetNodeByID(nodeID)
|
||||
if !found {
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
return s.persistNodeToDB(n)
|
||||
}
|
||||
|
||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||
@ -682,44 +598,15 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
|
||||
// TODO(kradalby): In principle we should call the AutoApprove logic here
|
||||
// because even if the CLI removes an auto-approved route, it will be added
|
||||
// back automatically.
|
||||
s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
node.ApprovedRoutes = routes
|
||||
})
|
||||
|
||||
err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||
}
|
||||
|
||||
// Get the updated node from NodeStore to ensure consistency
|
||||
// TODO(kradalby): Validate if this NodeStore read makes sense after database update
|
||||
n, found := s.GetNodeByID(nodeID)
|
||||
if !found {
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
// Get the node from NodeStore to ensure we have the latest state
|
||||
nodeView, ok := s.GetNodeByID(nodeID)
|
||||
if !ok {
|
||||
return n, change.EmptySet, fmt.Errorf("node %d not found in NodeStore", nodeID)
|
||||
}
|
||||
// Use SubnetRoutes() instead of ApprovedRoutes() to ensure we only set
|
||||
// primary routes for routes that are both announced AND approved
|
||||
routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...)
|
||||
|
||||
if routeChange || !c.IsFull() {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
return s.persistNodeToDB(n)
|
||||
}
|
||||
|
||||
// RenameNode changes the display name of a node.
|
||||
@ -729,13 +616,11 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Check name uniqueness
|
||||
nodes, err := s.db.ListNodes()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("checking name uniqueness: %w", err)
|
||||
}
|
||||
for _, node := range nodes {
|
||||
if node.ID != nodeID && node.GivenName == newName {
|
||||
// Check name uniqueness against NodeStore
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
for i := 0; i < allNodes.Len(); i++ {
|
||||
node := allNodes.At(i)
|
||||
if node.ID() != nodeID && node.AsStruct().GivenName == newName {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName)
|
||||
}
|
||||
}
|
||||
@ -743,35 +628,15 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
|
||||
// CRITICAL: Update NodeStore BEFORE database to ensure consistency.
|
||||
// The NodeStore update is blocking and will be the source of truth for the batcher.
|
||||
// The database update MUST make the EXACT same change.
|
||||
s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
node.GivenName = newName
|
||||
})
|
||||
|
||||
err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.RenameNode(tx, nodeID, newName)
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Get the updated node from NodeStore to ensure consistency
|
||||
// TODO(kradalby): Validate if this NodeStore read makes sense after database update
|
||||
n, found := s.GetNodeByID(nodeID)
|
||||
if !found {
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
return s.persistNodeToDB(n)
|
||||
}
|
||||
|
||||
// AssignNodeToUser transfers a node to a different user.
|
||||
@ -790,36 +655,16 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (type
|
||||
// CRITICAL: Update NodeStore BEFORE database to ensure consistency.
|
||||
// The NodeStore update is blocking and will be the source of truth for the batcher.
|
||||
// The database update MUST make the EXACT same change.
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n, ok := s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.User = *user
|
||||
n.UserID = uint(userID)
|
||||
})
|
||||
|
||||
err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Get the updated node from NodeStore to ensure consistency
|
||||
// TODO(kradalby): Validate if this NodeStore read makes sense after database update
|
||||
n, found := s.GetNodeByID(nodeID)
|
||||
if !found {
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
return s.persistNodeToDB(n)
|
||||
}
|
||||
|
||||
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
||||
@ -859,7 +704,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
}
|
||||
// TODO(kradalby): This should just update the IP addresses, nothing else in the node store.
|
||||
// We should avoid PutNode here.
|
||||
s.nodeStore.PutNode(*node)
|
||||
_ = s.nodeStore.PutNode(*node)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1075,7 +920,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
Msg("Refreshing existing node registration")
|
||||
|
||||
// Update NodeStore first with the new expiry
|
||||
s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) {
|
||||
updatedNode, ok := s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) {
|
||||
if expiry != nil {
|
||||
node.Expiry = expiry
|
||||
}
|
||||
@ -1084,22 +929,22 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
// Save to database
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeView.ID())
|
||||
}
|
||||
// Return the node to satisfy the Write signature
|
||||
return hsdb.GetNodeByID(tx, existingNodeView.ID())
|
||||
|
||||
// Use the node from UpdateNode to save to database
|
||||
nodePtr := updatedNode.AsStruct()
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(nodePtr).Error; err != nil {
|
||||
return nil, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
return nodePtr, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err)
|
||||
}
|
||||
|
||||
// Get updated node from NodeStore
|
||||
updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID())
|
||||
|
||||
if expiry != nil {
|
||||
return updatedNode, change.KeyExpiry(existingNodeView.ID(), *expiry), nil
|
||||
}
|
||||
@ -1163,7 +1008,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
var savedNode *types.Node
|
||||
if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) {
|
||||
// Update existing node - NodeStore first, then database
|
||||
s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) {
|
||||
updatedNodeView, ok := s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) {
|
||||
node.NodeKey = nodeToRegister.NodeKey
|
||||
node.DiscoKey = nodeToRegister.DiscoKey
|
||||
node.Hostname = nodeToRegister.Hostname
|
||||
@ -1194,12 +1039,17 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
// Save to database
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingMachineNode.ID)
|
||||
}
|
||||
|
||||
// Use the node from UpdateNode to save to database
|
||||
nodePtr := updatedNodeView.AsStruct()
|
||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
if err := tx.Save(nodePtr).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
return &nodeToRegister, nil
|
||||
return nodePtr, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
@ -1217,7 +1067,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
|
||||
// Add to NodeStore after database creates the ID
|
||||
s.nodeStore.PutNode(*savedNode)
|
||||
_ = s.nodeStore.PutNode(*savedNode)
|
||||
}
|
||||
|
||||
// Delete from registration cache
|
||||
@ -1345,7 +1195,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
var savedNode *types.Node
|
||||
if existingNode != nil && existingNode.UserID == pak.User.ID {
|
||||
// Update existing node - NodeStore first, then database
|
||||
s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
|
||||
updatedNodeView, ok := s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) {
|
||||
node.NodeKey = nodeToRegister.NodeKey
|
||||
node.Hostname = nodeToRegister.Hostname
|
||||
|
||||
@ -1378,6 +1228,10 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNode.ID)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node.name", nodeToRegister.Hostname).
|
||||
@ -1387,9 +1241,10 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
Str("user.name", pak.User.Username()).
|
||||
Msg("Node re-authorized")
|
||||
|
||||
// Save to database
|
||||
// Use the node from UpdateNode to save to database
|
||||
nodePtr := updatedNodeView.AsStruct()
|
||||
savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
if err := tx.Save(nodePtr).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
|
||||
@ -1400,7 +1255,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
}
|
||||
}
|
||||
|
||||
return &nodeToRegister, nil
|
||||
return nodePtr, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||
@ -1426,7 +1281,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
}
|
||||
|
||||
// Add to NodeStore after database creates the ID
|
||||
s.nodeStore.PutNode(*savedNode)
|
||||
_ = s.nodeStore.PutNode(*savedNode)
|
||||
}
|
||||
|
||||
// Update policy managers
|
||||
@ -1570,7 +1425,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
var needsRouteApproval bool
|
||||
// We need to ensure we update the node as it is in the NodeStore at
|
||||
// the time of the request.
|
||||
s.nodeStore.UpdateNode(id, func(currentNode *types.Node) {
|
||||
updatedNode, ok := s.nodeStore.UpdateNode(id, func(currentNode *types.Node) {
|
||||
peerChange := currentNode.PeerChangeFromMapRequest(req)
|
||||
hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo)
|
||||
|
||||
@ -1673,6 +1528,10 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
}
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", id)
|
||||
}
|
||||
|
||||
nodeRouteChange := change.EmptySet
|
||||
|
||||
// Handle route changes after NodeStore update
|
||||
@ -1702,12 +1561,6 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
}
|
||||
|
||||
if needsRouteUpdate {
|
||||
// Get the updated node to access its subnet routes
|
||||
updatedNode, exists := s.GetNodeByID(id)
|
||||
if !exists {
|
||||
return change.EmptySet, fmt.Errorf("node disappeared during update: %d", id)
|
||||
}
|
||||
|
||||
// SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes()
|
||||
// which returns only the intersection of announced AND approved routes.
|
||||
// Using AnnouncedRoutes() would bypass the security model and auto-approve everything.
|
||||
@ -1721,7 +1574,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...)
|
||||
}
|
||||
|
||||
_, policyChange, err := s.persistNodeToDB(id)
|
||||
_, policyChange, err := s.persistNodeToDB(updatedNode)
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("saving to database: %w", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user