1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-24 13:46:53 +02:00

state/nodestore: in memory representation of nodes

Initial work on a nodestore which stores all of the nodes
and their relations in memory with relationship for peers
precalculated.

It is a copy-on-write structure, replacing the "snapshot"
when a change to the structure occurs. It is optimised for reads,
and while batches are not fast, they are grouped together
to do less of the expensive peer calculation if there are many
changes rapidly.

Writes will block until commited, while reads are never
blocked.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-07-05 23:30:47 +02:00
parent 1f7ed7e4e5
commit 0c539d7993
No known key found for this signature in database
21 changed files with 1523 additions and 533 deletions

View File

@ -548,8 +548,8 @@ func (h *Headscale) Serve() error {
if err != nil {
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
}
for _, node := range ephmNodes {
h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout)
for _, node := range ephmNodes.All() {
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
}
if h.cfg.DNSConfig.ExtraRecordsPath != "" {

View File

@ -34,7 +34,7 @@ func (h *Headscale) handleRegister(
return nil, fmt.Errorf("looking up node in database: %w", err)
}
if node != nil {
if node.Valid() {
// If an existing node is trying to register with an auth key,
// we need to validate the auth key even for existing nodes
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
@ -50,7 +50,7 @@ func (h *Headscale) handleRegister(
return resp, nil
}
resp, err := h.handleExistingNode(node, regReq, machineKey)
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
}
@ -107,7 +107,7 @@ func (h *Headscale) handleExistingNode(
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
c, err := h.state.DeleteNode(node)
c, err := h.state.DeleteNode(node.View())
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode(
}
h.Change(c)
}
}
return nodeToRegisterResponse(node), nil
return nodeToRegisterResponse(node), nil
}
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
@ -177,7 +177,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
)
@ -194,7 +194,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
}
// If node is nil, it means an ephemeral node was deleted during logout
if node == nil {
if node.Valid() {
h.Change(changed)
return nil, nil
}
@ -218,21 +218,23 @@ func (h *Headscale) handleRegisterWithAuthKey(
}
if routeChange && changed.Empty() {
changed = change.NodeAdded(node.ID)
changed = change.NodeAdded(node.ID())
}
h.Change(changed)
// If policy changed due to node registration, send a separate policy change
if policyChanged {
policyChange := change.PolicyChange()
h.Change(policyChange)
}
// TODO(kradalby): I think this is covered above, but we need to validate that.
// // If policy changed due to node registration, send a separate policy change
// if policyChanged {
// policyChange := change.PolicyChange()
// h.Change(policyChange)
// }
user := node.User()
return &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
User: *node.User.TailscaleUser(),
Login: *node.User.TailscaleLogin(),
User: *user.TailscaleUser(),
Login: *user.TailscaleLogin(),
}, nil
}

View File

@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
}
sshPol := make(map[string]*tailcfg.SSHPolicy)
for _, node := range nodes {
pol, err := h.state.SSHPolicy(node.View())
for _, node := range nodes.All() {
pol, err := h.state.SSHPolicy(node)
if err != nil {
httpError(w, err)
return
}
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol
}
sshJSON, err := json.MarshalIndent(sshPol, "", " ")

View File

@ -25,6 +25,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/views"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/state"
@ -59,9 +60,10 @@ func (api headscaleV1APIServer) CreateUser(
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
}
c := change.UserAdded(types.UserID(user.ID))
if policyChanged {
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
if !policyChanged.Empty() {
c.Change = change.Policy
}
@ -79,15 +81,13 @@ func (api headscaleV1APIServer) RenameUser(
return nil, err
}
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
_, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
api.h.Change(change.PolicyChange())
}
api.h.Change(c)
newUser, err := api.h.state.GetUserByName(request.GetNewName())
if err != nil {
@ -297,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
resp.Online = api.h.mapBatcher.IsConnected(node.ID())
return &v1.GetNodeResponse{Node: resp}, nil
}
@ -323,7 +323,7 @@ func (api headscaleV1APIServer) SetTags(
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
Str("node", node.Hostname()).
Strs("tags", request.GetTags()).
Msg("Changing tags of node")
@ -357,7 +357,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
return nil, status.Error(codes.InvalidArgument, err.Error())
}
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...)
// Always propagate node changes from SetApprovedRoutes
api.h.Change(nodeChange)
@ -368,7 +368,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
}
proto := node.Proto()
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID()))
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
}
@ -420,8 +420,8 @@ func (api headscaleV1APIServer) ExpireNode(
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
Time("expiry", *node.Expiry).
Str("node", node.Hostname()).
Time("expiry", *node.AsStruct().Expiry).
Msg("node expired")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
@ -440,7 +440,7 @@ func (api headscaleV1APIServer) RenameNode(
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
Str("node", node.Hostname()).
Str("new_name", request.GetNewName()).
Msg("node renamed")
@ -477,36 +477,36 @@ func (api headscaleV1APIServer) ListNodes(
return nil, err
}
sort.Slice(nodes, func(i, j int) bool {
return nodes[i].ID < nodes[j].ID
})
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID, bool], nodes views.Slice[types.NodeView]) []*v1.Node {
response := make([]*v1.Node, nodes.Len())
for index, node := range nodes.All() {
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
if val, ok := IsConnected.Load(node.ID); ok && val {
if val, ok := isLikelyConnected.Load(node.ID()); ok && val {
resp.Online = true
}
var tags []string
for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node.View(), tag) {
if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
})
return response
}
@ -683,8 +683,8 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, fmt.Errorf("setting policy: %w", err)
}
if len(nodes) > 0 {
_, err = api.h.state.SSHPolicy(nodes[0].View())
if nodes.Len() > 0 {
_, err = api.h.state.SSHPolicy(nodes.At(0))
if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err)
}

View File

@ -99,8 +99,17 @@ func (h *Headscale) handleVerifyRequest(
return fmt.Errorf("cannot list nodes: %w", err)
}
// Check if any node has the requested NodeKey
var nodeKeyFound bool
for _, node := range nodes.All() {
if node.NodeKey() == derpAdmitClientRequest.NodePublic {
nodeKeyFound = true
break
}
}
resp := &tailcfg.DERPAdmitClientResponse{
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
Allow: nodeKeyFound,
}
return json.NewEncoder(writer).Encode(resp)

View File

