diff --git a/hscontrol/app.go b/hscontrol/app.go index ec8e2550..d7d97fc9 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -548,8 +548,8 @@ func (h *Headscale) Serve() error { if err != nil { return fmt.Errorf("failed to list ephemeral nodes: %w", err) } - for _, node := range ephmNodes { - h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout) + for _, node := range ephmNodes.All() { + h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) } if h.cfg.DNSConfig.ExtraRecordsPath != "" { diff --git a/hscontrol/auth.go b/hscontrol/auth.go index cb284173..f06864e7 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -34,7 +34,7 @@ func (h *Headscale) handleRegister( return nil, fmt.Errorf("looking up node in database: %w", err) } - if node != nil { + if node.Valid() { // If an existing node is trying to register with an auth key, // we need to validate the auth key even for existing nodes if regReq.Auth != nil && regReq.Auth.AuthKey != "" { @@ -50,7 +50,7 @@ func (h *Headscale) handleRegister( return resp, nil } - resp, err := h.handleExistingNode(node, regReq, machineKey) + resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey) if err != nil { return nil, fmt.Errorf("handling existing node: %w", err) } @@ -107,7 +107,7 @@ func (h *Headscale) handleExistingNode( // If the request expiry is in the past, we consider it a logout. if requestExpiry.Before(time.Now()) { if node.IsEphemeral() { - c, err := h.state.DeleteNode(node) + c, err := h.state.DeleteNode(node.View()) if err != nil { return nil, fmt.Errorf("deleting ephemeral node: %w", err) } @@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode( } h.Change(c) - } + } - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(node), nil } func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { @@ -177,7 +177,7 @@ func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey( + node, changed, err := h.state.HandleNodeFromPreAuthKey( regReq, machineKey, ) @@ -194,7 +194,7 @@ func (h *Headscale) handleRegisterWithAuthKey( } // If node is nil, it means an ephemeral node was deleted during logout - if node == nil { + if node.Valid() { h.Change(changed) return nil, nil } @@ -218,21 +218,23 @@ func (h *Headscale) handleRegisterWithAuthKey( } if routeChange && changed.Empty() { - changed = change.NodeAdded(node.ID) + changed = change.NodeAdded(node.ID()) } h.Change(changed) - // If policy changed due to node registration, send a separate policy change - if policyChanged { - policyChange := change.PolicyChange() - h.Change(policyChange) - } + // TODO(kradalby): I think this is covered above, but we need to validate that. + // // If policy changed due to node registration, send a separate policy change + // if policyChanged { + // policyChange := change.PolicyChange() + // h.Change(policyChange) + // } + user := node.User() return &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), - User: *node.User.TailscaleUser(), - Login: *node.User.TailscaleLogin(), + User: *user.TailscaleUser(), + Login: *user.TailscaleLogin(), }, nil } diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 481ce589..7583c5ad 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { } sshPol := make(map[string]*tailcfg.SSHPolicy) - for _, node := range nodes { - pol, err := h.state.SSHPolicy(node.View()) + for _, node := range nodes.All() { + pol, err := h.state.SSHPolicy(node) if err != nil { httpError(w, err) return } - sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol + sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol } sshJSON, err := json.MarshalIndent(sshPol, "", " ") diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 722f8421..4ef52106 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -25,6 +25,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/views" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/state" @@ -59,9 +60,10 @@ func (api headscaleV1APIServer) CreateUser( return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } - c := change.UserAdded(types.UserID(user.ID)) - if policyChanged { + + // TODO(kradalby): Both of these might be policy changes, find a better way to merge. + if !policyChanged.Empty() { c.Change = change.Policy } @@ -79,15 +81,13 @@ func (api headscaleV1APIServer) RenameUser( return nil, err } - _, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) + _, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } // Send policy update notifications if needed - if policyChanged { - api.h.Change(change.PolicyChange()) - } + api.h.Change(c) newUser, err := api.h.state.GetUserByName(request.GetNewName()) if err != nil { @@ -297,7 +297,7 @@ func (api headscaleV1APIServer) GetNode( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.mapBatcher.IsConnected(node.ID) + resp.Online = api.h.mapBatcher.IsConnected(node.ID()) return &v1.GetNodeResponse{Node: resp}, nil } @@ -323,7 +323,7 @@ func (api headscaleV1APIServer) SetTags( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Str("node", node.Hostname()). Strs("tags", request.GetTags()). Msg("Changing tags of node") @@ -357,7 +357,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes( return nil, status.Error(codes.InvalidArgument, err.Error()) } - routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) + routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...) // Always propagate node changes from SetApprovedRoutes api.h.Change(nodeChange) @@ -368,7 +368,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes( } proto := node.Proto() - proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID)) + proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID())) return &v1.SetApprovedRoutesResponse{Node: proto}, nil } @@ -420,8 +420,8 @@ func (api headscaleV1APIServer) ExpireNode( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). - Time("expiry", *node.Expiry). + Str("node", node.Hostname()). + Time("expiry", *node.AsStruct().Expiry). Msg("node expired") return &v1.ExpireNodeResponse{Node: node.Proto()}, nil @@ -440,7 +440,7 @@ func (api headscaleV1APIServer) RenameNode( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Str("node", node.Hostname()). Str("new_name", request.GetNewName()). Msg("node renamed") @@ -477,36 +477,36 @@ func (api headscaleV1APIServer) ListNodes( return nil, err } - sort.Slice(nodes, func(i, j int) bool { - return nodes[i].ID < nodes[j].ID - }) - response := nodesToProto(api.h.state, IsConnected, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } -func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node { - response := make([]*v1.Node, len(nodes)) - for index, node := range nodes { +func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID, bool], nodes views.Slice[types.NodeView]) []*v1.Node { + response := make([]*v1.Node, nodes.Len()) + for index, node := range nodes.All() { resp := node.Proto() // Populate the online field based on // currently connected nodes. - if val, ok := IsConnected.Load(node.ID); ok && val { + if val, ok := isLikelyConnected.Load(node.ID()); ok && val { resp.Online = true } var tags []string for _, tag := range node.RequestTags() { - if state.NodeCanHaveTag(node.View(), tag) { + if state.NodeCanHaveTag(node, tag) { tags = append(tags, tag) } } - resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...)) - resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...)) + resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...)) + resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) response[index] = resp } + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) + return response } @@ -683,8 +683,8 @@ func (api headscaleV1APIServer) SetPolicy( return nil, fmt.Errorf("setting policy: %w", err) } - if len(nodes) > 0 { - _, err = api.h.state.SSHPolicy(nodes[0].View()) + if nodes.Len() > 0 { + _, err = api.h.state.SSHPolicy(nodes.At(0)) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index d6d32e6d..0e1e0cb4 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -99,8 +99,17 @@ func (h *Headscale) handleVerifyRequest( return fmt.Errorf("cannot list nodes: %w", err) } + // Check if any node has the requested NodeKey + var nodeKeyFound bool + for _, node := range nodes.All() { + if node.NodeKey() == derpAdmitClientRequest.NodePublic { + nodeKeyFound = true + break + } + } + resp := &tailcfg.DERPAdmitClientResponse{ - Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), + Allow: nodeKeyFound, } return json.NewEncoder(writer).Encode(resp) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 0a8b544a..9419a008 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -1940,7 +1940,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { if err != nil { t.Errorf("Error listing peers for node %d: %v", i, err) } else { - t.Logf("Node %d should see %d peers from state", i, len(peers)) + t.Logf("Node %d should see %d peers from state", i, peers.Len()) } } @@ -2106,9 +2106,9 @@ func TestBatcherWorkQueueTracing(t *testing.T) { if err != nil { t.Errorf("Error getting peers from state: %v", err) } else { - t.Logf("State shows %d peers available for this node", len(peers)) - if len(peers) > 0 && len(data.Peers) == 0 { - t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers)) + t.Logf("State shows %d peers available for this node", peers.Len()) + if peers.Len() > 0 && len(data.Peers) == 0 { + t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len()) } } } else { diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index b6102c01..28bca095 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -63,9 +63,9 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { _, matchers := b.mapper.state.Filter() tailnode, err := tailNode( - node.View(), b.capVer, b.mapper.state, + node, b.capVer, b.mapper.state, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) }, b.mapper.cfg) if err != nil { @@ -112,7 +112,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { return b } - sshPolicy, err := b.mapper.state.SSHPolicy(node.View()) + sshPolicy, err := b.mapper.state.SSHPolicy(node) if err != nil { b.addError(err) return b @@ -135,7 +135,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { } // WithUserProfiles adds user profiles for the requesting node and given peers -func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder { +func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { node, err := b.mapper.state.GetNodeByID(b.nodeID) if err != nil { b.addError(err) @@ -161,14 +161,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node.View(), filter), + "base": policy.ReduceFilterRules(node, filter), } return b } // WithPeers adds full peer list with policy filtering (for full map response) -func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { +func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder { tailPeers, err := b.buildTailPeers(peers) if err != nil { @@ -181,7 +181,7 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { } // WithPeerChanges adds changed peers with policy filtering (for incremental updates) -func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder { +func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder { tailPeers, err := b.buildTailPeers(peers) if err != nil { @@ -193,8 +193,8 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil return b } -// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting -func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) { +// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting +func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { node, err := b.mapper.state.GetNodeByID(b.nodeID) if err != nil { return nil, err @@ -206,15 +206,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, // access each-other at all and remove them from the peers. var changedViews views.Slice[types.NodeView] if len(filter) > 0 { - changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers) + changedViews = policy.ReduceNodes(node, peers, matchers) } else { - changedViews = peers.ViewSlice() + changedViews = peers } tailPeers, err := tailNodes( changedViews, b.capVer, b.mapper.state, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) }, b.mapper.cfg) if err != nil { diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 43764457..7ffe2ede 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -18,6 +18,7 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/views" ) const ( @@ -68,16 +69,18 @@ func newMapper( } func generateUserProfiles( - node *types.Node, - peers types.Nodes, + node types.NodeView, + peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { userMap := make(map[uint]*types.User) ids := make([]uint, 0, len(userMap)) - userMap[node.User.ID] = &node.User - ids = append(ids, node.User.ID) - for _, peer := range peers { - userMap[peer.User.ID] = &peer.User - ids = append(ids, peer.User.ID) + user := node.User() + userMap[user.ID] = &user + ids = append(ids, user.ID) + for _, peer := range peers.All() { + peerUser := peer.User() + userMap[peerUser.ID] = &peerUser + ids = append(ids, peerUser.ID) } slices.Sort(ids) @@ -94,7 +97,7 @@ func generateUserProfiles( func generateDNSConfig( cfg *types.Config, - node *types.Node, + node types.NodeView, ) *tailcfg.DNSConfig { if cfg.TailcfgDNSConfig == nil { return nil @@ -114,12 +117,12 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ - "device_name": []string{node.Hostname}, - "device_model": []string{node.Hostinfo.OS}, + "device_name": []string{node.Hostname()}, + "device_model": []string{node.Hostinfo().OS()}, } if len(node.IPs()) > 0 { @@ -137,7 +140,7 @@ func (m *mapper) fullMapResponse( capVer tailcfg.CapabilityVersion, messages ...string, ) (*tailcfg.MapResponse, error) { - peers, err := m.listPeers(nodeID) + peers, err := m.state.ListPeers(nodeID) if err != nil { return nil, err } @@ -182,7 +185,7 @@ func (m *mapper) peerChangeResponse( capVer tailcfg.CapabilityVersion, changedNodeID types.NodeID, ) (*tailcfg.MapResponse, error) { - peers, err := m.listPeers(nodeID, changedNodeID) + peers, err := m.state.ListPeers(nodeID, changedNodeID) if err != nil { return nil, err } @@ -256,25 +259,6 @@ func writeDebugMapResponse( } } -// listPeers returns peers of node, regardless of any Policy or if the node is expired. -// If no peer IDs are given, all peers are returned. -// If at least one peer ID is given, only these peer nodes will be returned. -func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - peers, err := m.state.ListPeers(nodeID, peerIDs...) - if err != nil { - return nil, err - } - - // TODO(kradalby): Add back online via batcher. This was removed - // to avoid a circular dependency between the mapper and the notification. - for _, peer := range peers { - online := m.batcher.IsConnected(peer.ID) - peer.IsOnline = &online - } - - return peers, nil -} - // routeFilterFunc is a function that takes a node ID and returns a list of // netip.Prefixes that are allowed for that node. It is used to filter routes // from the primary route manager to the node. diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 198ba6c4..b801f7dd 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) { &types.Config{ TailcfgDNSConfig: &dnsConfigOrig, }, - nodeInShared1, + nodeInShared1.View(), ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { diff --git a/hscontrol/noise.go b/hscontrol/noise.go index db39992e..9dd42468 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -296,7 +296,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( // getAndValidateNode retrieves the node from the database using the NodeKey // and validates that it matches the MachineKey from the Noise session. func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { - node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) + nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) @@ -304,8 +304,6 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types. return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil) } - nv := node.View() - // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. if ns.machineKey != nv.MachineKey() { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 68361cae..2bfd6342 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims) + user, c, err := a.createOrUpdateUserFromClaim(&claims) if err != nil { log.Error(). Err(err). @@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } // Send policy update notifications if needed - if policyChanged { - a.h.Change(change.PolicyChange()) - } + a.h.Change(c) // TODO(kradalby): Is this comment right? // If the node exists, then the node should be reauthenticated, @@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( claims *types.OIDCClaims, -) (*types.User, bool, error) { +) (*types.User, change.ChangeSet, error) { var user *types.User var err error var newUser bool - var policyChanged bool + var c change.ChangeSet user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, false, fmt.Errorf("creating or updating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err) } // if the user is still not found, create a new empty user. @@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( user.FromClaim(claims) if newUser { - user, policyChanged, err = a.h.state.CreateUser(*user) + user, c, err = a.h.state.CreateUser(*user) if err != nil { - return nil, false, fmt.Errorf("creating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } } else { - _, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { + _, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { *u = *user return nil }) if err != nil { - return nil, false, fmt.Errorf("updating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("updating user: %w", err) } } - return user, policyChanged, nil + return user, c, nil } func (a *AuthProviderOIDC) handleRegistration( diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 1833f060..15de78d3 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -235,6 +235,7 @@ func (m *mapSession) serveLongPoll() { mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) } mapResponseSent.WithLabelValues("ok", "keepalive").Inc() + m.resetKeepAlive() } } } diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go new file mode 100644 index 00000000..ea0f3b61 --- /dev/null +++ b/hscontrol/state/node_store.go @@ -0,0 +1,260 @@ +package state + +import ( + "maps" + "sync/atomic" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +const ( + batchSize = 10 + batchTimeout = 500 * time.Millisecond +) + +const ( + put = 1 + del = 2 + update = 3 +) + +// NodeStore is a thread-safe store for nodes. +// It is a copy-on-write structure, replacing the "snapshot" +// when a change to the structure occurs. It is optimised for reads, +// and while batches are not fast, they are grouped together +// to do less of the expensive peer calculation if there are many +// changes rapidly. +// +// Writes will block until committed, while reads are never +// blocked. +type NodeStore struct { + data atomic.Pointer[Snapshot] + + peersFunc PeersFunc + writeQueue chan work + // TODO: metrics +} + +func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore { + nodes := make(map[types.NodeID]types.Node, len(allNodes)) + for _, n := range allNodes { + nodes[n.ID] = *n + } + snap := snapshotFromNodes(nodes, peersFunc) + + store := &NodeStore{ + peersFunc: peersFunc, + } + store.data.Store(&snap) + + return store +} + +type Snapshot struct { + // nodesByID is the main source of truth for nodes. + nodesByID map[types.NodeID]types.Node + + // calculated from nodesByID + nodesByNodeKey map[key.NodePublic]types.NodeView + peersByNode map[types.NodeID][]types.NodeView + nodesByUser map[types.UserID][]types.NodeView + allNodes []types.NodeView +} + +type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView + +type work struct { + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} + immediate bool // For operations that need immediate processing +} + +// PutNode adds or updates a node in the store. +// If the node already exists, it will be replaced. +// If the node does not exist, it will be added. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) PutNode(n types.Node) { + work := work{ + op: put, + nodeID: n.ID, + node: n, + result: make(chan struct{}), + } + + s.writeQueue <- work + <-work.result +} + +// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. +type UpdateNodeFunc func(n *types.Node) + +// UpdateNode applies a function to modify a specific node in the store. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) { + work := work{ + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + } + + s.writeQueue <- work + <-work.result +} + +// UpdateNodeImmediate applies a function to modify a specific node in the store +// with immediate processing (bypassing normal batching delays). +// Use this for time-sensitive updates like online status changes. +func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) { + work := work{ + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + immediate: true, + } + + s.writeQueue <- work + <-work.result +} + +// DeleteNode removes a node from the store by its ID. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) DeleteNode(id types.NodeID) { + work := work{ + op: del, + nodeID: id, + result: make(chan struct{}), + } + + s.writeQueue <- work + <-work.result +} + +func (s *NodeStore) Start() { + s.writeQueue = make(chan work) + go s.processWrite() +} + +func (s *NodeStore) Stop() { + close(s.writeQueue) +} + +func (s *NodeStore) processWrite() { + c := time.NewTicker(batchTimeout) + batch := make([]work, 0, batchSize) + + for { + select { + case w, ok := <-s.writeQueue: + if !ok { + c.Stop() + return + } + + // Handle immediate operations right away + if w.immediate { + s.applyBatch([]work{w}) + continue + } + + batch = append(batch, w) + if len(batch) >= batchSize { + s.applyBatch(batch) + batch = batch[:0] + c.Reset(batchTimeout) + } + + case <-c.C: + if len(batch) != 0 { + s.applyBatch(batch) + batch = batch[:0] + } + c.Reset(batchTimeout) + } + } +} + +func (s *NodeStore) applyBatch(batch []work) { + nodes := make(map[types.NodeID]types.Node) + maps.Copy(nodes, s.data.Load().nodesByID) + + for _, w := range batch { + switch w.op { + case put: + nodes[w.nodeID] = w.node + case update: + // Update the specific node identified by nodeID + if n, exists := nodes[w.nodeID]; exists { + w.updateFn(&n) + nodes[w.nodeID] = n + } + case del: + delete(nodes, w.nodeID) + } + } + + newSnap := snapshotFromNodes(nodes, s.peersFunc) + + s.data.Store(&newSnap) + + for _, w := range batch { + close(w.result) + } +} + +func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot { + allNodes := make([]types.NodeView, 0, len(nodes)) + for _, n := range nodes { + allNodes = append(allNodes, n.View()) + } + + newSnap := Snapshot{ + nodesByID: nodes, + allNodes: allNodes, + nodesByNodeKey: make(map[key.NodePublic]types.NodeView), + peersByNode: peersFunc(allNodes), + nodesByUser: make(map[types.UserID][]types.NodeView), + } + + // Build nodesByUser and nodesByNodeKey maps + for _, n := range nodes { + nodeView := n.View() + newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView) + newSnap.nodesByNodeKey[n.NodeKey] = nodeView + } + + return newSnap +} + +// GetNode retrieves a node by its ID. +func (s *NodeStore) GetNode(id types.NodeID) types.NodeView { + n := s.data.Load().nodesByID[id] + return n.View() +} + +// GetNodeByNodeKey retrieves a node by its NodeKey. +func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView { + return s.data.Load().nodesByNodeKey[nodeKey] +} + +// ListNodes returns a slice of all nodes in the store. +func (s *NodeStore) ListNodes() views.Slice[types.NodeView] { + return views.SliceOf(s.data.Load().allNodes) +} + +// ListPeers returns a slice of all peers for a given node ID. +func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] { + return views.SliceOf(s.data.Load().peersByNode[id]) +} + +// ListNodesByUser returns a slice of all nodes for a given user ID. +func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] { + return views.SliceOf(s.data.Load().nodesByUser[uid]) +} diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go new file mode 100644 index 00000000..7af07b38 --- /dev/null +++ b/hscontrol/state/node_store_test.go @@ -0,0 +1,494 @@ +package state + +import ( + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/key" +) + +func TestSnapshotFromNodes(t *testing.T) { + tests := []struct { + name string + setupFunc func() (map[types.NodeID]types.Node, PeersFunc) + validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) + }{ + { + name: "empty nodes", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := make(map[types.NodeID]types.Node) + peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + return make(map[types.NodeID][]types.NodeView) + } + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "single node", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + } + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers + assert.Len(t, snapshot.nodesByUser[1], 1) + assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID()) + }, + }, + { + name: "multiple nodes same user", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 1, "user1", "node2"), + } + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Each node sees the other as peer (but not itself) + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "multiple nodes different users", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 1, "user1", "node3"), + } + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // Each node should have 2 peers (all others, but not itself) + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2 + }, + }, + { + name: "odd-even peers filtering", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 3, "user3", "node3"), + 4: createTestNode(4, 4, "user4", "node4"), + } + peersFunc := oddEvenPeersFunc + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 4) + assert.Len(t, snapshot.allNodes, 4) + assert.Len(t, snapshot.peersByNode, 4) + assert.Len(t, snapshot.nodesByUser, 4) + + // Odd nodes should only see other odd nodes as peers + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // Even nodes should only see other even nodes as peers + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nodes, peersFunc := tt.setupFunc() + snapshot := snapshotFromNodes(nodes, peersFunc) + tt.validate(t, nodes, snapshot) + }) + } +} + +// Helper functions + +func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node { + now := time.Now() + machineKey := key.NewMachine() + nodeKey := key.NewNode() + discoKey := key.NewDisco() + + ipv4 := netip.MustParseAddr("100.64.0.1") + ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1") + + return types.Node{ + ID: nodeID, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + DiscoKey: discoKey.Public(), + Hostname: hostname, + GivenName: hostname, + UserID: userID, + User: types.User{ + Name: username, + DisplayName: username, + }, + RegisterMethod: "test", + IPv4: &ipv4, + IPv6: &ipv6, + CreatedAt: now, + UpdatedAt: now, + } +} + +// Peer functions + +func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + for _, n := range nodes { + if n.ID() != node.ID() { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + return ret +} + +func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 + + for _, n := range nodes { + if n.ID() == node.ID() { + continue + } + + peerIsOdd := n.ID()%2 == 1 + + // Only add peer if both are odd or both are even + if nodeIsOdd == peerIsOdd { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + return ret +} + +func TestNodeStoreOperations(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) *NodeStore + steps []testStep + }{ + { + name: "create empty store and add single node", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify empty store", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "add first node", + action: func(store *NodeStore) { + node := createTestNode(1, 1, "user1", "node1") + store.PutNode(node) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, node.ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no peers yet + assert.Len(t, snapshot.nodesByUser[1], 1) + }, + }, + }, + }, + { + name: "create store with initial node and add more", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + initialNodes := types.Nodes{&node1} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + assert.Empty(t, snapshot.peersByNode[1]) + }, + }, + { + name: "add second node same user", + action: func(store *NodeStore) { + node2 := createTestNode(2, 1, "user1", "node2") + store.PutNode(node2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Now both nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "add third node different user", + action: func(store *NodeStore) { + node3 := createTestNode(3, 2, "user2", "node3") + store.PutNode(node3) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // All nodes should see the other 2 as peers + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3 + }, + }, + }, + }, + { + name: "test node deletion", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + node3 := createTestNode(3, 2, "user2", "node3") + initialNodes := types.Nodes{&node1, &node2, &node3} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial 3 nodes", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + }, + }, + { + name: "delete middle node", + action: func(store *NodeStore) { + store.DeleteNode(2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 2) + + // Node 2 should be gone + assert.NotContains(t, snapshot.nodesByID, types.NodeID(2)) + + // Remaining nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // User groupings updated + assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3 + }, + }, + { + name: "delete all remaining nodes", + action: func(store *NodeStore) { + store.DeleteNode(1) + store.DeleteNode(3) + + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + }, + }, + { + name: "test node updates", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial hostnames", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "update node hostname", + action: func(store *NodeStore) { + store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "updated-node1" + n.GivenName = "updated-node1" + }) + + snapshot := store.data.Load() + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged + + // Peers should still work correctly + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Len(t, snapshot.peersByNode[2], 1) + }, + }, + }, + }, + { + name: "test with odd-even peers filtering", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, oddEvenPeersFunc) + }, + steps: []testStep{ + { + name: "add nodes with odd-even filtering", + action: func(store *NodeStore) { + // Add nodes in sequence + store.PutNode(createTestNode(1, 1, "user1", "node1")) + store.PutNode(createTestNode(2, 2, "user2", "node2")) + store.PutNode(createTestNode(3, 3, "user3", "node3")) + store.PutNode(createTestNode(4, 4, "user4", "node4")) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 4) + + // Verify odd-even peer relationships + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + { + name: "delete odd node and verify even nodes unaffected", + action: func(store *NodeStore) { + store.DeleteNode(1) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + + // Node 3 (odd) should now have no peers + assert.Empty(t, snapshot.peersByNode[3]) + + // Even nodes should still see each other + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := tt.setupFunc(t) + store.Start() + defer store.Stop() + + for _, step := range tt.steps { + t.Run(step.name, func(t *testing.T) { + step.action(store) + }) + } + }) + } +} + +type testStep struct { + name string + action func(store *NodeStore) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 02d5d3cd..1768ea97 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -27,6 +27,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" + "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -49,6 +50,9 @@ type State struct { // cfg holds the current Headscale configuration cfg *types.Config + // nodeStore provides an in-memory cache for nodes. + nodeStore *NodeStore + // subsystem keeping state // db provides persistent storage and database operations db *hsdb.HSDatabase @@ -107,6 +111,12 @@ func NewState(cfg *types.Config) (*State, error) { return nil, fmt.Errorf("init policy manager: %w", err) } + nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + _, matchers := polMan.Filter() + return policy.BuildPeerMap(views.SliceOf(nodes), matchers) + }) + nodeStore.Start() + return &State{ cfg: cfg, @@ -117,11 +127,14 @@ func NewState(cfg *types.Config) (*State, error) { polMan: polMan, registrationCache: registrationCache, primaryRoutes: routes.New(), + nodeStore: nodeStore, }, nil } // Close gracefully shuts down the State instance and releases all resources. func (s *State) Close() error { + s.nodeStore.Stop() + if err := s.db.Close(); err != nil { return fmt.Errorf("closing database: %w", err) } @@ -204,43 +217,42 @@ func (s *State) AutoApproveNodes() error { } // CreateUser creates a new user and updates the policy manager. -// Returns the created user, whether policies changed, and any error. -func (s *State) CreateUser(user types.User) (*types.User, bool, error) { +// Returns the created user, change set, and any error. +func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() - if err := s.db.DB.Save(&user).Error; err != nil { - return nil, false, fmt.Errorf("creating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerUsers() + c, err := s.updatePolicyManagerUsers() if err != nil { // Log the error but don't fail the user creation - return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err) + return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err) } // Even if the policy manager doesn't detect a filter change, SSH policies // might now be resolvable when they weren't before. If there are existing // nodes, we should send a policy change to ensure they get updated SSH policies. - if !policyChanged { + if c.Empty() { nodes, err := s.ListNodes() - if err == nil && len(nodes) > 0 { - policyChanged = true + if err == nil && nodes.Len() > 0 { + c = change.PolicyChange() } } - log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated") + log.Info().Str("user", user.Name).Bool("policyChanged", !c.Empty()).Msg("User created, policy manager updated") // TODO(kradalby): implement the user in-memory cache - return &user, policyChanged, nil + return &user, c, nil } // UpdateUser modifies an existing user using the provided update function within a transaction. -// Returns the updated user, whether policies changed, and any error. -func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, bool, error) { +// Returns the updated user, change set, and any error. +func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() @@ -261,18 +273,18 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return user, nil }) if err != nil { - return nil, false, err + return nil, change.EmptySet, err } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerUsers() + c, err := s.updatePolicyManagerUsers() if err != nil { - return user, false, fmt.Errorf("failed to update policy manager after user update: %w", err) + return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err) } // TODO(kradalby): implement the user in-memory cache - return user, policyChanged, nil + return user, c, nil } // DeleteUser permanently removes a user and all associated data (nodes, API keys, etc). @@ -282,7 +294,7 @@ func (s *State) DeleteUser(userID types.UserID) error { } // RenameUser changes a user's name. The new name must be unique. -func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, bool, error) { +func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) { return s.UpdateUser(userID, func(user *types.User) error { user.Name = newName return nil @@ -315,28 +327,31 @@ func (s *State) ListAllUsers() ([]types.User, error) { } // CreateNode creates a new node and updates the policy manager. -// Returns the created node, whether policies changed, and any error. -func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) { - s.mu.Lock() - defer s.mu.Unlock() +// Returns the created node, change set, and any error. +func (s *State) CreateNode(node *types.Node) (types.NodeView, change.ChangeSet, error) { + s.nodeStore.PutNode(*node) if err := s.db.DB.Save(node).Error; err != nil { - return nil, false, fmt.Errorf("creating node: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("creating node: %w", err) } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() + c, err := s.updatePolicyManagerNodes() if err != nil { - return node, false, fmt.Errorf("failed to update policy manager after node creation: %w", err) + return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node creation: %w", err) } // TODO(kradalby): implement the node in-memory cache - return node, policyChanged, nil + if c.Empty() { + c = change.NodeAdded(node.ID) + } + + return node.View(), c, nil } // updateNodeTx performs a database transaction to update a node and refresh the policy manager. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) { +func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (types.NodeView, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() @@ -357,70 +372,72 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err return node, nil }) if err != nil { - return nil, change.EmptySet, err + return types.NodeView{}, change.EmptySet, err } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() + c, err := s.updatePolicyManagerNodes() if err != nil { - return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } // TODO(kradalby): implement the node in-memory cache - var c change.ChangeSet - if policyChanged { - c = change.PolicyChange() - } else { + if c.Empty() { // Basic node change without specific details since this is a generic update c = change.NodeAdded(node.ID) } + return node.View(), c, nil +} + +// SaveNode persists an existing node to the database and updates the policy manager. +func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + + nodePtr := node.AsStruct() + s.nodeStore.PutNode(*nodePtr) + + if err := s.db.DB.Save(nodePtr).Error; err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) + } + + // TODO(kradalby): implement the node in-memory cache + + if c.Empty() { + c = change.NodeAdded(node.ID()) + } + return node, c, nil } -// SaveNode persists an existing node to the database and updates the policy manager. -func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if err := s.db.DB.Save(node).Error; err != nil { - return nil, change.EmptySet, fmt.Errorf("saving node: %w", err) - } - - // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() - if err != nil { - return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - if policyChanged { - return node, change.PolicyChange(), nil - } - - return node, change.EmptySet, nil -} - // DeleteNode permanently removes a node and cleans up associated resources. // Returns whether policies changed and any error. This operation is irreversible. -func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) { - err := s.db.DeleteNode(node) +func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { + s.nodeStore.DeleteNode(node.ID()) + + err := s.db.DeleteNode(node.AsStruct()) if err != nil { return change.EmptySet, err } - c := change.NodeRemoved(node.ID) + c := change.NodeRemoved(node.ID()) // Check if policy manager needs updating after node deletion - policyChanged, err := s.updatePolicyManagerNodes() + policyChange, err := s.updatePolicyManagerNodes() if err != nil { return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err) } - if policyChanged { - c = change.PolicyChange() + if !policyChange.Empty() { + c = policyChange } return c, nil @@ -434,12 +451,26 @@ func (s *State) Connect(node *types.Node) change.ChangeSet { c = change.NodeAdded(node.ID) } + // Update nodestore with online status - node is connecting so it's online + // Use immediate update to ensure online status changes are not delayed by batching + s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) { + // Set the online status in the node's ephemeral field + n.IsOnline = ptr.To(true) + }) + return c } func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) { c := change.NodeOffline(node.ID) + // Update nodestore with offline status + // Use immediate update to ensure online status changes are not delayed by batching + s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) { + // Set the online status to false in the node's ephemeral field + n.IsOnline = ptr.To(false) + }) + _, _, err := s.SetLastSeen(node.ID, time.Now()) if err != nil { return c, fmt.Errorf("disconnecting node: %w", err) @@ -454,70 +485,95 @@ func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) { } // GetNodeByID retrieves a node by ID. -func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) { - return s.db.GetNodeByID(nodeID) -} - -// GetNodeViewByID retrieves a node view by ID. -func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) { - node, err := s.db.GetNodeByID(nodeID) - if err != nil { - return types.NodeView{}, err - } - - return node.View(), nil +func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, error) { + return s.nodeStore.GetNode(nodeID), nil } // GetNodeByNodeKey retrieves a node by its Tailscale public key. -func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { - return s.db.GetNodeByNodeKey(nodeKey) -} - -// GetNodeViewByNodeKey retrieves a node view by its Tailscale public key. -func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) { - node, err := s.db.GetNodeByNodeKey(nodeKey) - if err != nil { - return types.NodeView{}, err - } - - return node.View(), nil +func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) { + return s.nodeStore.GetNodeByNodeKey(nodeKey), nil } // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. -func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { +func (s *State) ListNodes(nodeIDs ...types.NodeID) (views.Slice[types.NodeView], error) { if len(nodeIDs) == 0 { - return s.db.ListNodes() + return s.nodeStore.ListNodes(), nil } - return s.db.ListNodes(nodeIDs...) + // Filter nodes by the requested IDs + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(nodeIDs)) + for _, id := range nodeIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes), nil } // ListNodesByUser retrieves all nodes belonging to a specific user. -func (s *State) ListNodesByUser(userID types.UserID) (types.Nodes, error) { - return hsdb.Read(s.db.DB, func(rx *gorm.DB) (types.Nodes, error) { - return hsdb.ListNodesByUser(rx, userID) - }) +func (s *State) ListNodesByUser(userID types.UserID) (views.Slice[types.NodeView], error) { + return s.nodeStore.ListNodesByUser(userID), nil } // ListPeers retrieves nodes that can communicate with the specified node based on policy. -func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - return s.db.ListPeers(nodeID, peerIDs...) +func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (views.Slice[types.NodeView], error) { + if len(peerIDs) == 0 { + return s.nodeStore.ListPeers(nodeID), nil + } + + // For specific peerIDs, filter from all nodes + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(peerIDs)) + for _, id := range peerIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes), nil } // ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system. -func (s *State) ListEphemeralNodes() (types.Nodes, error) { - return s.db.ListEphemeralNodes() +func (s *State) ListEphemeralNodes() (views.Slice[types.NodeView], error) { + allNodes := s.nodeStore.ListNodes() + var ephemeralNodes []types.NodeView + + for _, node := range allNodes.All() { + // Check if node is ephemeral by checking its AuthKey + if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + ephemeralNodes = append(ephemeralNodes, node) + } + } + + return views.SliceOf(ephemeralNodes), nil } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) { +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) { n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.NodeSetExpiry(tx, nodeID, expiry) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) } + // Update nodestore with the same change + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.Expiry = &expiry + }) + if !c.IsFull() { c = change.KeyExpiry(nodeID) } @@ -526,14 +582,19 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Nod } // SetNodeTags assigns tags to a node for use in access control policies. -func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) { +func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetTags(tx, nodeID, tags) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err) } + // Update nodestore with the same change + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.ForcedTags = tags + }) + if !c.IsFull() { c = change.NodeAdded(nodeID) } @@ -542,16 +603,21 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, ch } // SetApprovedRoutes sets the network routes that a node is approved to advertise. -func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) { +func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) { n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetApprovedRoutes(tx, nodeID, routes) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) } + // Update nodestore with the same change + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.ApprovedRoutes = routes + }) + // Update primary routes after changing approved routes - routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...) + routeChange := s.primaryRoutes.SetRoutes(nodeID, n.AsStruct().SubnetRoutes()...) if routeChange || !c.IsFull() { c = change.PolicyChange() @@ -561,14 +627,19 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (* } // RenameNode changes the display name of a node. -func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) { +func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) { n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.RenameNode(tx, nodeID, newName) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) } + // Update nodestore with the same change + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.GivenName = newName + }) + if !c.IsFull() { c = change.NodeAdded(nodeID) } @@ -577,21 +648,19 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, ch } // SetLastSeen updates when a node was last seen, used for connectivity monitoring. -func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (types.NodeView, change.ChangeSet, error) { + n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetLastSeen(tx, nodeID, lastSeen) }) -} - -// AssignNodeToUser transfers a node to a different user. -func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.AssignNodeToUser(tx, nodeID, userID) - }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting last seen: %w", err) } + // Update nodestore with the same change + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.LastSeen = &lastSeen + }) + if !c.IsFull() { c = change.NodeAdded(nodeID) } @@ -599,6 +668,32 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*typ return n, c, nil } +// AssignNodeToUser transfers a node to a different user. +func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (types.NodeView, change.ChangeSet, error) { + node, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { + return hsdb.AssignNodeToUser(tx, nodeID, userID) + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Update nodestore with the same change + // Get the updated user information from the database + user, err := s.GetUserByID(userID) + if err == nil { + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.UserID = uint(userID) + n.User = *user + }) + } + + if !c.IsFull() { + c = change.NodeAdded(nodeID) + } + + return node, c, nil +} + // BackfillNodeIPs assigns IP addresses to nodes that don't have them. func (s *State) BackfillNodeIPs() ([]string, error) { return s.db.BackfillNodeIPs(s.ipAlloc) @@ -631,8 +726,16 @@ func (s *State) SetPolicy(pol []byte) (bool, error) { } // AutoApproveRoutes checks if a node's routes should be auto-approved. -func (s *State) AutoApproveRoutes(node *types.Node) bool { - return policy.AutoApproveRoutes(s.polMan, node) +func (s *State) AutoApproveRoutes(node types.NodeView) bool { + nodePtr := node.AsStruct() + changed := policy.AutoApproveRoutes(s.polMan, nodePtr) + if changed { + s.nodeStore.PutNode(*nodePtr) + // Update primaryRoutes manager with the newly approved routes + // This is essential for actual packet forwarding to work + s.primaryRoutes.SetRoutes(nodePtr.ID, nodePtr.SubnetRoutes()...) + } + return changed } // PolicyDebugString returns a debug representation of the current policy. @@ -742,34 +845,42 @@ func (s *State) HandleNodeFromAuthPath( userID types.UserID, expiry *time.Time, registrationMethod string, -) (*types.Node, change.ChangeSet, error) { +) (types.NodeView, change.ChangeSet, error) { ipv4, ipv6, err := s.ipAlloc.Next() if err != nil { - return nil, change.EmptySet, err + return types.NodeView{}, change.EmptySet, err } - return s.db.HandleNodeFromAuthPath( + node, nodeChange, err := s.db.HandleNodeFromAuthPath( registrationID, userID, expiry, util.RegisterMethodOIDC, ipv4, ipv6, ) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Update nodestore with the newly registered/updated node + s.nodeStore.PutNode(*node) + + return node.View(), nodeChange, nil } // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, -) (*types.Node, change.ChangeSet, bool, error) { +) (types.NodeView, change.ChangeSet, error) { pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { - return nil, change.EmptySet, false, err + return types.NodeView{}, change.EmptySet, err } err = pak.Validate() if err != nil { - return nil, change.EmptySet, false, err + return types.NodeView{}, change.EmptySet, err } nodeToRegister := types.Node{ @@ -796,7 +907,7 @@ func (s *State) HandleNodeFromPreAuthKey( ipv4, ipv6, err := s.ipAlloc.Next() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) } node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { @@ -818,38 +929,44 @@ func (s *State) HandleNodeFromPreAuthKey( return node, nil }) if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) } // Check if this is a logout request for an ephemeral node if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { // This is a logout request for an ephemeral node, delete it immediately - c, err := s.DeleteNode(node) + c, err := s.DeleteNode(node.View()) if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) } - return nil, c, false, nil + return types.NodeView{}, c, nil } // Check if policy manager needs updating // This is necessary because we just created a new node. // We need to ensure that the policy manager is aware of this new node. // Also update users to ensure all users are known when evaluating policies. - usersChanged, err := s.updatePolicyManagerUsers() + usersChange, err := s.updatePolicyManagerUsers() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager users after node registration: %w", err) } - nodesChanged, err := s.updatePolicyManagerNodes() + nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err) } - policyChanged := usersChanged || nodesChanged + var c change.ChangeSet + if !usersChange.Empty() || !nodesChange.Empty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(node.ID) + } - c := change.NodeAdded(node.ID) + // Update nodestore with the newly registered node + s.nodeStore.PutNode(*node) - return node, c, policyChanged, nil + return node.View(), c, nil } // AllocateNextIPs allocates the next available IPv4 and IPv6 addresses. @@ -863,22 +980,25 @@ func (s *State) AllocateNextIPs() (*netip.Addr, *netip.Addr, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for users. // updatePolicyManagerUsers refreshes the policy manager with current user data. -func (s *State) updatePolicyManagerUsers() (bool, error) { +func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { users, err := s.ListAllUsers() if err != nil { - return false, fmt.Errorf("listing users for policy update: %w", err) + return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err) } log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users") changed, err := s.polMan.SetUsers(users) if err != nil { - return false, fmt.Errorf("updating policy manager users: %w", err) + return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err) } log.Debug().Bool("changed", changed).Msg("Policy manager users updated") - return changed, nil + if changed { + return change.PolicyChange(), nil + } + return change.EmptySet, nil } // updatePolicyManagerNodes updates the policy manager with current nodes. @@ -887,18 +1007,21 @@ func (s *State) updatePolicyManagerUsers() (bool, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for nodes. // updatePolicyManagerNodes refreshes the policy manager with current node data. -func (s *State) updatePolicyManagerNodes() (bool, error) { +func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) { nodes, err := s.ListNodes() if err != nil { - return false, fmt.Errorf("listing nodes for policy update: %w", err) + return change.EmptySet, fmt.Errorf("listing nodes for policy update: %w", err) } - changed, err := s.polMan.SetNodes(nodes.ViewSlice()) + changed, err := s.polMan.SetNodes(nodes) if err != nil { - return false, fmt.Errorf("updating policy manager nodes: %w", err) + return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err) } - return changed, nil + if changed { + return change.PolicyChange(), nil + } + return change.EmptySet, nil } // PingDB checks if the database connection is healthy. @@ -923,6 +1046,9 @@ func (s *State) autoApproveNodes() error { // TODO(kradalby): This change should probably be sent to the rest of the system. changed := policy.AutoApproveRoutes(s.polMan, node) if changed { + // Update nodestore first if available + s.nodeStore.PutNode(*node) + err = tx.Save(node).Error if err != nil { return err @@ -985,7 +1111,7 @@ func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapReques if routesChanged { // Auto approve any routes that have been defined in policy as // auto approved. Check if this actually changed the node. - _ = s.AutoApproveRoutes(node) + _ = s.AutoApproveRoutes(node.View()) // Update the routes of the given node in the route manager to // see if an update needs to be sent. @@ -998,7 +1124,7 @@ func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapReques // the hostname change. node.ApplyHostnameFromHostInfo(req.Hostinfo) - _, policyChange, err := s.SaveNode(node) + _, policyChange, err := s.SaveNode(node.View()) if err != nil { return change.EmptySet, err } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 81a2a86a..fa315bf5 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -13,6 +13,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/tsaddr" @@ -511,11 +512,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } if node.Hostname != hostInfo.Hostname { + log.Trace(). + Str("node_id", node.ID.String()). + Str("old_hostname", node.Hostname). + Str("new_hostname", hostInfo.Hostname). + Str("old_given_name", node.GivenName). + Bool("given_name_changed", node.GivenNameHasBeenChanged()). + Msg("Updating hostname from hostinfo") + if node.GivenNameHasBeenChanged() { node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname) } node.Hostname = hostInfo.Hostname + + log.Trace(). + Str("node_id", node.ID.String()). + Str("new_hostname", node.Hostname). + Str("new_given_name", node.GivenName). + Msg("Hostname updated") } } @@ -759,6 +774,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix { return v.ж.ExitRoutes() } +// RequestTags returns the ACL tags that the node is requesting. +func (v NodeView) RequestTags() []string { + if !v.Valid() || !v.Hostinfo().Valid() { + return []string{} + } + return v.Hostinfo().RequestTags().AsSlice() +} + +// Proto converts the NodeView to a protobuf representation. +func (v NodeView) Proto() *v1.Node { + if !v.Valid() { + return nil + } + return v.ж.Proto() +} + // HasIP reports if a node has a given IP address. func (v NodeView) HasIP(i netip.Addr) bool { if !v.Valid() { diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index d7bc7897..97bb3da4 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Parse each hop line - hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?") + hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) for i := 1; i < len(lines); i++ { matches := hopRegex.FindStringSubmatch(lines[i]) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index d118b643..394d219b 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.NoError(ct, err) assert.Equal(ct, "NeedsLogin", status.BackendState) } + assertTailscaleNodesLogout(t, allClients) }, shortAccessTTL+10*time.Second, 5*time.Second) } diff --git a/integration/general_test.go b/integration/general_test.go index 4e250854..cee17610 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -535,6 +535,8 @@ func TestUpdateHostnameFromClient(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second var nodes []*v1.Node assert.EventuallyWithT(t, func(ct *assert.CollectT) { err := executeAndUnmarshal( @@ -630,27 +632,34 @@ func TestUpdateHostnameFromClient(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "node", - "list", - "--output", - "json", - }, - &nodes, - ) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second + assert.Eventually(t, func() bool { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) - assertNoErr(t, err) - assert.Len(t, nodes, 3) + if err != nil || len(nodes) != 3 { + return false + } - for _, node := range nodes { - hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] - givenName := fmt.Sprintf("%d-givenname", node.GetId()) - assert.Equal(t, hostname+"NEW", node.GetName()) - assert.Equal(t, givenName, node.GetGivenName()) - } + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName { + return false + } + } + return true + }, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix") } func TestExpireNode(t *testing.T) { diff --git a/integration/route_test.go b/integration/route_test.go index 7243d3f2..bb13a47f 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, node.GetSubnetRoutes(), 1) } - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - assert.NotNil(t, peerStatus.PrimaryRoutes) - - assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3) - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + assert.NotNil(c, peerStatus.PrimaryRoutes) + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + } } - } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") _, err = headscale.ApproveRoutes( 1, @@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - - for _, node := range nodes { - if node.GetId() == 1 { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 - assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 - assert.Empty(t, node.GetSubnetRoutes()) - } else if node.GetId() == 2 { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 - assert.Empty(t, node.GetApprovedRoutes()) - assert.Empty(t, node.GetSubnetRoutes()) - } else { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 - assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 - assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + for _, node := range nodes { + if node.GetId() == 1 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetSubnetRoutes()) + } else if node.GetId() == 2 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetApprovedRoutes()) + assert.Empty(c, node.GetSubnetRoutes()) + } else { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + } } - } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the clients can see the new routes for _, client := range allClients { @@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - time.Sleep(3 * time.Second) + // Wait for route configuration changes after advertising routes + var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err := headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 0, 0) - requireNodeRouteCount(t, nodes[1], 1, 0, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved") // Verify that no routes has been sent to the client, // they are not yet enabled. @@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on first subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 0, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route") // Verify that the client has routes from the primary machine and can access // the webservice. @@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on second subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route") // Verify that the client has routes from the primary machine srs1 = subRouter1.MustStatus() @@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on third subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0) + }, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route") // Verify that the client has routes from the primary machine srs1 = subRouter1.MustStatus() @@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NoError(t, err) assert.Len(t, result, 13) - tr, err = client.Traceroute(webip) - require.NoError(t, err) - assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + // Wait for traceroute to work correctly through the expected router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + + // Get the expected router IP - use a more robust approach to handle temporary disconnections + ips, err := subRouter1.IPs() + assert.NoError(c, err) + assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") + + var expectedIP netip.Addr + for _, ip := range ips { + if ip.Is4() { + expectedIP = ip + break + } + } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") + + assertTracerouteViaIPWithCollect(c, tr, expectedIP) + }, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1") // Take down the current primary t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname()) @@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter1.Down() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r1 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs2 = subRouter2.MustStatus() + clientStatus = client.MustStatus() - srs2 = subRouter2.MustStatus() - clientStatus = client.MustStatus() + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down") - assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up") - assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up") + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) require.NotNil(t, srs2PeerStatus.PrimaryRoutes) @@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter2.Down() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r2 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // TODO(kradalby): Check client status - // Both are expected to be down + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - // Verify that the route is not presented from either router - clientStatus, err = client.Status() - require.NoError(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down") - assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down") - assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down") + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.False(c, srs2PeerStatus.Online, "r2 should be offline") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter1.Up() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r1 comes back up + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - require.NoError(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down") - assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down") - assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available") + assert.True(c, srs1PeerStatus.Online, "r1 should be back online") + assert.False(c, srs2PeerStatus.Online, "r2 should still be offline") + assert.True(c, srs3PeerStatus.Online, "r3 should still be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter2.Up() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing to complete and online status to be updated + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - require.NoError(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up") - assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up") - assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up") + assert.True(c, srs1PeerStatus.Online, "r1 should be online") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname()) _, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{}) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + // After disabling route on r3, r1 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname()) _, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{}) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + // After disabling route on r1, r2 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) { util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()), ) - time.Sleep(5 * time.Second) + // Wait for route state changes after re-enabling r1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - require.Len(t, nodes, 2) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 0, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the client has routes from the primary machine srs1, _ := subRouter1.Status() @@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) { requireNodeRouteCount(t, nodes[0], 2, 2, 2) requireNodeRouteCount(t, nodes[1], 2, 2, 2) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - require.NotNil(t, peerStatus.AllowedIPs) - assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4) - assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) - assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + assert.NotNil(c, peerStatus.AllowedIPs) + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + } } - } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") } // TestSubnetRouterMultiNetwork is an evolution of the subnet router test. @@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes and clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 2) - requireNodeRouteCount(t, nodes[0], 1, 1, 1) + // Verify that the routes have been sent to the client + status, err = user2c.Status() + assert.NoError(c, err) - // Verify that the routes have been sent to the client. - status, err = user2c.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref) - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref}) - } + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients") usernet1, err := scenario.Network("usernet1") require.NoError(t, err) @@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes and clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 2) - requireNodeRouteCount(t, nodes[0], 2, 2, 2) + // Verify that the routes have been sent to the client + status, err = user2c.Status() + assert.NoError(c, err) - // Verify that the routes have been sent to the client. - status, err = user2c.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) - } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients") // Tell user2c to use user1c as an exit node. command = []string{ @@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) + var nodes []*v1.Node opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoErrorf(t, err, "failed to advertise route: %s", err) } - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err := headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err := client.Status() @@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = headscale.SetPolicy(tt.pol) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = headscale.SetPolicy(tt.pol) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerSubRoute.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") requireNodeRouteCount(t, nodes[1], 1, 1, 1) // Verify that the routes have been sent to the client. @@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerSubRoute.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") requireNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[2], 0, 0, 0) @@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerExitNode.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 2, 2, 2) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) { require.Equal(t, tr.Route[0].IP, ip) } +// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT +func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { + assert.NotNil(c, tr) + assert.True(c, tr.Success) + assert.NoError(c, tr.Err) + assert.NotEmpty(c, tr.Route) + assert.Equal(c, tr.Route[0].IP, ip) +} + // requirePeerSubnetRoutes asserts that the peer has the expected subnet routes. func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) { t.Helper() @@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected } } +func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) { + if status.AllowedIPs.Len() <= 2 && len(expected) != 0 { + assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected)) + return + } + + if len(expected) == 0 { + expected = []netip.Prefix{} + } + + got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool { + if tsaddr.IsExitRoute(p) { + return true + } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) + }) + + if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { + assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff)) + } +} + func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) { t.Helper() require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) @@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) } +func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) { + assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) + assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes())) + assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) +} + // TestSubnetRouteACLFiltering tests that a node can only access subnet routes // that are explicitly allowed in the ACL. func TestSubnetRouteACLFiltering(t *testing.T) { @@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) { ) require.NoError(t, err) - // Give some time for the routes to propagate - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // List nodes and verify the router has 3 available routes + nodes, err = headscale.NodesByUser() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - // List nodes and verify the router has 3 available routes - nodes, err = headscale.NodesByUser() - require.NoError(t, err) - require.Len(t, nodes, 2) + // Find the router node + routerNode = nodes[routerUser][0] - // Find the router node - routerNode = nodes[routerUser][0] - - // Check that the router has 3 routes now approved and available - requireNodeRouteCount(t, routerNode, 3, 3, 3) + // Check that the router has 3 routes now approved and available + requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Now check the client node status nodeStatus, err := nodeClient.Status() diff --git a/integration/scenario.go b/integration/scenario.go index 817d927b..8ce54b89 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -14,7 +14,6 @@ import ( "net/netip" "net/url" "os" - "sort" "strconv" "strings" "sync" @@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { return nil, fmt.Errorf("no network named: %s", name) } - for _, ipam := range net.Network.IPAM.Config { - pref, err := netip.ParsePrefix(ipam.Subnet) - if err != nil { - return nil, err - } - - return &pref, nil + if len(net.Network.IPAM.Config) == 0 { + return nil, fmt.Errorf("no IPAM config found in network: %s", name) } - return nil, fmt.Errorf("no prefix found in network: %s", name) + pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) + if err != nil { + return nil, err + } + + return &pref, nil } func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { @@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv( return err } - sort.Strings(s.spec.Users) for _, user := range s.spec.Users { u, err := s.CreateUser(user) if err != nil {