diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 4256a89b..64ee0406 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -1,7 +1,11 @@ package state import ( + "context" + "fmt" "net/netip" + "runtime" + "sync" "testing" "time" @@ -835,3 +839,302 @@ type testStep struct { name string action func(store *NodeStore) } + +// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- + +// Helper for concurrent test nodes +func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { + machineKey := key.NewMachine() + nodeKey := key.NewNode() + return types.Node{ + ID: id, + Hostname: hostname, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + UserID: 1, + User: types.User{ + Name: "concurrent-test-user", + }, + } +} + +// --- Concurrency: concurrent PutNode operations --- +func TestNodeStoreConcurrentPutNode(t *testing.T) { + const concurrentOps = 20 + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + var wg sync.WaitGroup + results := make(chan bool, concurrentOps) + for i := 0; i < concurrentOps; i++ { + wg.Add(1) + go func(nodeID int) { + defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + resultNode := store.PutNode(node) + results <- resultNode.Valid() + }(i + 1) + } + wg.Wait() + close(results) + + successCount := 0 + for success := range results { + if success { + successCount++ + } + } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") +} + +// --- Batching: concurrent ops fit in one batch --- +func TestNodeStoreBatchingEfficiency(t *testing.T) { + const batchSize = 10 + const ops = 15 // more than batchSize + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + var wg sync.WaitGroup + results := make(chan bool, ops) + for i := 0; i < ops; i++ { + wg.Add(1) + go func(nodeID int) { + defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + resultNode := store.PutNode(node) + results <- resultNode.Valid() + }(i + 1) + } + wg.Wait() + close(results) + + successCount := 0 + for success := range results { + if success { + successCount++ + } + } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") +} + +// --- Race conditions: many goroutines on same node --- +func TestNodeStoreRaceConditions(t *testing.T) { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + nodeID := types.NodeID(1) + node := createConcurrentTestNode(nodeID, "race-node") + resultNode := store.PutNode(node) + require.True(t, resultNode.Valid()) + + const numGoroutines = 30 + const opsPerGoroutine = 10 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(gid int) { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + switch j % 3 { + case 0: + resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "race-updated" + }) + if !resultNode.Valid() { + errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) + } + case 1: + retrieved, found := store.GetNode(nodeID) + if !found || !retrieved.Valid() { + errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) + } + case 2: + newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) + if !resultNode.Valid() { + errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) + } + } + } + }(i) + } + wg.Wait() + close(errors) + + errorCount := 0 + for err := range errors { + t.Error(err) + errorCount++ + } + if errorCount > 0 { + t.Fatalf("Race condition test failed with %d errors", errorCount) + } +} + +// --- Resource cleanup: goroutine leak detection --- +func TestNodeStoreResourceCleanup(t *testing.T) { + // initialGoroutines := runtime.NumGoroutine() + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + time.Sleep(50 * time.Millisecond) + afterStartGoroutines := runtime.NumGoroutine() + + const ops = 100 + for i := 0; i < ops; i++ { + nodeID := types.NodeID(i + 1) + node := createConcurrentTestNode(nodeID, "cleanup-node") + resultNode := store.PutNode(node) + assert.True(t, resultNode.Valid()) + store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "cleanup-updated" + }) + retrieved, found := store.GetNode(nodeID) + assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { + store.DeleteNode(nodeID) + } + } + runtime.GC() + time.Sleep(100 * time.Millisecond) + finalGoroutines := runtime.NumGoroutine() + if finalGoroutines > afterStartGoroutines+2 { + t.Errorf("Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines) + } +} + +// --- Timeout/deadlock: operations complete within reasonable time --- +func TestNodeStoreOperationTimeout(t *testing.T) { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) + updateResults := make([]error, ops) + + // Launch all PutNode operations concurrently + for i := 1; i <= ops; i++ { + nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { + defer wg.Done() + startPut := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) + node := createConcurrentTestNode(id, "timeout-node") + resultNode := store.PutNode(node) + endPut := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { + putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) + } + }(i, nodeID) + } + wg.Wait() + + // Launch all UpdateNode operations concurrently + wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { + nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { + defer wg.Done() + startUpdate := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) + resultNode, ok := store.UpdateNode(id, func(n *types.Node) { + n.Hostname = "timeout-updated" + }) + endUpdate := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { + updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) + } + }(i, nodeID) + } + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + errorCount := 0 + for _, err := range putResults { + if err != nil { + t.Error(err) + errorCount++ + } + } + for _, err := range updateResults { + if err != nil { + t.Error(err) + errorCount++ + } + } + if errorCount == 0 { + t.Log("All concurrent operations completed successfully within timeout") + } else { + t.Fatalf("Some concurrent operations failed: %d errors", errorCount) + } + case <-ctx.Done(): + fmt.Println("[TestNodeStoreOperationTimeout] Timeout reached, test failed") + t.Fatal("Operations timed out - potential deadlock or resource issue") + } +} + +// --- Edge case: update non-existent node --- +func TestNodeStoreUpdateNonExistentNode(t *testing.T) { + for i := 0; i < 10; i++ { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + nonExistentID := types.NodeID(999 + i) + updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) + resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { + updateCallCount++ + n.Hostname = "should-never-be-called" + }) + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) finished, valid=%v, ok=%v, updateCallCount=%d\n", nonExistentID, resultNode.Valid(), ok, updateCallCount) + assert.False(t, ok, "UpdateNode should return false for non-existent node") + assert.False(t, resultNode.Valid(), "UpdateNode should return invalid node for non-existent node") + assert.Equal(t, 0, updateCallCount, "UpdateFn should not be called for non-existent node") + store.Stop() + } +} + +// --- Allocation benchmark --- +func BenchmarkNodeStoreAllocations(b *testing.B) { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + nodeID := types.NodeID(i + 1) + node := createConcurrentTestNode(nodeID, "bench-node") + store.PutNode(node) + store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "bench-updated" + }) + store.GetNode(nodeID) + if i%10 == 9 { + store.DeleteNode(nodeID) + } + } +} + +func TestNodeStoreAllocationStats(t *testing.T) { + res := testing.Benchmark(BenchmarkNodeStoreAllocations) + allocs := res.AllocsPerOp() + t.Logf("NodeStore allocations per op: %.2f", float64(allocs)) +}