@ -1940,7 +1940,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
if err != nil {
t.Errorf("Error listing peers for node %d: %v", i, err)
} else {
t.Logf("Node %d should see %d peers from state", i, len(peers))
t.Logf("Node %d should see %d peers from state", i, peers.Len())
}
}
@ -2106,9 +2106,9 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
if err != nil {
t.Errorf("Error getting peers from state: %v", err)
} else {
t.Logf("State shows %d peers available for this node", len(peers))
if len(peers) > 0 && len(data.Peers) == 0 {
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers))
t.Logf("State shows %d peers available for this node", peers.Len())
if peers.Len() > 0 && len(data.Peers) == 0 {
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
}
}
} else {

View File

@ -63,9 +63,9 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
_, matchers := b.mapper.state.Filter()
tailnode, err := tailNode(
node.View(), b.capVer, b.mapper.state,
node, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
@ -112,7 +112,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
return b
}
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
sshPolicy, err := b.mapper.state.SSHPolicy(node)
if err != nil {
b.addError(err)
return b
@ -135,7 +135,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
}
// WithUserProfiles adds user profiles for the requesting node and given peers
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
@ -161,14 +161,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node.View(), filter),
"base": policy.ReduceFilterRules(node, filter),
}
return b
}
// WithPeers adds full peer list with policy filtering (for full map response)
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
@ -181,7 +181,7 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
}
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
@ -193,8 +193,8 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil
return b
}
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
return nil, err
@ -206,15 +206,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
// access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView]
if len(filter) > 0 {
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
changedViews = policy.ReduceNodes(node, peers, matchers)
} else {
changedViews = peers.ViewSlice()
changedViews = peers
}
tailPeers, err := tailNodes(
changedViews, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {

View File

@ -18,6 +18,7 @@ import (
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/views"
)
const (
@ -68,16 +69,18 @@ func newMapper(
}
func generateUserProfiles(
node *types.Node,
peers types.Nodes,
node types.NodeView,
peers views.Slice[types.NodeView],
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User)
ids := make([]uint, 0, len(userMap))
userMap[node.User.ID] = &node.User
ids = append(ids, node.User.ID)
for _, peer := range peers {
userMap[peer.User.ID] = &peer.User
ids = append(ids, peer.User.ID)
user := node.User()
userMap[user.ID] = &user
ids = append(ids, user.ID)
for _, peer := range peers.All() {
peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
}
slices.Sort(ids)
@ -94,7 +97,7 @@ func generateUserProfiles(
func generateDNSConfig(
cfg *types.Config,
node *types.Node,
node types.NodeView,
) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil {
return nil
@ -114,12 +117,12 @@ func generateDNSConfig(
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{node.Hostname},
"device_model": []string{node.Hostinfo.OS},
"device_name": []string{node.Hostname()},
"device_model": []string{node.Hostinfo().OS()},
}
if len(node.IPs()) > 0 {
@ -137,7 +140,7 @@ func (m *mapper) fullMapResponse(
capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID)
peers, err := m.state.ListPeers(nodeID)
if err != nil {
return nil, err
}
@ -182,7 +185,7 @@ func (m *mapper) peerChangeResponse(
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID, changedNodeID)
peers, err := m.state.ListPeers(nodeID, changedNodeID)
if err != nil {
return nil, err
}
@ -256,25 +259,6 @@ func writeDebugMapResponse(
}
}
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.state.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, peer := range peers {
online := m.batcher.IsConnected(peer.ID)
peer.IsOnline = &online
}
return peers, nil
}
// routeFilterFunc is a function that takes a node ID and returns a list of
// netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node.

View File

@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
&types.Config{
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1,
nodeInShared1.View(),
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {

View File

@ -296,7 +296,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
// getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
@ -304,8 +304,6 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
}
nv := node.View()
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
if ns.machineKey != nv.MachineKey() {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)

View File

@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims)
user, c, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
log.Error().
Err(err).
@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
// Send policy update notifications if needed
if policyChanged {
a.h.Change(change.PolicyChange())
}
a.h.Change(c)
// TODO(kradalby): Is this comment right?
// If the node exists, then the node should be reauthenticated,
@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, bool, error) {
) (*types.User, change.ChangeSet, error) {
var user *types.User
var err error
var newUser bool
var policyChanged bool
var c change.ChangeSet
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, false, fmt.Errorf("creating or updating user: %w", err)
return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
}
// if the user is still not found, create a new empty user.
@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
user.FromClaim(claims)
if newUser {
user, policyChanged, err = a.h.state.CreateUser(*user)
user, c, err = a.h.state.CreateUser(*user)
if err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
}
} else {
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
_, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
*u = *user
return nil
})
if err != nil {
return nil, false, fmt.Errorf("updating user: %w", err)
return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
}
}
return user, policyChanged, nil
return user, c, nil
}
func (a *AuthProviderOIDC) handleRegistration(

View File

@ -235,6 +235,7 @@ func (m *mapSession) serveLongPoll() {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
m.resetKeepAlive()
}
}
}

View File

