mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			483 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			483 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package state
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"maps"
 | |
| 	"strings"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/juanfont/headscale/hscontrol/types"
 | |
| 	"github.com/prometheus/client_golang/prometheus"
 | |
| 	"github.com/prometheus/client_golang/prometheus/promauto"
 | |
| 	"tailscale.com/types/key"
 | |
| 	"tailscale.com/types/views"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	batchSize    = 10
 | |
| 	batchTimeout = 500 * time.Millisecond
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	put    = 1
 | |
| 	del    = 2
 | |
| 	update = 3
 | |
| )
 | |
| 
 | |
| const prometheusNamespace = "headscale"
 | |
| 
 | |
| var (
 | |
| 	nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_operations_total",
 | |
| 		Help:      "Total number of NodeStore operations",
 | |
| 	}, []string{"operation"})
 | |
| 	nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_operation_duration_seconds",
 | |
| 		Help:      "Duration of NodeStore operations",
 | |
| 		Buckets:   prometheus.DefBuckets,
 | |
| 	}, []string{"operation"})
 | |
| 	nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_batch_size",
 | |
| 		Help:      "Size of NodeStore write batches",
 | |
| 		Buckets:   []float64{1, 2, 5, 10, 20, 50, 100},
 | |
| 	})
 | |
| 	nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_batch_duration_seconds",
 | |
| 		Help:      "Duration of NodeStore batch processing",
 | |
| 		Buckets:   prometheus.DefBuckets,
 | |
| 	})
 | |
| 	nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_snapshot_build_duration_seconds",
 | |
| 		Help:      "Duration of NodeStore snapshot building from nodes",
 | |
| 		Buckets:   prometheus.DefBuckets,
 | |
| 	})
 | |
| 	nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_nodes_total",
 | |
| 		Help:      "Total number of nodes in the NodeStore",
 | |
| 	})
 | |
| 	nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_peers_calculation_duration_seconds",
 | |
| 		Help:      "Duration of peers calculation in NodeStore",
 | |
| 		Buckets:   prometheus.DefBuckets,
 | |
| 	})
 | |
| 	nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
 | |
| 		Namespace: prometheusNamespace,
 | |
| 		Name:      "nodestore_queue_depth",
 | |
| 		Help:      "Current depth of NodeStore write queue",
 | |
| 	})
 | |
| )
 | |
| 
 | |
| // NodeStore is a thread-safe store for nodes.
 | |
| // It is a copy-on-write structure, replacing the "snapshot"
 | |
| // when a change to the structure occurs. It is optimised for reads,
 | |
| // and while batches are not fast, they are grouped together
 | |
| // to do less of the expensive peer calculation if there are many
 | |
| // changes rapidly.
 | |
| //
 | |
| // Writes will block until committed, while reads are never
 | |
| // blocked. This means that the caller of a write operation
 | |
| // is responsible for ensuring an update depending on a write
 | |
| // is not issued before the write is complete.
 | |
| type NodeStore struct {
 | |
| 	data atomic.Pointer[Snapshot]
 | |
| 
 | |
| 	peersFunc  PeersFunc
 | |
| 	writeQueue chan work
 | |
| }
 | |
| 
 | |
| func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
 | |
| 	nodes := make(map[types.NodeID]types.Node, len(allNodes))
 | |
| 	for _, n := range allNodes {
 | |
| 		nodes[n.ID] = *n
 | |
| 	}
 | |
| 	snap := snapshotFromNodes(nodes, peersFunc)
 | |
| 
 | |
| 	store := &NodeStore{
 | |
| 		peersFunc: peersFunc,
 | |
| 	}
 | |
| 	store.data.Store(&snap)
 | |
| 
 | |
| 	// Initialize node count gauge
 | |
| 	nodeStoreNodesCount.Set(float64(len(nodes)))
 | |
| 
 | |
| 	return store
 | |
| }
 | |
| 
 | |
| // Snapshot is the representation of the current state of the NodeStore.
 | |
| // It contains all nodes and their relationships.
 | |
| // It is a copy-on-write structure, meaning that when a write occurs,
 | |
| // a new Snapshot is created with the updated state,
 | |
| // and replaces the old one atomically.
 | |
| type Snapshot struct {
 | |
| 	// nodesByID is the main source of truth for nodes.
 | |
| 	nodesByID map[types.NodeID]types.Node
 | |
| 
 | |
| 	// calculated from nodesByID
 | |
| 	nodesByNodeKey map[key.NodePublic]types.NodeView
 | |
| 	peersByNode    map[types.NodeID][]types.NodeView
 | |
| 	nodesByUser    map[types.UserID][]types.NodeView
 | |
| 	allNodes       []types.NodeView
 | |
| }
 | |
