From 8442220b80d987aad69e5c513fbfd9ec102d4212 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 17 Sep 2025 16:04:47 +0200 Subject: [PATCH] nodestore: return node when applied Signed-off-by: Kristoffer Dalby --- hscontrol/state/debug_test.go | 4 +- hscontrol/state/node_store.go | 76 +++++-- hscontrol/state/node_store_test.go | 352 ++++++++++++++++++++++++++++- hscontrol/state/state.go | 301 +++++++----------------- 4 files changed, 483 insertions(+), 250 deletions(-) diff --git a/hscontrol/state/debug_test.go b/hscontrol/state/debug_test.go index ae6c340b..60d77245 100644 --- a/hscontrol/state/debug_test.go +++ b/hscontrol/state/debug_test.go @@ -33,8 +33,8 @@ func TestNodeStoreDebugString(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc) store.Start() - store.PutNode(node1) - store.PutNode(node2) + _ = store.PutNode(node1) + _ = store.PutNode(node2) return store }, diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index b27a2945..5328d852 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -135,26 +135,29 @@ type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView // work represents a single operation to be performed on the NodeStore. type work struct { - op int - nodeID types.NodeID - node types.Node - updateFn UpdateNodeFunc - result chan struct{} + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} + nodeResult chan types.NodeView // Channel to return the resulting node after batch application } // PutNode adds or updates a node in the store. // If the node already exists, it will be replaced. // If the node does not exist, it will be added. // This is a blocking operation that waits for the write to complete. -func (s *NodeStore) PutNode(n types.Node) { +// Returns the resulting node after all modifications in the batch have been applied. +func (s *NodeStore) PutNode(n types.Node) types.NodeView { timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put")) defer timer.ObserveDuration() work := work{ - op: put, - nodeID: n.ID, - node: n, - result: make(chan struct{}), + op: put, + nodeID: n.ID, + node: n, + result: make(chan struct{}), + nodeResult: make(chan types.NodeView, 1), } nodeStoreQueueDepth.Inc() @@ -162,7 +165,10 @@ func (s *NodeStore) PutNode(n types.Node) { <-work.result nodeStoreQueueDepth.Dec() + resultNode := <-work.nodeResult nodeStoreOperations.WithLabelValues("put").Inc() + + return resultNode } // UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. @@ -173,6 +179,7 @@ type UpdateNodeFunc func(n *types.Node) // This is analogous to a database "transaction", or, the caller should // rather collect all data they want to change, and then call this function. // Fewer calls are better. +// Returns the resulting node after all modifications in the batch have been applied. // // TODO(kradalby): Technically we could have a version of this that modifies the node // in the current snapshot if _we know_ that the change will not affect the peer relationships. @@ -181,15 +188,16 @@ type UpdateNodeFunc func(n *types.Node) // a lock around the nodesByID map to ensure that no other writes are happening // while we are modifying the node. Which mean we would need to implement read-write locks // on all read operations. -func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) { +func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) { timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update")) defer timer.ObserveDuration() work := work{ - op: update, - nodeID: nodeID, - updateFn: updateFn, - result: make(chan struct{}), + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + nodeResult: make(chan types.NodeView, 1), } nodeStoreQueueDepth.Inc() @@ -197,7 +205,11 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) <-work.result nodeStoreQueueDepth.Dec() + resultNode := <-work.nodeResult nodeStoreOperations.WithLabelValues("update").Inc() + + // Return the node and whether it exists (is valid) + return resultNode, resultNode.Valid() } // DeleteNode removes a node from the store by its ID. @@ -282,18 +294,32 @@ func (s *NodeStore) applyBatch(batch []work) { nodes := make(map[types.NodeID]types.Node) maps.Copy(nodes, s.data.Load().nodesByID) - for _, w := range batch { + // Track which work items need node results + nodeResultRequests := make(map[types.NodeID][]*work) + + for i := range batch { + w := &batch[i] switch w.op { case put: nodes[w.nodeID] = w.node + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } case update: // Update the specific node identified by nodeID if n, exists := nodes[w.nodeID]; exists { w.updateFn(&n) nodes[w.nodeID] = n } + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } case del: delete(nodes, w.nodeID) + // For delete operations, send an invalid NodeView if requested + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } } } @@ -303,6 +329,24 @@ func (s *NodeStore) applyBatch(batch []work) { // Update node count gauge nodeStoreNodesCount.Set(float64(len(nodes))) + // Send the resulting nodes to all work items that requested them + for nodeID, workItems := range nodeResultRequests { + if node, exists := nodes[nodeID]; exists { + nodeView := node.View() + for _, w := range workItems { + w.nodeResult <- nodeView + close(w.nodeResult) + } + } else { + // Node was deleted or doesn't exist + for _, w := range workItems { + w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) + } + } + } + + // Signal completion for all work items for _, w := range batch { close(w.result) } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 9666e5db..4256a89b 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -249,7 +249,9 @@ func TestNodeStoreOperations(t *testing.T) { name: "add first node", action: func(store *NodeStore) { node := createTestNode(1, 1, "user1", "node1") - store.PutNode(node) + resultNode := store.PutNode(node) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, node.ID, resultNode.ID()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 1) @@ -288,7 +290,9 @@ func TestNodeStoreOperations(t *testing.T) { name: "add second node same user", action: func(store *NodeStore) { node2 := createTestNode(2, 1, "user1", "node2") - store.PutNode(node2) + resultNode := store.PutNode(node2) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, types.NodeID(2), resultNode.ID()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 2) @@ -308,7 +312,9 @@ func TestNodeStoreOperations(t *testing.T) { name: "add third node different user", action: func(store *NodeStore) { node3 := createTestNode(3, 2, "user2", "node3") - store.PutNode(node3) + resultNode := store.PutNode(node3) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, types.NodeID(3), resultNode.ID()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 3) @@ -409,10 +415,14 @@ func TestNodeStoreOperations(t *testing.T) { { name: "update node hostname", action: func(store *NodeStore) { - store.UpdateNode(1, func(n *types.Node) { + resultNode, ok := store.UpdateNode(1, func(n *types.Node) { n.Hostname = "updated-node1" n.GivenName = "updated-node1" }) + assert.True(t, ok, "UpdateNode should return true for existing node") + assert.True(t, resultNode.Valid(), "Result node should be valid") + assert.Equal(t, "updated-node1", resultNode.Hostname()) + assert.Equal(t, "updated-node1", resultNode.GivenName()) snapshot := store.data.Load() assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname) @@ -436,10 +446,14 @@ func TestNodeStoreOperations(t *testing.T) { name: "add nodes with odd-even filtering", action: func(store *NodeStore) { // Add nodes in sequence - store.PutNode(createTestNode(1, 1, "user1", "node1")) - store.PutNode(createTestNode(2, 2, "user2", "node2")) - store.PutNode(createTestNode(3, 3, "user3", "node3")) - store.PutNode(createTestNode(4, 4, "user4", "node4")) + n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) + assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) + assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) + assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) + assert.True(t, n4.Valid()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 4) @@ -478,6 +492,328 @@ func TestNodeStoreOperations(t *testing.T) { }, }, }, + { + name: "test batch modifications return correct node state", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "concurrent updates should reflect all batch changes", + action: func(store *NodeStore) { + // Start multiple updates that will be batched together + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var resultNode1, resultNode2 types.NodeView + var newNode3 types.NodeView + var ok1, ok2 bool + + // These should all be processed in the same batch + go func() { + resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "batch-updated-node1" + n.GivenName = "batch-given-1" + }) + close(done1) + }() + + go func() { + resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) { + n.Hostname = "batch-updated-node2" + n.GivenName = "batch-given-2" + }) + close(done2) + }() + + go func() { + node3 := createTestNode(3, 1, "user1", "node3") + newNode3 = store.PutNode(node3) + close(done3) + }() + + // Wait for all operations to complete + <-done1 + <-done2 + <-done3 + + // Verify the returned nodes reflect the batch state + assert.True(t, ok1, "UpdateNode should succeed for node 1") + assert.True(t, ok2, "UpdateNode should succeed for node 2") + assert.True(t, resultNode1.Valid()) + assert.True(t, resultNode2.Valid()) + assert.True(t, newNode3.Valid()) + + // Check that returned nodes have the updated values + assert.Equal(t, "batch-updated-node1", resultNode1.Hostname()) + assert.Equal(t, "batch-given-1", resultNode1.GivenName()) + assert.Equal(t, "batch-updated-node2", resultNode2.Hostname()) + assert.Equal(t, "batch-given-2", resultNode2.GivenName()) + assert.Equal(t, "node3", newNode3.Hostname()) + + // Verify the snapshot also reflects all changes + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname) + assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname) + + // Verify peer relationships are updated correctly with new node + assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3 + assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3 + assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2 + }, + }, + { + name: "update non-existent node returns invalid view", + action: func(store *NodeStore) { + resultNode, ok := store.UpdateNode(999, func(n *types.Node) { + n.Hostname = "should-not-exist" + }) + + assert.False(t, ok, "UpdateNode should return false for non-existent node") + assert.False(t, resultNode.Valid(), "Result should be invalid NodeView") + }, + }, + { + name: "multiple updates to same node in batch all see final state", + action: func(store *NodeStore) { + // This test verifies that when multiple updates to the same node + // are batched together, each returned node reflects ALL changes + // in the batch, not just the individual update's changes. + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var resultNode1, resultNode2, resultNode3 types.NodeView + var ok1, ok2, ok3 bool + + // These updates all modify node 1 and should be batched together + // The final state should have all three modifications applied + go func() { + resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "multi-update-hostname" + }) + close(done1) + }() + + go func() { + resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { + n.GivenName = "multi-update-givenname" + }) + close(done2) + }() + + go func() { + resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { + n.ForcedTags = []string{"tag1", "tag2"} + }) + close(done3) + }() + + // Wait for all operations to complete + <-done1 + <-done2 + <-done3 + + // All updates should succeed + assert.True(t, ok1, "First update should succeed") + assert.True(t, ok2, "Second update should succeed") + assert.True(t, ok3, "Third update should succeed") + + // CRITICAL: Each returned node should reflect ALL changes from the batch + // not just the change from its specific update call + + // resultNode1 (from hostname update) should also have the givenname and tags changes + assert.Equal(t, "multi-update-hostname", resultNode1.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode1.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice()) + + // resultNode2 (from givenname update) should also have the hostname and tags changes + assert.Equal(t, "multi-update-hostname", resultNode2.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode2.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice()) + + // resultNode3 (from tags update) should also have the hostname and givenname changes + assert.Equal(t, "multi-update-hostname", resultNode3.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode3.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice()) + + // Verify the snapshot also has all changes + snapshot := store.data.Load() + finalNode := snapshot.nodesByID[1] + assert.Equal(t, "multi-update-hostname", finalNode.Hostname) + assert.Equal(t, "multi-update-givenname", finalNode.GivenName) + assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags) + }, + }, + }, + }, + { + name: "test UpdateNode result is immutable for database save", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify returned node is complete and consistent", + action: func(store *NodeStore) { + // Update a node and verify the returned view is complete + resultNode, ok := store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "db-save-hostname" + n.GivenName = "db-save-given" + n.ForcedTags = []string{"db-tag1", "db-tag2"} + }) + + assert.True(t, ok, "UpdateNode should succeed") + assert.True(t, resultNode.Valid(), "Result should be valid") + + // Verify the returned node has all expected values + assert.Equal(t, "db-save-hostname", resultNode.Hostname()) + assert.Equal(t, "db-save-given", resultNode.GivenName()) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice()) + + // Convert to struct as would be done for database save + nodePtr := resultNode.AsStruct() + assert.NotNil(t, nodePtr) + assert.Equal(t, "db-save-hostname", nodePtr.Hostname) + assert.Equal(t, "db-save-given", nodePtr.GivenName) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags) + + // Verify the snapshot also reflects the same state + snapshot := store.data.Load() + storedNode := snapshot.nodesByID[1] + assert.Equal(t, "db-save-hostname", storedNode.Hostname) + assert.Equal(t, "db-save-given", storedNode.GivenName) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags) + }, + }, + { + name: "concurrent updates all return consistent final state for DB save", + action: func(store *NodeStore) { + // Multiple goroutines updating the same node + // All should receive the final batch state suitable for DB save + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var result1, result2, result3 types.NodeView + var ok1, ok2, ok3 bool + + // Start concurrent updates + go func() { + result1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "concurrent-db-hostname" + }) + close(done1) + }() + + go func() { + result2, ok2 = store.UpdateNode(1, func(n *types.Node) { + n.GivenName = "concurrent-db-given" + }) + close(done2) + }() + + go func() { + result3, ok3 = store.UpdateNode(1, func(n *types.Node) { + n.ForcedTags = []string{"concurrent-tag"} + }) + close(done3) + }() + + // Wait for all to complete + <-done1 + <-done2 + <-done3 + + assert.True(t, ok1 && ok2 && ok3, "All updates should succeed") + + // All results should be valid and suitable for database save + assert.True(t, result1.Valid()) + assert.True(t, result2.Valid()) + assert.True(t, result3.Valid()) + + // Convert each to struct as would be done for DB save + nodePtr1 := result1.AsStruct() + nodePtr2 := result2.AsStruct() + nodePtr3 := result3.AsStruct() + + // All should have the complete final state + assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags) + + assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags) + + assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags) + + // Verify consistency with stored state + snapshot := store.data.Load() + storedNode := snapshot.nodesByID[1] + assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname) + assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName) + assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags) + }, + }, + { + name: "verify returned node preserves all fields for DB save", + action: func(store *NodeStore) { + // Get initial state + snapshot := store.data.Load() + originalNode := snapshot.nodesByID[2] + originalIPv4 := originalNode.IPv4 + originalIPv6 := originalNode.IPv6 + originalCreatedAt := originalNode.CreatedAt + originalUser := originalNode.User + + // Update only hostname + resultNode, ok := store.UpdateNode(2, func(n *types.Node) { + n.Hostname = "preserve-test-hostname" + }) + + assert.True(t, ok, "Update should succeed") + + // Convert to struct for DB save + nodeForDB := resultNode.AsStruct() + + // Verify all fields are preserved + assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname) + assert.Equal(t, originalIPv4, nodeForDB.IPv4) + assert.Equal(t, originalIPv6, nodeForDB.IPv6) + assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt) + assert.Equal(t, originalUser.Name, nodeForDB.User.Name) + assert.Equal(t, types.NodeID(2), nodeForDB.ID) + + // These fields should be suitable for direct database save + assert.NotNil(t, nodeForDB.IPv4) + assert.NotNil(t, nodeForDB.IPv6) + assert.False(t, nodeForDB.CreatedAt.IsZero()) + }, + }, + }, + }, } for _, tt := range tests { diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 47f78fd3..43f54c0e 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -336,38 +336,11 @@ func (s *State) ListAllUsers() ([]types.User, error) { return s.db.ListUsers() } -// updateNodeTx performs a database transaction to update a node and refresh the policy manager. -// IMPORTANT: This function does NOT update the NodeStore. The caller MUST update the NodeStore -// BEFORE calling this function with the EXACT same changes that the database update will make. -// This ensures the NodeStore is the source of truth for the batcher and maintains consistency. -// Returns error only; callers should get the updated NodeView from NodeStore to maintain consistency. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) error { - _, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := updateFn(tx); err != nil { - return nil, err - } - - node, err := hsdb.GetNodeByID(tx, nodeID) - if err != nil { - return nil, err - } - - if err := tx.Save(node).Error; err != nil { - return nil, fmt.Errorf("updating node: %w", err) - } - - return node, nil - }) - return err -} - -// persistNodeToDB saves the current state of a node from NodeStore to the database. -// CRITICAL: This function MUST get the latest node from NodeStore to ensure consistency. -func (s *State) persistNodeToDB(nodeID types.NodeID) (types.NodeView, change.ChangeSet, error) { - // CRITICAL: Always get the latest node from NodeStore to ensure we save the current state - node, found := s.nodeStore.GetNode(nodeID) - if !found { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) +// persistNodeToDB saves the given node state to the database. +// CRITICAL: This function MUST receive the exact node state to save, ensuring consistency. +func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.ChangeSet, error) { + if !node.Valid() { + return types.NodeView{}, change.EmptySet, fmt.Errorf("invalid node view provided") } nodePtr := node.AsStruct() @@ -393,10 +366,10 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, // Update NodeStore first nodePtr := node.AsStruct() - s.nodeStore.PutNode(*nodePtr) + resultNode := s.nodeStore.PutNode(*nodePtr) - // Then save to database - return s.persistNodeToDB(node.ID()) + // Then save to database using the result from PutNode + return s.persistNodeToDB(resultNode) } // DeleteNode permanently removes a node and cleans up associated resources. @@ -430,17 +403,14 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. // now := time.Now() - s.nodeStore.UpdateNode(id, func(n *types.Node) { + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { n.IsOnline = ptr.To(true) // n.LastSeen = ptr.To(now) }) - c := []change.ChangeSet{change.NodeOnline(id)} - - // Get fresh node data from NodeStore after the online status update - node, found := s.GetNodeByID(id) - if !found { + if !ok { return nil } + c := []change.ChangeSet{change.NodeOnline(id)} log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") @@ -460,39 +430,25 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { now := time.Now() - // Get node info before updating for logging - node, found := s.GetNodeByID(id) - var nodeName string - if found { - nodeName = node.Hostname() - } - - s.nodeStore.UpdateNode(id, func(n *types.Node) { + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { n.LastSeen = ptr.To(now) // NodeStore is the source of truth for all node state including online status. n.IsOnline = ptr.To(false) }) - if found { - log.Info().Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Node disconnected") + if !ok { + return nil, fmt.Errorf("node not found: %d", id) } - err := s.updateNodeTx(id, func(tx *gorm.DB) error { - // Update last_seen in the database - // Note: IsOnline is managed only in NodeStore (marked with gorm:"-"), not persisted to database - return hsdb.SetLastSeen(tx, id, now) - }) + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected") + + // Special error handling for disconnect - we log errors but continue + // because NodeStore is already updated and we need to notify peers + _, c, err := s.persistNodeToDB(node) if err != nil { // Log error but don't fail the disconnection - NodeStore is already updated // and we need to send change notifications to peers - log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update last seen in database") - } - - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - // Log error but continue - disconnection must proceed - log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update policy manager after node disconnect") + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database") c = change.EmptySet } @@ -610,35 +566,15 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node // If the database update fails, the NodeStore change will remain, but since we return // an error, no change notification will be sent to the batcher. expiryPtr := expiry - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.Expiry = &expiryPtr }) - err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.NodeSetExpiry(tx, nodeID, expiry) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.KeyExpiry(nodeID, expiry) - } - - return n, c, nil + return s.persistNodeToDB(n) } // SetNodeTags assigns tags to a node for use in access control policies. @@ -646,35 +582,15 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, // CRITICAL: Update NodeStore BEFORE database to ensure consistency. // The NodeStore update is blocking and will be the source of truth for the batcher. // The database update MUST make the EXACT same change. - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.ForcedTags = tags }) - err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetTags(tx, nodeID, tags) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil + return s.persistNodeToDB(n) } // SetApprovedRoutes sets the network routes that a node is approved to advertise. @@ -682,44 +598,15 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t // TODO(kradalby): In principle we should call the AutoApprove logic here // because even if the CLI removes an auto-approved route, it will be added // back automatically. - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.ApprovedRoutes = routes }) - err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetApprovedRoutes(tx, nodeID, routes) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - // Get the node from NodeStore to ensure we have the latest state - nodeView, ok := s.GetNodeByID(nodeID) - if !ok { - return n, change.EmptySet, fmt.Errorf("node %d not found in NodeStore", nodeID) - } - // Use SubnetRoutes() instead of ApprovedRoutes() to ensure we only set - // primary routes for routes that are both announced AND approved - routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...) - - if routeChange || !c.IsFull() { - c = change.PolicyChange() - } - - return n, c, nil + return s.persistNodeToDB(n) } // RenameNode changes the display name of a node. @@ -729,13 +616,11 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) } - // Check name uniqueness - nodes, err := s.db.ListNodes() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("checking name uniqueness: %w", err) - } - for _, node := range nodes { - if node.ID != nodeID && node.GivenName == newName { + // Check name uniqueness against NodeStore + allNodes := s.nodeStore.ListNodes() + for i := 0; i < allNodes.Len(); i++ { + node := allNodes.At(i) + if node.ID() != nodeID && node.AsStruct().GivenName == newName { return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName) } } @@ -743,35 +628,15 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, // CRITICAL: Update NodeStore BEFORE database to ensure consistency. // The NodeStore update is blocking and will be the source of truth for the batcher. // The database update MUST make the EXACT same change. - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.GivenName = newName }) - err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.RenameNode(tx, nodeID, newName) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil + return s.persistNodeToDB(n) } // AssignNodeToUser transfers a node to a different user. @@ -790,36 +655,16 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (type // CRITICAL: Update NodeStore BEFORE database to ensure consistency. // The NodeStore update is blocking and will be the source of truth for the batcher. // The database update MUST make the EXACT same change. - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { n.User = *user n.UserID = uint(userID) }) - err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.AssignNodeToUser(tx, nodeID, userID) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, err - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil + return s.persistNodeToDB(n) } // BackfillNodeIPs assigns IP addresses to nodes that don't have them. @@ -859,7 +704,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) { } // TODO(kradalby): This should just update the IP addresses, nothing else in the node store. // We should avoid PutNode here. - s.nodeStore.PutNode(*node) + _ = s.nodeStore.PutNode(*node) } } @@ -1075,7 +920,7 @@ func (s *State) HandleNodeFromAuthPath( Msg("Refreshing existing node registration") // Update NodeStore first with the new expiry - s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) { + updatedNode, ok := s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) { if expiry != nil { node.Expiry = expiry } @@ -1084,22 +929,22 @@ func (s *State) HandleNodeFromAuthPath( node.LastSeen = ptr.To(time.Now()) }) - // Save to database + if !ok { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeView.ID()) + } + + // Use the node from UpdateNode to save to database + nodePtr := updatedNode.AsStruct() _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry) - if err != nil { - return nil, err + if err := tx.Save(nodePtr).Error; err != nil { + return nil, fmt.Errorf("saving node: %w", err) } - // Return the node to satisfy the Write signature - return hsdb.GetNodeByID(tx, existingNodeView.ID()) + return nodePtr, nil }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err) } - // Get updated node from NodeStore - updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) - if expiry != nil { return updatedNode, change.KeyExpiry(existingNodeView.ID(), *expiry), nil } @@ -1163,7 +1008,7 @@ func (s *State) HandleNodeFromAuthPath( var savedNode *types.Node if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { // Update existing node - NodeStore first, then database - s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { + updatedNodeView, ok := s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { node.NodeKey = nodeToRegister.NodeKey node.DiscoKey = nodeToRegister.DiscoKey node.Hostname = nodeToRegister.Hostname @@ -1194,12 +1039,17 @@ func (s *State) HandleNodeFromAuthPath( node.LastSeen = ptr.To(time.Now()) }) - // Save to database + if !ok { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingMachineNode.ID) + } + + // Use the node from UpdateNode to save to database + nodePtr := updatedNodeView.AsStruct() savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { + if err := tx.Save(nodePtr).Error; err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } - return &nodeToRegister, nil + return nodePtr, nil }) if err != nil { return types.NodeView{}, change.EmptySet, err @@ -1217,7 +1067,7 @@ func (s *State) HandleNodeFromAuthPath( } // Add to NodeStore after database creates the ID - s.nodeStore.PutNode(*savedNode) + _ = s.nodeStore.PutNode(*savedNode) } // Delete from registration cache @@ -1345,7 +1195,7 @@ func (s *State) HandleNodeFromPreAuthKey( var savedNode *types.Node if existingNode != nil && existingNode.UserID == pak.User.ID { // Update existing node - NodeStore first, then database - s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { + updatedNodeView, ok := s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { node.NodeKey = nodeToRegister.NodeKey node.Hostname = nodeToRegister.Hostname @@ -1378,6 +1228,10 @@ func (s *State) HandleNodeFromPreAuthKey( node.LastSeen = ptr.To(time.Now()) }) + if !ok { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNode.ID) + } + log.Trace(). Caller(). Str("node.name", nodeToRegister.Hostname). @@ -1387,9 +1241,10 @@ func (s *State) HandleNodeFromPreAuthKey( Str("user.name", pak.User.Username()). Msg("Node re-authorized") - // Save to database + // Use the node from UpdateNode to save to database + nodePtr := updatedNodeView.AsStruct() savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { + if err := tx.Save(nodePtr).Error; err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } @@ -1400,7 +1255,7 @@ func (s *State) HandleNodeFromPreAuthKey( } } - return &nodeToRegister, nil + return nodePtr, nil }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) @@ -1426,7 +1281,7 @@ func (s *State) HandleNodeFromPreAuthKey( } // Add to NodeStore after database creates the ID - s.nodeStore.PutNode(*savedNode) + _ = s.nodeStore.PutNode(*savedNode) } // Update policy managers @@ -1570,7 +1425,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest var needsRouteApproval bool // We need to ensure we update the node as it is in the NodeStore at // the time of the request. - s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { + updatedNode, ok := s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { peerChange := currentNode.PeerChangeFromMapRequest(req) hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) @@ -1673,6 +1528,10 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest } }) + if !ok { + return change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", id) + } + nodeRouteChange := change.EmptySet // Handle route changes after NodeStore update @@ -1702,12 +1561,6 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest } if needsRouteUpdate { - // Get the updated node to access its subnet routes - updatedNode, exists := s.GetNodeByID(id) - if !exists { - return change.EmptySet, fmt.Errorf("node disappeared during update: %d", id) - } - // SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes() // which returns only the intersection of announced AND approved routes. // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. @@ -1721,7 +1574,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...) } - _, policyChange, err := s.persistNodeToDB(id) + _, policyChange, err := s.persistNodeToDB(updatedNode) if err != nil { return change.EmptySet, fmt.Errorf("saving to database: %w", err) }