@ -0,0 +1,260 @@
package state
import (
"maps"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/types/key"
"tailscale.com/types/views"
)
const (
batchSize = 10
batchTimeout = 500 * time.Millisecond
)
const (
put = 1
del = 2
update = 3
)
// NodeStore is a thread-safe store for nodes.
// It is a copy-on-write structure, replacing the "snapshot"
// when a change to the structure occurs. It is optimised for reads,
// and while batches are not fast, they are grouped together
// to do less of the expensive peer calculation if there are many
// changes rapidly.
//
// Writes will block until committed, while reads are never
// blocked.
type NodeStore struct {
data atomic.Pointer[Snapshot]
peersFunc PeersFunc
writeQueue chan work
// TODO: metrics
}
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
nodes := make(map[types.NodeID]types.Node, len(allNodes))
for _, n := range allNodes {
nodes[n.ID] = *n
}
snap := snapshotFromNodes(nodes, peersFunc)
store := &NodeStore{
peersFunc: peersFunc,
}
store.data.Store(&snap)
return store
}
type Snapshot struct {
// nodesByID is the main source of truth for nodes.
nodesByID map[types.NodeID]types.Node
// calculated from nodesByID
nodesByNodeKey map[key.NodePublic]types.NodeView
peersByNode map[types.NodeID][]types.NodeView
nodesByUser map[types.UserID][]types.NodeView
allNodes []types.NodeView
}
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
type work struct {
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
immediate bool // For operations that need immediate processing
}
// PutNode adds or updates a node in the store.
// If the node already exists, it will be replaced.
// If the node does not exist, it will be added.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) PutNode(n types.Node) {
work := work{
op: put,
nodeID: n.ID,
node: n,
result: make(chan struct{}),
}
s.writeQueue <- work
<-work.result
}
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
type UpdateNodeFunc func(n *types.Node)
// UpdateNode applies a function to modify a specific node in the store.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
work := work{
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
}
s.writeQueue <- work
<-work.result
}
// UpdateNodeImmediate applies a function to modify a specific node in the store
// with immediate processing (bypassing normal batching delays).
// Use this for time-sensitive updates like online status changes.
func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) {
work := work{
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
immediate: true,
}
s.writeQueue <- work
<-work.result
}
// DeleteNode removes a node from the store by its ID.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) DeleteNode(id types.NodeID) {
work := work{
op: del,
nodeID: id,
result: make(chan struct{}),
}
s.writeQueue <- work
<-work.result
}
func (s *NodeStore) Start() {
s.writeQueue = make(chan work)
go s.processWrite()
}
func (s *NodeStore) Stop() {
close(s.writeQueue)
}
func (s *NodeStore) processWrite() {
c := time.NewTicker(batchTimeout)
batch := make([]work, 0, batchSize)
for {
select {
case w, ok := <-s.writeQueue:
if !ok {
c.Stop()
return
}
// Handle immediate operations right away
if w.immediate {
s.applyBatch([]work{w})
continue
}
batch = append(batch, w)
if len(batch) >= batchSize {
s.applyBatch(batch)
batch = batch[:0]
c.Reset(batchTimeout)
}
case <-c.C:
if len(batch) != 0 {
s.applyBatch(batch)
batch = batch[:0]
}
c.Reset(batchTimeout)
}
}
}
func (s *NodeStore) applyBatch(batch []work) {
nodes := make(map[types.NodeID]types.Node)
maps.Copy(nodes, s.data.Load().nodesByID)
for _, w := range batch {
switch w.op {
case put:
nodes[w.nodeID] = w.node
case update:
// Update the specific node identified by nodeID
if n, exists := nodes[w.nodeID]; exists {
w.updateFn(&n)
nodes[w.nodeID] = n
}
case del:
delete(nodes, w.nodeID)
}
}
newSnap := snapshotFromNodes(nodes, s.peersFunc)
s.data.Store(&newSnap)
for _, w := range batch {
close(w.result)
}
}
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
allNodes := make([]types.NodeView, 0, len(nodes))
for _, n := range nodes {
allNodes = append(allNodes, n.View())
}
newSnap := Snapshot{
nodesByID: nodes,
allNodes: allNodes,
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
peersByNode: peersFunc(allNodes),
nodesByUser: make(map[types.UserID][]types.NodeView),
}
// Build nodesByUser and nodesByNodeKey maps
for _, n := range nodes {
nodeView := n.View()
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
}
return newSnap
}
// GetNode retrieves a node by its ID.
func (s *NodeStore) GetNode(id types.NodeID) types.NodeView {
n := s.data.Load().nodesByID[id]
return n.View()
}
// GetNodeByNodeKey retrieves a node by its NodeKey.
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
return s.data.Load().nodesByNodeKey[nodeKey]
}
// ListNodes returns a slice of all nodes in the store.
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
return views.SliceOf(s.data.Load().allNodes)
}
// ListPeers returns a slice of all peers for a given node ID.
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
return views.SliceOf(s.data.Load().peersByNode[id])
}
// ListNodesByUser returns a slice of all nodes for a given user ID.
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
return views.SliceOf(s.data.Load().nodesByUser[uid])
}

View File

@ -0,0 +1,494 @@
package state
import (
"net/netip"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
)
func TestSnapshotFromNodes(t *testing.T) {
tests := []struct {
name string
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
}{
{
name: "empty nodes",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := make(map[types.NodeID]types.Node)
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
return make(map[types.NodeID][]types.NodeView)
}
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "single node",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
assert.Len(t, snapshot.nodesByUser[1], 1)
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
},
},
{
name: "multiple nodes same user",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 1, "user1", "node2"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Each node sees the other as peer (but not itself)
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "multiple nodes different users",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 1, "user1", "node3"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// Each node should have 2 peers (all others, but not itself)
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
},
},
{
name: "odd-even peers filtering",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 3, "user3", "node3"),
4: createTestNode(4, 4, "user4", "node4"),
}
peersFunc := oddEvenPeersFunc
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
assert.Len(t, snapshot.nodesByID, 4)
assert.Len(t, snapshot.allNodes, 4)
assert.Len(t, snapshot.peersByNode, 4)
assert.Len(t, snapshot.nodesByUser, 4)
// Odd nodes should only see other odd nodes as peers
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// Even nodes should only see other even nodes as peers
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nodes, peersFunc := tt.setupFunc()
snapshot := snapshotFromNodes(nodes, peersFunc)
tt.validate(t, nodes, snapshot)
})
}
}
// Helper functions
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
now := time.Now()
machineKey := key.NewMachine()
nodeKey := key.NewNode()
discoKey := key.NewDisco()
ipv4 := netip.MustParseAddr("100.64.0.1")
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
return types.Node{
ID: nodeID,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: userID,
User: types.User{
Name: username,
DisplayName: username,
},
RegisterMethod: "test",
IPv4: &ipv4,
IPv6: &ipv6,
CreatedAt: now,
UpdatedAt: now,
}
}
// Peer functions
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
nodeIsOdd := node.ID()%2 == 1
for _, n := range nodes {
if n.ID() == node.ID() {
continue
}
peerIsOdd := n.ID()%2 == 1
// Only add peer if both are odd or both are even
if nodeIsOdd == peerIsOdd {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func TestNodeStoreOperations(t *testing.T) {
tests := []struct {
name string
setupFunc func(t *testing.T) *NodeStore
steps []testStep
}{
{
name: "create empty store and add single node",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify empty store",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "add first node",
action: func(store *NodeStore) {
node := createTestNode(1, 1, "user1", "node1")
store.PutNode(node)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
assert.Len(t, snapshot.nodesByUser[1], 1)
},
},
},
},
{
name: "create store with initial node and add more",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
initialNodes := types.Nodes{&node1}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial state",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
assert.Empty(t, snapshot.peersByNode[1])
},
},
{
name: "add second node same user",
action: func(store *NodeStore) {
node2 := createTestNode(2, 1, "user1", "node2")
store.PutNode(node2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Now both nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "add third node different user",
action: func(store *NodeStore) {
node3 := createTestNode(3, 2, "user2", "node3")
store.PutNode(node3)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// All nodes should see the other 2 as peers
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
},
},
},
},
{
name: "test node deletion",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
node3 := createTestNode(3, 2, "user2", "node3")
initialNodes := types.Nodes{&node1, &node2, &node3}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial 3 nodes",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
},
},
{
name: "delete middle node",
action: func(store *NodeStore) {
store.DeleteNode(2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 2)
// Node 2 should be gone
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
// Remaining nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// User groupings updated
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
},
},
{
name: "delete all remaining nodes",
action: func(store *NodeStore) {
store.DeleteNode(1)
store.DeleteNode(3)
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
},
},
{
name: "test node updates",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial hostnames",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
},
},
{
name: "update node hostname",
action: func(store *NodeStore) {
store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "updated-node1"
n.GivenName = "updated-node1"
})
snapshot := store.data.Load()
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
// Peers should still work correctly
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Len(t, snapshot.peersByNode[2], 1)
},
},
},
},
{
name: "test with odd-even peers filtering",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, oddEvenPeersFunc)
},
steps: []testStep{
{
name: "add nodes with odd-even filtering",
action: func(store *NodeStore) {
// Add nodes in sequence
store.PutNode(createTestNode(1, 1, "user1", "node1"))
store.PutNode(createTestNode(2, 2, "user2", "node2"))
store.PutNode(createTestNode(3, 3, "user3", "node3"))
store.PutNode(createTestNode(4, 4, "user4", "node4"))
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 4)
// Verify odd-even peer relationships
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
{
name: "delete odd node and verify even nodes unaffected",
action: func(store *NodeStore) {
store.DeleteNode(1)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
// Node 3 (odd) should now have no peers
assert.Empty(t, snapshot.peersByNode[3])
// Even nodes should still see each other
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := tt.setupFunc(t)
store.Start()
defer store.Stop()
for _, step := range tt.steps {
t.Run(step.name, func(t *testing.T) {
step.action(store)
})
}
})
}
}
type testStep struct {
name string
action func(store *NodeStore)
}