| 
 | |
| // PeersFunc is a function that takes a list of nodes and returns a map
 | |
| // with the relationships between nodes and their peers.
 | |
| // This will typically be used to calculate which nodes can see each other
 | |
| // based on the current policy.
 | |
| type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
 | |
| 
 | |
| // work represents a single operation to be performed on the NodeStore.
 | |
| type work struct {
 | |
| 	op       int
 | |
| 	nodeID   types.NodeID
 | |
| 	node     types.Node
 | |
| 	updateFn UpdateNodeFunc
 | |
| 	result   chan struct{}
 | |
| }
 | |
| 
 | |
| // PutNode adds or updates a node in the store.
 | |
| // If the node already exists, it will be replaced.
 | |
| // If the node does not exist, it will be added.
 | |
| // This is a blocking operation that waits for the write to complete.
 | |
| func (s *NodeStore) PutNode(n types.Node) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	work := work{
 | |
| 		op:     put,
 | |
| 		nodeID: n.ID,
 | |
| 		node:   n,
 | |
| 		result: make(chan struct{}),
 | |
| 	}
 | |
| 
 | |
| 	nodeStoreQueueDepth.Inc()
 | |
| 	s.writeQueue <- work
 | |
| 	<-work.result
 | |
| 	nodeStoreQueueDepth.Dec()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("put").Inc()
 | |
| }
 | |
| 
 | |
| // UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
 | |
| type UpdateNodeFunc func(n *types.Node)
 | |
| 
 | |
| // UpdateNode applies a function to modify a specific node in the store.
 | |
| // This is a blocking operation that waits for the write to complete.
 | |
| // This is analogous to a database "transaction", or, the caller should
 | |
| // rather collect all data they want to change, and then call this function.
 | |
| // Fewer calls are better.
 | |
| //
 | |
| // TODO(kradalby): Technically we could have a version of this that modifies the node
 | |
| // in the current snapshot if _we know_ that the change will not affect the peer relationships.
 | |
| // This is because the main nodesByID map contains the struct, and every other map is using a
 | |
| // pointer to the underlying struct. The gotcha with this is that we will need to introduce
 | |
| // a lock around the nodesByID map to ensure that no other writes are happening
 | |
| // while we are modifying the node. Which mean we would need to implement read-write locks
 | |
| // on all read operations.
 | |
| func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	work := work{
 | |
| 		op:       update,
 | |
| 		nodeID:   nodeID,
 | |
| 		updateFn: updateFn,
 | |
| 		result:   make(chan struct{}),
 | |
| 	}
 | |
| 
 | |
| 	nodeStoreQueueDepth.Inc()
 | |
| 	s.writeQueue <- work
 | |
| 	<-work.result
 | |
| 	nodeStoreQueueDepth.Dec()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("update").Inc()
 | |
| }
 | |
| 
 | |
| // DeleteNode removes a node from the store by its ID.
 | |
| // This is a blocking operation that waits for the write to complete.
 | |
| func (s *NodeStore) DeleteNode(id types.NodeID) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	work := work{
 | |
| 		op:     del,
 | |
| 		nodeID: id,
 | |
| 		result: make(chan struct{}),
 | |
| 	}
 | |
| 
 | |
| 	nodeStoreQueueDepth.Inc()
 | |
| 	s.writeQueue <- work
 | |
| 	<-work.result
 | |
| 	nodeStoreQueueDepth.Dec()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("delete").Inc()
 | |
| }
 | |
| 
 | |
| // Start initializes the NodeStore and starts processing the write queue.
 | |
| func (s *NodeStore) Start() {
 | |
| 	s.writeQueue = make(chan work)
 | |
| 	go s.processWrite()
 | |
| }
 | |
| 
 | |
| // Stop stops the NodeStore.
 | |
| func (s *NodeStore) Stop() {
 | |
| 	close(s.writeQueue)
 | |
| }
 | |
| 
 | |
| // processWrite processes the write queue in batches.
 | |