View File

@ -27,6 +27,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
zcache "zgo.at/zcache/v2"
)
@ -49,6 +50,9 @@ type State struct {
// cfg holds the current Headscale configuration
cfg *types.Config
// nodeStore provides an in-memory cache for nodes.
nodeStore *NodeStore
// subsystem keeping state
// db provides persistent storage and database operations
db *hsdb.HSDatabase
@ -107,6 +111,12 @@ func NewState(cfg *types.Config) (*State, error) {
return nil, fmt.Errorf("init policy manager: %w", err)
}
nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
_, matchers := polMan.Filter()
return policy.BuildPeerMap(views.SliceOf(nodes), matchers)
})
nodeStore.Start()
return &State{
cfg: cfg,
@ -117,11 +127,14 @@ func NewState(cfg *types.Config) (*State, error) {
polMan: polMan,
registrationCache: registrationCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
}, nil
}
// Close gracefully shuts down the State instance and releases all resources.
func (s *State) Close() error {
s.nodeStore.Stop()
if err := s.db.Close(); err != nil {
return fmt.Errorf("closing database: %w", err)
}
@ -204,43 +217,42 @@ func (s *State) AutoApproveNodes() error {
}
// CreateUser creates a new user and updates the policy manager.
// Returns the created user, whether policies changed, and any error.
func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
// Returns the created user, change set, and any error.
func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(&user).Error; err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerUsers()
c, err := s.updatePolicyManagerUsers()
if err != nil {
// Log the error but don't fail the user creation
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err)
}
// Even if the policy manager doesn't detect a filter change, SSH policies
// might now be resolvable when they weren't before. If there are existing
// nodes, we should send a policy change to ensure they get updated SSH policies.
if !policyChanged {
if c.Empty() {
nodes, err := s.ListNodes()
if err == nil && len(nodes) > 0 {
policyChanged = true
if err == nil && nodes.Len() > 0 {
c = change.PolicyChange()
}
}
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
log.Info().Str("user", user.Name).Bool("policyChanged", !c.Empty()).Msg("User created, policy manager updated")
// TODO(kradalby): implement the user in-memory cache
return &user, policyChanged, nil
return &user, c, nil
}
// UpdateUser modifies an existing user using the provided update function within a transaction.
// Returns the updated user, whether policies changed, and any error.
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, bool, error) {
// Returns the updated user, change set, and any error.
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -261,18 +273,18 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
return user, nil
})
if err != nil {
return nil, false, err
return nil, change.EmptySet, err
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerUsers()
c, err := s.updatePolicyManagerUsers()
if err != nil {
return user, false, fmt.Errorf("failed to update policy manager after user update: %w", err)
return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err)
}
// TODO(kradalby): implement the user in-memory cache
return user, policyChanged, nil
return user, c, nil
}
// DeleteUser permanently removes a user and all associated data (nodes, API keys, etc).
@ -282,7 +294,7 @@ func (s *State) DeleteUser(userID types.UserID) error {
}
// RenameUser changes a user's name. The new name must be unique.
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, bool, error) {
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) {
return s.UpdateUser(userID, func(user *types.User) error {
user.Name = newName
return nil
@ -315,28 +327,31 @@ func (s *State) ListAllUsers() ([]types.User, error) {
}
// CreateNode creates a new node and updates the policy manager.
// Returns the created node, whether policies changed, and any error.
func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Returns the created node, change set, and any error.
func (s *State) CreateNode(node *types.Node) (types.NodeView, change.ChangeSet, error) {
s.nodeStore.PutNode(*node)
if err := s.db.DB.Save(node).Error; err != nil {
return nil, false, fmt.Errorf("creating node: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("creating node: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
c, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node creation: %w", err)
return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node creation: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
if c.Empty() {
c = change.NodeAdded(node.ID)
}
return node.View(), c, nil
}
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (types.NodeView, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -357,70 +372,72 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
return node, nil
})
if err != nil {
return nil, change.EmptySet, err
return types.NodeView{}, change.EmptySet, err
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
c, err := s.updatePolicyManagerNodes()
if err != nil {
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
var c change.ChangeSet
if policyChanged {
c = change.PolicyChange()
} else {
if c.Empty() {
// Basic node change without specific details since this is a generic update
c = change.NodeAdded(node.ID)
}
return node.View(), c, nil
}
// SaveNode persists an existing node to the database and updates the policy manager.
func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
nodePtr := node.AsStruct()
s.nodeStore.PutNode(*nodePtr)
if err := s.db.DB.Save(nodePtr).Error; err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err)
}
// Check if policy manager needs updating
c, err := s.updatePolicyManagerNodes()
if err != nil {
return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
if c.Empty() {
c = change.NodeAdded(node.ID())
}
return node, c, nil
}
// SaveNode persists an existing node to the database and updates the policy manager.
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(node).Error; err != nil {
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
if policyChanged {
return node, change.PolicyChange(), nil
}
return node, change.EmptySet, nil
}
// DeleteNode permanently removes a node and cleans up associated resources.
// Returns whether policies changed and any error. This operation is irreversible.
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
err := s.db.DeleteNode(node)
func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
s.nodeStore.DeleteNode(node.ID())
err := s.db.DeleteNode(node.AsStruct())
if err != nil {
return change.EmptySet, err
}
c := change.NodeRemoved(node.ID)
c := change.NodeRemoved(node.ID())
// Check if policy manager needs updating after node deletion
policyChanged, err := s.updatePolicyManagerNodes()
policyChange, err := s.updatePolicyManagerNodes()
if err != nil {
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
}
if policyChanged {
c = change.PolicyChange()
if !policyChange.Empty() {
c = policyChange
}
return c, nil
@ -434,12 +451,26 @@ func (s *State) Connect(node *types.Node) change.ChangeSet {
c = change.NodeAdded(node.ID)
}
// Update nodestore with online status - node is connecting so it's online
// Use immediate update to ensure online status changes are not delayed by batching
s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) {
// Set the online status in the node's ephemeral field
n.IsOnline = ptr.To(true)
})
return c
}
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
c := change.NodeOffline(node.ID)
// Update nodestore with offline status
// Use immediate update to ensure online status changes are not delayed by batching
s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) {
// Set the online status to false in the node's ephemeral field
n.IsOnline = ptr.To(false)
})
_, _, err := s.SetLastSeen(node.ID, time.Now())
if err != nil {
return c, fmt.Errorf("disconnecting node: %w", err)
@ -454,70 +485,95 @@ func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
}
// GetNodeByID retrieves a node by ID.
func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) {
return s.db.GetNodeByID(nodeID)
}
// GetNodeViewByID retrieves a node view by ID.
func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
node, err := s.db.GetNodeByID(nodeID)
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, error) {
return s.nodeStore.GetNode(nodeID), nil
}
// GetNodeByNodeKey retrieves a node by its Tailscale public key.
func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return s.db.GetNodeByNodeKey(nodeKey)
}
// GetNodeViewByNodeKey retrieves a node view by its Tailscale public key.
func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) {
node, err := s.db.GetNodeByNodeKey(nodeKey)
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) {
return s.nodeStore.GetNodeByNodeKey(nodeKey), nil
}
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
func (s *State) ListNodes(nodeIDs ...types.NodeID) (views.Slice[types.NodeView], error) {
if len(nodeIDs) == 0 {
return s.db.ListNodes()
return s.nodeStore.ListNodes(), nil
}
return s.db.ListNodes(nodeIDs...)
// Filter nodes by the requested IDs
allNodes := s.nodeStore.ListNodes()
nodeIDSet := make(map[types.NodeID]struct{}, len(nodeIDs))
for _, id := range nodeIDs {
nodeIDSet[id] = struct{}{}
}
var filteredNodes []types.NodeView
for _, node := range allNodes.All() {
if _, exists := nodeIDSet[node.ID()]; exists {
filteredNodes = append(filteredNodes, node)
}
}
return views.SliceOf(filteredNodes), nil
}
// ListNodesByUser retrieves all nodes belonging to a specific user.
func (s *State) ListNodesByUser(userID types.UserID) (types.Nodes, error) {
return hsdb.Read(s.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return hsdb.ListNodesByUser(rx, userID)
})
func (s *State) ListNodesByUser(userID types.UserID) (views.Slice[types.NodeView], error) {
return s.nodeStore.ListNodesByUser(userID), nil
}
// ListPeers retrieves nodes that can communicate with the specified node based on policy.
func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
return s.db.ListPeers(nodeID, peerIDs...)
func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (views.Slice[types.NodeView], error) {
if len(peerIDs) == 0 {
return s.nodeStore.ListPeers(nodeID), nil
}
// For specific peerIDs, filter from all nodes
allNodes := s.nodeStore.ListNodes()
nodeIDSet := make(map[types.NodeID]struct{}, len(peerIDs))
for _, id := range peerIDs {
nodeIDSet[id] = struct{}{}
}
var filteredNodes []types.NodeView
for _, node := range allNodes.All() {
if _, exists := nodeIDSet[node.ID()]; exists {
filteredNodes = append(filteredNodes, node)
}
}
return views.SliceOf(filteredNodes), nil
}
// ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system.
func (s *State) ListEphemeralNodes() (types.Nodes, error) {
return s.db.ListEphemeralNodes()
func (s *State) ListEphemeralNodes() (views.Slice[types.NodeView], error) {
allNodes := s.nodeStore.ListNodes()
var ephemeralNodes []types.NodeView
for _, node := range allNodes.All() {
// Check if node is ephemeral by checking its AuthKey
if node.AuthKey().Valid() && node.AuthKey().Ephemeral() {
ephemeralNodes = append(ephemeralNodes, node)
}
}
return views.SliceOf(ephemeralNodes), nil
}
// SetNodeExpiry updates the expiration time for a node.
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
}
// Update nodestore with the same change
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
n.Expiry = &expiry
})
if !c.IsFull() {
c = change.KeyExpiry(nodeID)
}
@ -526,14 +582,19 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Nod
}
// SetNodeTags assigns tags to a node for use in access control policies.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetTags(tx, nodeID, tags)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
}
// Update nodestore with the same change
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
n.ForcedTags = tags
})
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
@ -542,16 +603,21 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, ch
}
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
}
// Update nodestore with the same change
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
n.ApprovedRoutes = routes
})
// Update primary routes after changing approved routes
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.AsStruct().SubnetRoutes()...)
if routeChange || !c.IsFull() {
c = change.PolicyChange()
@ -561,14 +627,19 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*
}
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.RenameNode(tx, nodeID, newName)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
}
// Update nodestore with the same change
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
n.GivenName = newName
})
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
@ -577,21 +648,19 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, ch
}
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (types.NodeView, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
})
}
// AssignNodeToUser transfers a node to a different user.
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.AssignNodeToUser(tx, nodeID, userID)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting last seen: %w", err)
}
// Update nodestore with the same change
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
n.LastSeen = &lastSeen
})
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
@ -599,6 +668,32 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*typ
return n, c, nil
}
// AssignNodeToUser transfers a node to a different user.
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (types.NodeView, change.ChangeSet, error) {
node, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.AssignNodeToUser(tx, nodeID, userID)
})
if err != nil {
return types.NodeView{}, change.EmptySet, err
}
// Update nodestore with the same change
// Get the updated user information from the database
user, err := s.GetUserByID(userID)
if err == nil {
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
n.UserID = uint(userID)
n.User = *user
})
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return node, c, nil
}
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
func (s *State) BackfillNodeIPs() ([]string, error) {
return s.db.BackfillNodeIPs(s.ipAlloc)
@ -631,8 +726,16 @@ func (s *State) SetPolicy(pol []byte) (bool, error) {
}
// AutoApproveRoutes checks if a node's routes should be auto-approved.
func (s *State) AutoApproveRoutes(node *types.Node) bool {
return policy.AutoApproveRoutes(s.polMan, node)
func (s *State) AutoApproveRoutes(node types.NodeView) bool {
nodePtr := node.AsStruct()
changed := policy.AutoApproveRoutes(s.polMan, nodePtr)
if changed {
s.nodeStore.PutNode(*nodePtr)
// Update primaryRoutes manager with the newly approved routes
// This is essential for actual packet forwarding to work
s.primaryRoutes.SetRoutes(nodePtr.ID, nodePtr.SubnetRoutes()...)
}
return changed
}
// PolicyDebugString returns a debug representation of the current policy.
@ -742,34 +845,42 @@ func (s *State) HandleNodeFromAuthPath(
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (*types.Node, change.ChangeSet, error) {
) (types.NodeView, change.ChangeSet, error) {
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, change.EmptySet, err
return types.NodeView{}, change.EmptySet, err
}
return s.db.HandleNodeFromAuthPath(
node, nodeChange, err := s.db.HandleNodeFromAuthPath(
registrationID,
userID,
expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
)
if err != nil {
return types.NodeView{}, change.EmptySet, err
}
// Update nodestore with the newly registered/updated node
s.nodeStore.PutNode(*node)
return node.View(), nodeChange, nil
}
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
func (s *State) HandleNodeFromPreAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*types.Node, change.ChangeSet, bool, error) {
) (types.NodeView, change.ChangeSet, error) {
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, change.EmptySet, false, err
return types.NodeView{}, change.EmptySet, err
}
err = pak.Validate()
if err != nil {
return nil, change.EmptySet, false, err
return types.NodeView{}, change.EmptySet, err
}
nodeToRegister := types.Node{
@ -796,7 +907,7 @@ func (s *State) HandleNodeFromPreAuthKey(
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
}
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
@ -818,38 +929,44 @@ func (s *State) HandleNodeFromPreAuthKey(
return node, nil
})
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
}
// Check if this is a logout request for an ephemeral node
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
// This is a logout request for an ephemeral node, delete it immediately
c, err := s.DeleteNode(node)
c, err := s.DeleteNode(node.View())
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
}
return nil, c, false, nil
return types.NodeView{}, c, nil
}
// Check if policy manager needs updating
// This is necessary because we just created a new node.
// We need to ensure that the policy manager is aware of this new node.
// Also update users to ensure all users are known when evaluating policies.
usersChanged, err := s.updatePolicyManagerUsers()
usersChange, err := s.updatePolicyManagerUsers()
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
}
nodesChanged, err := s.updatePolicyManagerNodes()
nodesChange, err := s.updatePolicyManagerNodes()
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
}
policyChanged := usersChanged || nodesChanged
var c change.ChangeSet
if !usersChange.Empty() || !nodesChange.Empty() {
c = change.PolicyChange()
} else {
c = change.NodeAdded(node.ID)
}
c := change.NodeAdded(node.ID)
// Update nodestore with the newly registered node
s.nodeStore.PutNode(*node)
return node, c, policyChanged, nil
return node.View(), c, nil
}
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
@ -863,22 +980,25 @@ func (s *State) AllocateNextIPs() (*netip.Addr, *netip.Addr, error) {
// have the list already available so it could go much quicker. Alternatively
// the policy manager could have a remove or add list for users.
// updatePolicyManagerUsers refreshes the policy manager with current user data.
func (s *State) updatePolicyManagerUsers() (bool, error) {
func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
users, err := s.ListAllUsers()
if err != nil {
return false, fmt.Errorf("listing users for policy update: %w", err)
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
}
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
changed, err := s.polMan.SetUsers(users)
if err != nil {
return false, fmt.Errorf("updating policy manager users: %w", err)
return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err)
}
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
return changed, nil
if changed {
return change.PolicyChange(), nil
}
return change.EmptySet, nil
}
// updatePolicyManagerNodes updates the policy manager with current nodes.
@ -887,18 +1007,21 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
// have the list already available so it could go much quicker. Alternatively
// the policy manager could have a remove or add list for nodes.
// updatePolicyManagerNodes refreshes the policy manager with current node data.
func (s *State) updatePolicyManagerNodes() (bool, error) {
func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) {
nodes, err := s.ListNodes()
if err != nil {
return false, fmt.Errorf("listing nodes for policy update: %w", err)
return change.EmptySet, fmt.Errorf("listing nodes for policy update: %w", err)
}
changed, err := s.polMan.SetNodes(nodes.ViewSlice())
changed, err := s.polMan.SetNodes(nodes)
if err != nil {
return false, fmt.Errorf("updating policy manager nodes: %w", err)
return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err)
}
return changed, nil
if changed {
return change.PolicyChange(), nil
}
return change.EmptySet, nil
}
// PingDB checks if the database connection is healthy.
@ -923,6 +1046,9 @@ func (s *State) autoApproveNodes() error {
// TODO(kradalby): This change should probably be sent to the rest of the system.
changed := policy.AutoApproveRoutes(s.polMan, node)
if changed {
// Update nodestore first if available
s.nodeStore.PutNode(*node)
err = tx.Save(node).Error
if err != nil {
return err
@ -985,7 +1111,7 @@ func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapReques
if routesChanged {
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
_ = s.AutoApproveRoutes(node)
_ = s.AutoApproveRoutes(node.View())
// Update the routes of the given node in the route manager to
// see if an update needs to be sent.
@ -998,7 +1124,7 @@ func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapReques
// the hostname change.
node.ApplyHostnameFromHostInfo(req.Hostinfo)
_, policyChange, err := s.SaveNode(node)
_, policyChange, err := s.SaveNode(node.View())
if err != nil {
return change.EmptySet, err
}

View File

@ -13,6 +13,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/tsaddr"
@ -511,11 +512,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
}
if node.Hostname != hostInfo.Hostname {
log.Trace().
Str("node_id", node.ID.String()).
Str("old_hostname", node.Hostname).
Str("new_hostname", hostInfo.Hostname).
Str("old_given_name", node.GivenName).
Bool("given_name_changed", node.GivenNameHasBeenChanged()).
Msg("Updating hostname from hostinfo")
if node.GivenNameHasBeenChanged() {
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
}
node.Hostname = hostInfo.Hostname
log.Trace().
Str("node_id", node.ID.String()).
Str("new_hostname", node.Hostname).
Str("new_given_name", node.GivenName).
Msg("Hostname updated")
}
}
@ -759,6 +774,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix {
return v.ж.ExitRoutes()
}
// RequestTags returns the ACL tags that the node is requesting.
func (v NodeView) RequestTags() []string {
if !v.Valid() || !v.Hostinfo().Valid() {
return []string{}
}
return v.Hostinfo().RequestTags().AsSlice()
}
// Proto converts the NodeView to a protobuf representation.
func (v NodeView) Proto() *v1.Node {
if !v.Valid() {
return nil
}
return v.ж.Proto()
}
// HasIP reports if a node has a given IP address.
func (v NodeView) HasIP(i netip.Addr) bool {
if !v.Valid() {

View File

@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
}
// Parse each hop line
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i])

View File

@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}
assertTailscaleNodesLogout(t, allClients)
}, shortAccessTTL+10*time.Second, 5*time.Second)
}