| func (s *NodeStore) processWrite() {
 | |
| 	c := time.NewTicker(batchTimeout)
 | |
| 	defer c.Stop()
 | |
| 	batch := make([]work, 0, batchSize)
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case w, ok := <-s.writeQueue:
 | |
| 			if !ok {
 | |
| 				// Channel closed, apply any remaining batch and exit
 | |
| 				if len(batch) != 0 {
 | |
| 					s.applyBatch(batch)
 | |
| 				}
 | |
| 				return
 | |
| 			}
 | |
| 			batch = append(batch, w)
 | |
| 			if len(batch) >= batchSize {
 | |
| 				s.applyBatch(batch)
 | |
| 				batch = batch[:0]
 | |
| 				c.Reset(batchTimeout)
 | |
| 			}
 | |
| 		case <-c.C:
 | |
| 			if len(batch) != 0 {
 | |
| 				s.applyBatch(batch)
 | |
| 				batch = batch[:0]
 | |
| 			}
 | |
| 			c.Reset(batchTimeout)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // applyBatch applies a batch of work to the node store.
 | |
| // This means that it takes a copy of the current nodes,
 | |
| // then applies the batch of operations to that copy,
 | |
| // runs any precomputation needed (like calculating peers),
 | |
| // and finally replaces the snapshot in the store with the new one.
 | |
| // The replacement of the snapshot is atomic, ensuring that reads
 | |
| // are never blocked by writes.
 | |
| // Each write item is blocked until the batch is applied to ensure
 | |
| // the caller knows the operation is complete and do not send any
 | |
| // updates that are dependent on a read that is yet to be written.
 | |
| func (s *NodeStore) applyBatch(batch []work) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreBatchDuration)
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreBatchSize.Observe(float64(len(batch)))
 | |
| 
 | |
| 	nodes := make(map[types.NodeID]types.Node)
 | |
| 	maps.Copy(nodes, s.data.Load().nodesByID)
 | |
| 
 | |
| 	for _, w := range batch {
 | |
| 		switch w.op {
 | |
| 		case put:
 | |
| 			nodes[w.nodeID] = w.node
 | |
| 		case update:
 | |
| 			// Update the specific node identified by nodeID
 | |
| 			if n, exists := nodes[w.nodeID]; exists {
 | |
| 				w.updateFn(&n)
 | |
| 				nodes[w.nodeID] = n
 | |
| 			}
 | |
| 		case del:
 | |
| 			delete(nodes, w.nodeID)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	newSnap := snapshotFromNodes(nodes, s.peersFunc)
 | |
| 	s.data.Store(&newSnap)
 | |
| 
 | |
| 	// Update node count gauge
 | |
| 	nodeStoreNodesCount.Set(float64(len(nodes)))
 | |
| 
 | |
| 	for _, w := range batch {
 | |
| 		close(w.result)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // snapshotFromNodes creates a new Snapshot from the provided nodes.
 | |
| // It builds a lot of "indexes" to make lookups fast for datasets we
 | |
| // that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
 | |
| // This is not a fast operation, it is the "slow" part of our copy-on-write
 | |
| // structure, but it allows us to have fast reads and efficient lookups.
 | |
| func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
 | |
| 	timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	allNodes := make([]types.NodeView, 0, len(nodes))
 | |
| 	for _, n := range nodes {
 | |
| 		allNodes = append(allNodes, n.View())
 | |
| 	}
 | |
| 
 | |
| 	newSnap := Snapshot{
 | |
| 		nodesByID:      nodes,
 | |
| 		allNodes:       allNodes,
 | |
| 		nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
 | |
| 
 | |
| 		// peersByNode is most likely the most expensive operation,
 | |
| 		// it will use the list of all nodes, combined with the
 | |
| 		// current policy to precalculate which nodes are peers and
 | |
| 		// can see each other.
 | |
| 		peersByNode: func() map[types.NodeID][]types.NodeView {
 | |
| 			peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
 | |
| 			defer peersTimer.ObserveDuration()
 | |
| 			return peersFunc(allNodes)
 | |
| 		}(),
 | |
| 		nodesByUser: make(map[types.UserID][]types.NodeView),
 | |
| 	}
 | |
| 
 | |
| 	// Build nodesByUser and nodesByNodeKey maps
 | |
| 	for _, n := range nodes {
 | |
| 		nodeView := n.View()
 | |
| 		newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
 | |
| 		newSnap.nodesByNodeKey[n.NodeKey] = nodeView
 | |
| 	}
 | |
| 
 | |
| 	return newSnap
 | |
| }
 | |
| 
 | |
| // GetNode retrieves a node by its ID.
 | |
| // The bool indicates if the node exists or is available (like "err not found").
 | |
| // The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
 | |
| // it isn't an invalid node (this is more of a node error or node is broken).
 | |
| func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("get").Inc()
 | |
| 
 | |
| 	n, exists := s.data.Load().nodesByID[id]
 | |
| 	if !exists {
 | |
| 		return types.NodeView{}, false
 | |
| 	}
 | |
| 
 | |
| 	return n.View(), true
 | |
| }
 | |
| 
 | |
| // GetNodeByNodeKey retrieves a node by its NodeKey.
 | |
| // The bool indicates if the node exists or is available (like "err not found").
 | |
| // The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
 | |
| // it isn't an invalid node (this is more of a node error or node is broken).
 | |
| func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_key"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("get_by_key").Inc()
 | |
| 
 | |
| 	nodeView, exists := s.data.Load().nodesByNodeKey[nodeKey]
 | |
| 
 | |
| 	return nodeView, exists
 | |
| }
 | |