View File

@ -535,6 +535,8 @@ func TestUpdateHostnameFromClient(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
var nodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err := executeAndUnmarshal(
@ -630,27 +632,34 @@ func TestUpdateHostnameFromClient(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"node",
"list",
"--output",
"json",
},
&nodes,
)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
assert.Eventually(t, func() bool {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"node",
"list",
"--output",
"json",
},
&nodes,
)
assertNoErr(t, err)
assert.Len(t, nodes, 3)
if err != nil || len(nodes) != 3 {
return false
}
for _, node := range nodes {
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
givenName := fmt.Sprintf("%d-givenname", node.GetId())
assert.Equal(t, hostname+"NEW", node.GetName())
assert.Equal(t, givenName, node.GetGivenName())
}
for _, node := range nodes {
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
givenName := fmt.Sprintf("%d-givenname", node.GetId())
if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName {
return false
}
}
return true
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix")
}
func TestExpireNode(t *testing.T) {

View File

@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) {
assert.Len(t, node.GetSubnetRoutes(), 1)
}
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate to clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assert.NoError(c, err)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.NotNil(t, peerStatus.PrimaryRoutes)
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
assert.NotNil(c, peerStatus.PrimaryRoutes)
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
}
}
}
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
_, err = headscale.ApproveRoutes(
1,
@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate to nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
for _, node := range nodes {
if node.GetId() == 1 {
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
assert.Empty(t, node.GetSubnetRoutes())
} else if node.GetId() == 2 {
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
assert.Empty(t, node.GetApprovedRoutes())
assert.Empty(t, node.GetSubnetRoutes())
} else {
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
for _, node := range nodes {
if node.GetId() == 1 {
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
assert.Empty(c, node.GetSubnetRoutes())
} else if node.GetId() == 2 {
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
assert.Empty(c, node.GetApprovedRoutes())
assert.Empty(c, node.GetSubnetRoutes())
} else {
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
}
}
}
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
// Verify that the clients can see the new routes
for _, client := range allClients {
@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
time.Sleep(3 * time.Second)
// Wait for route configuration changes after advertising routes
var nodes []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err := headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
}, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved")
// Verify that no routes has been sent to the client,
// they are not yet enabled.
@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(3 * time.Second)
// Wait for route approval on first subnet router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
}, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route")
// Verify that the client has routes from the primary machine and can access
// the webservice.
@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(3 * time.Second)
// Wait for route approval on second subnet router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
}, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route")
// Verify that the client has routes from the primary machine
srs1 = subRouter1.MustStatus()
@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(3 * time.Second)
// Wait for route approval on third subnet router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 1, 1, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0)
}, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route")
// Verify that the client has routes from the primary machine
srs1 = subRouter1.MustStatus()
@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, result, 13)
tr, err = client.Traceroute(webip)
require.NoError(t, err)
assertTracerouteViaIP(t, tr, subRouter1.MustIPv4())
// Wait for traceroute to work correctly through the expected router
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
// Get the expected router IP - use a more robust approach to handle temporary disconnections
ips, err := subRouter1.IPs()
assert.NoError(c, err)
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
var expectedIP netip.Addr
for _, ip := range ips {
if ip.Is4() {
expectedIP = ip
break
}
}
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
}, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1")
// Take down the current primary
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter1.Down()
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for router status changes after r1 goes down
assert.EventuallyWithT(t, func(c *assert.CollectT) {
srs2 = subRouter2.MustStatus()
clientStatus = client.MustStatus()
srs2 = subRouter2.MustStatus()
clientStatus = client.MustStatus()
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up")
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
require.NotNil(t, srs2PeerStatus.PrimaryRoutes)
@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter2.Down()
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for router status changes after r2 goes down
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientStatus, err = client.Status()
assert.NoError(c, err)
// TODO(kradalby): Check client status
// Both are expected to be down
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
// Verify that the route is not presented from either router
clientStatus, err = client.Status()
require.NoError(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down")
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
assert.False(c, srs2PeerStatus.Online, "r2 should be offline")
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
}, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter1.Up()
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for router status changes after r1 comes back up
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientStatus, err = client.Status()
assert.NoError(c, err)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
require.NoError(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available")
assert.True(c, srs1PeerStatus.Online, "r1 should be back online")
assert.False(c, srs2PeerStatus.Online, "r2 should still be offline")
assert.True(c, srs3PeerStatus.Online, "r3 should still be online")
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) {
err = subRouter2.Up()
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for nodestore batch processing to complete and online status to be updated
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientStatus, err = client.Status()
assert.NoError(c, err)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
require.NoError(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up")
assert.True(c, srs1PeerStatus.Online, "r1 should be online")
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
}, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second)
// Wait for nodestore batch processing and route state changes to complete
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// After disabling route on r3, r1 should become primary with 1 subnet route
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3")
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second)
// Wait for nodestore batch processing and route state changes to complete
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// After disabling route on r1, r2 should become primary with 1 subnet route
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1")
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
)
time.Sleep(5 * time.Second)
// Wait for route state changes after re-enabling r1
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
}, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping")
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate to nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
require.Len(t, nodes, 2)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 0, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
// Verify that the client has routes from the primary machine
srs1, _ := subRouter1.Status()
@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) {
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
requireNodeRouteCount(t, nodes[1], 2, 2, 2)
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate to clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assert.NoError(c, err)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
require.NotNil(t, peerStatus.AllowedIPs)
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4)
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
assert.NotNil(c, peerStatus.AllowedIPs)
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
}
}
}
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
}
// TestSubnetRouterMultiNetwork is an evolution of the subnet router test.
@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate to nodes and clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
// Verify that the routes have been sent to the client
status, err = user2c.Status()
assert.NoError(c, err)
// Verify that the routes have been sent to the client.
status, err = user2c.Status()
require.NoError(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref)
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref})
}
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
}
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
usernet1, err := scenario.Network("usernet1")
require.NoError(t, err)
@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
require.NoError(t, err)
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate to nodes and clients
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
// Verify that the routes have been sent to the client
status, err = user2c.Status()
assert.NoError(c, err)
// Verify that the routes have been sent to the client.
status, err = user2c.Status()
require.NoError(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
}
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
// Tell user2c to use user1c as an exit node.
command = []string{
@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
var nodes []*v1.Node
opts := []hsic.Option{
hsic.WithTestName("autoapprovemulti"),
hsic.WithEmbeddedDERPServerOnly(),
@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
require.NoErrorf(t, err, "failed to advertise route: %s", err)
}
time.Sleep(5 * time.Second)
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err := headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
status, err := client.Status()
@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = headscale.SetPolicy(tt.pol)
require.NoError(t, err)
time.Sleep(5 * time.Second)
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
status, err = client.Status()
@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(5 * time.Second)
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
status, err = client.Status()
@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = headscale.SetPolicy(tt.pol)
require.NoError(t, err)
time.Sleep(5 * time.Second)
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
status, err = client.Status()
@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
_, _, err = routerSubRoute.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err)
time.Sleep(5 * time.Second)
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
// Verify that the routes have been sent to the client.
@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
_, _, err = routerSubRoute.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err)
time.Sleep(5 * time.Second)
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 0, 0, 0)
@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
_, _, err = routerExitNode.Execute(command)
require.NoErrorf(t, err, "failed to advertise route: %s", err)
time.Sleep(5 * time.Second)
nodes, err = headscale.ListNodes()
require.NoError(t, err)
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 2, 2, 2)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
status, err = client.Status()
@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) {
require.Equal(t, tr.Route[0].IP, ip)
}
// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT
func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) {
assert.NotNil(c, tr)
assert.True(c, tr.Success)
assert.NoError(c, tr.Err)
assert.NotEmpty(c, tr.Route)
assert.Equal(c, tr.Route[0].IP, ip)
}
// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes.
func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) {
t.Helper()
@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
}
}
func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) {
if status.AllowedIPs.Len() <= 2 && len(expected) != 0 {
assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected))
return
}
if len(expected) == 0 {
expected = []netip.Prefix{}
}
got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool {
if tsaddr.IsExitRoute(p) {
return true
}
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
})
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff))
}
}
func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
t.Helper()
require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
}
func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) {
assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
}
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
// that are explicitly allowed in the ACL.
func TestSubnetRouteACLFiltering(t *testing.T) {
@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
)
require.NoError(t, err)
// Give some time for the routes to propagate
time.Sleep(5 * time.Second)
// Wait for route state changes to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// List nodes and verify the router has 3 available routes
nodes, err = headscale.NodesByUser()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
// List nodes and verify the router has 3 available routes
nodes, err = headscale.NodesByUser()
require.NoError(t, err)
require.Len(t, nodes, 2)
// Find the router node
routerNode = nodes[routerUser][0]
// Find the router node
routerNode = nodes[routerUser][0]
// Check that the router has 3 routes now approved and available
requireNodeRouteCount(t, routerNode, 3, 3, 3)
// Check that the router has 3 routes now approved and available
requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Now check the client node status
nodeStatus, err := nodeClient.Status()

View File

@ -14,7 +14,6 @@ import (
"net/netip"
"net/url"
"os"
"sort"
"strconv"
"strings"
"sync"
@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
return nil, fmt.Errorf("no network named: %s", name)
}
for _, ipam := range net.Network.IPAM.Config {
pref, err := netip.ParsePrefix(ipam.Subnet)
if err != nil {
return nil, err
}
return &pref, nil
if len(net.Network.IPAM.Config) == 0 {
return nil, fmt.Errorf("no IPAM config found in network: %s", name)
}
return nil, fmt.Errorf("no prefix found in network: %s", name)
pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet)
if err != nil {
return nil, err
}
return &pref, nil
}
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv(
return err
}
sort.Strings(s.spec.Users)
for _, user := range s.spec.Users {
u, err := s.CreateUser(user)
if err != nil {