| 
 | |
| // GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists.
 | |
| func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc()
 | |
| 
 | |
| 	snapshot := s.data.Load()
 | |
| 	// We don't have a byMachineKey map, so we need to iterate
 | |
| 	// This could be optimized by adding a byMachineKey map if this becomes a hot path
 | |
| 	for _, node := range snapshot.nodesByID {
 | |
| 		if node.MachineKey == machineKey {
 | |
| 			return node.View(), true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return types.NodeView{}, false
 | |
| }
 | |
| 
 | |
| // DebugString returns debug information about the NodeStore.
 | |
| func (s *NodeStore) DebugString() string {
 | |
| 	snapshot := s.data.Load()
 | |
| 
 | |
| 	var sb strings.Builder
 | |
| 
 | |
| 	sb.WriteString("=== NodeStore Debug Information ===\n\n")
 | |
| 
 | |
| 	// Basic counts
 | |
| 	sb.WriteString(fmt.Sprintf("Total Nodes: %d\n", len(snapshot.nodesByID)))
 | |
| 	sb.WriteString(fmt.Sprintf("Users with Nodes: %d\n", len(snapshot.nodesByUser)))
 | |
| 	sb.WriteString("\n")
 | |
| 
 | |
| 	// User distribution
 | |
| 	sb.WriteString("Nodes by User:\n")
 | |
| 	for userID, nodes := range snapshot.nodesByUser {
 | |
| 		if len(nodes) > 0 {
 | |
| 			userName := "unknown"
 | |
| 			if len(nodes) > 0 && nodes[0].Valid() {
 | |
| 				userName = nodes[0].User().Name
 | |
| 			}
 | |
| 			sb.WriteString(fmt.Sprintf("  - User %d (%s): %d nodes\n", userID, userName, len(nodes)))
 | |
| 		}
 | |
| 	}
 | |
| 	sb.WriteString("\n")
 | |
| 
 | |
| 	// Peer relationships summary
 | |
| 	sb.WriteString("Peer Relationships:\n")
 | |
| 	totalPeers := 0
 | |
| 	for nodeID, peers := range snapshot.peersByNode {
 | |
| 		peerCount := len(peers)
 | |
| 		totalPeers += peerCount
 | |
| 		if node, exists := snapshot.nodesByID[nodeID]; exists {
 | |
| 			sb.WriteString(fmt.Sprintf("  - Node %d (%s): %d peers\n",
 | |
| 				nodeID, node.Hostname, peerCount))
 | |
| 		}
 | |
| 	}
 | |
| 	if len(snapshot.peersByNode) > 0 {
 | |
| 		avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode))
 | |
| 		sb.WriteString(fmt.Sprintf("  - Average peers per node: %.1f\n", avgPeers))
 | |
| 	}
 | |
| 	sb.WriteString("\n")
 | |
| 
 | |
| 	// Node key index
 | |
| 	sb.WriteString(fmt.Sprintf("NodeKey Index: %d entries\n", len(snapshot.nodesByNodeKey)))
 | |
| 	sb.WriteString("\n")
 | |
| 
 | |
| 	return sb.String()
 | |
| }
 | |
| 
 | |
| // ListNodes returns a slice of all nodes in the store.
 | |
| func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("list").Inc()
 | |
| 
 | |
| 	return views.SliceOf(s.data.Load().allNodes)
 | |
| }
 | |
| 
 | |
| // ListPeers returns a slice of all peers for a given node ID.
 | |
| func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("list_peers").Inc()
 | |
| 
 | |
| 	return views.SliceOf(s.data.Load().peersByNode[id])
 | |
| }
 | |
| 
 | |
| // ListNodesByUser returns a slice of all nodes for a given user ID.
 | |
| func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
 | |
| 	timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
 | |
| 	defer timer.ObserveDuration()
 | |
| 
 | |
| 	nodeStoreOperations.WithLabelValues("list_by_user").Inc()
 | |
| 
 | |
| 	return views.SliceOf(s.data.Load().nodesByUser[uid])
 | |
| }
 |