From 2e20652fdfcb17820d4e9c2d0f6f50921ddbacc4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 5 Jul 2025 23:30:47 +0200 Subject: [PATCH] state/nodestore: in memory representation of nodes Initial work on a nodestore which stores all of the nodes and their relations in memory with relationship for peers precalculated. It is a copy-on-write structure, replacing the "snapshot" when a change to the structure occurs. It is optimised for reads, and while batches are not fast, they are grouped together to do less of the expensive peer calculation if there are many changes rapidly. Writes will block until commited, while reads are never blocked. Signed-off-by: Kristoffer Dalby --- cmd/headscale/cli/nodes.go | 3 +- hscontrol/app.go | 47 +- hscontrol/auth.go | 76 +- hscontrol/db/node.go | 179 +-- hscontrol/db/node_test.go | 121 +- hscontrol/db/preauth_keys.go | 4 +- hscontrol/db/users.go | 21 +- hscontrol/debug.go | 6 +- hscontrol/grpcv1.go | 153 +- hscontrol/handlers.go | 14 +- hscontrol/mapper/batcher.go | 7 +- hscontrol/mapper/batcher_lockfree.go | 250 ++- hscontrol/mapper/batcher_test.go | 303 +++- hscontrol/mapper/builder.go | 133 +- hscontrol/mapper/mapper.go | 57 +- hscontrol/mapper/mapper_test.go | 2 +- hscontrol/mapper/tail.go | 13 +- hscontrol/noise.go | 12 +- hscontrol/oidc.go | 24 +- hscontrol/policy/policy.go | 86 +- hscontrol/policy/policy_autoapprove_test.go | 339 ++++ .../policy/policy_route_approval_test.go | 361 +++++ hscontrol/policy/route_approval_test.go | 23 + hscontrol/policy/v2/policy.go | 8 +- hscontrol/poll.go | 91 +- hscontrol/state/node_store.go | 403 +++++ hscontrol/state/node_store_test.go | 501 ++++++ hscontrol/state/state.go | 1369 ++++++++++++----- hscontrol/types/change/change.go | 1 + hscontrol/types/node.go | 37 + hscontrol/util/util.go | 2 +- integration/auth_oidc_test.go | 1 + integration/general_test.go | 47 +- integration/route_test.go | 565 ++++--- integration/scenario.go | 18 +- 35 files changed, 3960 insertions(+), 1317 deletions(-) create mode 100644 hscontrol/policy/policy_autoapprove_test.go create mode 100644 hscontrol/policy/policy_route_approval_test.go create mode 100644 hscontrol/state/node_store.go create mode 100644 hscontrol/state/node_store_test.go diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index caac986c..6d6476fb 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -551,13 +551,12 @@ be assigned to nodes.`, } } - if confirm || force { ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force }) + changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force}) if err != nil { ErrorOutput( err, diff --git a/hscontrol/app.go b/hscontrol/app.go index 774aec46..47b38c83 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -137,9 +137,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { // Initialize ephemeral garbage collector ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { - node, err := app.state.GetNodeByID(ni) - if err != nil { - log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion") + node, ok := app.state.GetNodeByID(ni) + if !ok { + log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") + log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") return } @@ -379,15 +380,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writer http.ResponseWriter, req *http.Request, ) { - log.Trace(). - Caller(). - Str("client_address", req.RemoteAddr). - Msg("HTTP authentication invoked") - - authHeader := req.Header.Get("authorization") - - if !strings.HasPrefix(authHeader, AuthPrefix) { - log.Error(). + if err := func() error { + log.Trace(). Caller(). Str("client_address", req.RemoteAddr). Msg(`missing "Bearer " prefix in "Authorization" header`) @@ -501,11 +495,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { + var err error capver.CanOldCodeBeCleanedUp() if profilingEnabled { if profilingPath != "" { - err := os.MkdirAll(profilingPath, os.ModePerm) + err = os.MkdirAll(profilingPath, os.ModePerm) if err != nil { log.Fatal().Err(err).Msg("failed to create profiling directory") } @@ -559,12 +554,9 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() - ephmNodes, err := h.state.ListEphemeralNodes() - if err != nil { - return fmt.Errorf("failed to list ephemeral nodes: %w", err) - } - for _, node := range ephmNodes { - h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout) + ephmNodes := h.state.ListEphemeralNodes() + for _, node := range ephmNodes.All() { + h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) } if h.cfg.DNSConfig.ExtraRecordsPath != "" { @@ -794,23 +786,14 @@ func (h *Headscale) Serve() error { continue } - changed, err := h.state.ReloadPolicy() + changes, err := h.state.ReloadPolicy() if err != nil { log.Error().Err(err).Msgf("reloading policy") continue } - if changed { - log.Info(). - Msg("ACL policy successfully reloaded, notifying nodes of change") + h.Change(changes...) - err = h.state.AutoApproveNodes() - if err != nil { - log.Error().Err(err).Msg("failed to approve routes after new policy") - } - - h.Change(change.PolicySet) - } default: info := func(msg string) { log.Info().Msg(msg) } log.Info(). @@ -1020,6 +1003,6 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { // Change is used to send changes to nodes. // All change should be enqueued here and empty will be automatically // ignored. -func (h *Headscale) Change(c change.ChangeSet) { - h.mapBatcher.AddWork(c) +func (h *Headscale) Change(cs ...change.ChangeSet) { + h.mapBatcher.AddWork(cs...) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index cb284173..81032640 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -12,7 +12,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/rs/zerolog/log" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -29,28 +28,10 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, err := h.state.GetNodeByNodeKey(regReq.NodeKey) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("looking up node in database: %w", err) - } + node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey) - if node != nil { - // If an existing node is trying to register with an auth key, - // we need to validate the auth key even for existing nodes - if regReq.Auth != nil && regReq.Auth.AuthKey != "" { - resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) - if err != nil { - // Preserve HTTPError types so they can be handled properly by the HTTP layer - var httpErr HTTPError - if errors.As(err, &httpErr) { - return nil, httpErr - } - return nil, fmt.Errorf("handling register with auth key for existing node: %w", err) - } - return resp, nil - } - - resp, err := h.handleExistingNode(node, regReq, machineKey) + if ok { + resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey) if err != nil { return nil, fmt.Errorf("handling existing node: %w", err) } @@ -70,6 +51,7 @@ func (h *Headscale) handleRegister( if errors.As(err, &httpErr) { return nil, httpErr } + return nil, fmt.Errorf("handling register with auth key: %w", err) } @@ -89,13 +71,22 @@ func (h *Headscale) handleExistingNode( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - if node.MachineKey != machineKey { return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } expired := node.IsExpired() + // If the node is expired and this is not a re-authentication attempt, + // force the client to re-authenticate + if expired && regReq.Auth == nil { + return &tailcfg.RegisterResponse{ + NodeKeyExpired: true, + MachineAuthorized: false, + AuthURL: "", // Client will need to re-authenticate + }, nil + } + if !expired && !regReq.Expiry.IsZero() { requestExpiry := regReq.Expiry @@ -107,7 +98,7 @@ func (h *Headscale) handleExistingNode( // If the request expiry is in the past, we consider it a logout. if requestExpiry.Before(time.Now()) { if node.IsEphemeral() { - c, err := h.state.DeleteNode(node) + c, err := h.state.DeleteNode(node.View()) if err != nil { return nil, fmt.Errorf("deleting ephemeral node: %w", err) } @@ -118,15 +109,19 @@ func (h *Headscale) handleExistingNode( } } - _, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) + updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) if err != nil { return nil, fmt.Errorf("setting node expiry: %w", err) } h.Change(c) - } - return nodeToRegisterResponse(node), nil + // CRITICAL: Use the updated node view for the response + // The original node object has stale expiry information + node = updatedNode.AsStruct() + } + + return nodeToRegisterResponse(node), nil } func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { @@ -177,7 +172,7 @@ func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey( + node, changed, err := h.state.HandleNodeFromPreAuthKey( regReq, machineKey, ) @@ -193,8 +188,8 @@ func (h *Headscale) handleRegisterWithAuthKey( return nil, err } - // If node is nil, it means an ephemeral node was deleted during logout - if node == nil { + // If node is not valid, it means an ephemeral node was deleted during logout + if !node.Valid() { h.Change(changed) return nil, nil } @@ -213,26 +208,30 @@ func (h *Headscale) handleRegisterWithAuthKey( // TODO(kradalby): This needs to be ran as part of the batcher maybe? // now since we dont update the node/pol here anymore routeChange := h.state.AutoApproveRoutes(node) + if _, _, err := h.state.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } if routeChange && changed.Empty() { - changed = change.NodeAdded(node.ID) + changed = change.NodeAdded(node.ID()) } h.Change(changed) - // If policy changed due to node registration, send a separate policy change - if policyChanged { - policyChange := change.PolicyChange() - h.Change(policyChange) - } + // TODO(kradalby): I think this is covered above, but we need to validate that. + // // If policy changed due to node registration, send a separate policy change + // if policyChanged { + // policyChange := change.PolicyChange() + // h.Change(policyChange) + // } + + user := node.User() return &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), - User: *node.User.TailscaleUser(), - Login: *node.User.TailscaleLogin(), + User: *user.TailscaleUser(), + Login: *user.TailscaleLogin(), }, nil } @@ -266,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive( ) log.Info().Msgf("Starting node registration using key: %s", registrationId) + return &tailcfg.RegisterResponse{ AuthURL: h.authProvider.AuthURL(registrationId), }, nil diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 83d62d3d..3531fc49 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { } // RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. If the name is not unique, it will return an error. +// and renames it. Validation should be done in the state layer before calling this function. func RenameNode(tx *gorm.DB, nodeID types.NodeID, newName string, ) error { - err := util.CheckForFQDNRules( - newName, - ) - if err != nil { - return fmt.Errorf("renaming node: %w", err) + // Check if the new name is unique + var count int64 + if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { + return fmt.Errorf("failed to check name uniqueness: %w", err) } - uniq, err := isUniqueName(tx, newName) - if err != nil { - return fmt.Errorf("checking if name is unique: %w", err) - } - - if !uniq { - return fmt.Errorf("name is not unique: %s", newName) + if count > 0 { + return errors.New("name is not unique") } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -333,108 +327,19 @@ func (hsdb *HSDatabase) DeleteEphemeralNode( }) } -// HandleNodeFromAuthPath is called from the OIDC or CLI auth path -// with a registrationID to register or reauthenticate a node. -// If the node found in the registration cache is not already registered, -// it will be registered with the user and the node will be removed from the cache. -// If the node is already registered, the expiry will be updated. -// The node, and a boolean indicating if it was a new node or not, will be returned. -func (hsdb *HSDatabase) HandleNodeFromAuthPath( - registrationID types.RegistrationID, - userID types.UserID, - nodeExpiry *time.Time, - registrationMethod string, - ipv4 *netip.Addr, - ipv6 *netip.Addr, -) (*types.Node, change.ChangeSet, error) { - var nodeChange change.ChangeSet - node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - if reg, ok := hsdb.regCache.Get(registrationID); ok { - if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { - user, err := GetUserByID(tx, userID) - if err != nil { - return nil, fmt.Errorf( - "failed to find user in register node from auth callback, %w", - err, - ) - } +// RegisterNodeForTest is used only for testing purposes to register a node directly in the database. +// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey. +func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { + if !testing.Testing() { + panic("RegisterNodeForTest can only be called during tests") + } - log.Debug(). - Str("registration_id", registrationID.String()). - Str("username", user.Username()). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). - Msg("Registering node from API/CLI or auth callback") - - // TODO(kradalby): This looks quite wrong? why ID 0? - // Why not always? - // Registration of expired node with different user - if reg.Node.ID != 0 && - reg.Node.UserID != user.ID { - return nil, ErrDifferentRegisteredUser - } - - reg.Node.UserID = user.ID - reg.Node.User = *user - reg.Node.RegisterMethod = registrationMethod - - if nodeExpiry != nil { - reg.Node.Expiry = nodeExpiry - } - - node, err := RegisterNode( - tx, - reg.Node, - ipv4, ipv6, - ) - - if err == nil { - hsdb.regCache.Delete(registrationID) - } - - // Signal to waiting clients that the machine has been registered. - select { - case reg.Registered <- node: - default: - } - close(reg.Registered) - - nodeChange = change.NodeAdded(node.ID) - - return node, err - } else { - // If the node is already registered, this is a refresh. - err := NodeSetExpiry(tx, node.ID, *nodeExpiry) - if err != nil { - return nil, err - } - - nodeChange = change.KeyExpiry(node.ID) - - return node, nil - } - } - - return nil, ErrNodeNotFoundRegistrationCache - }) - - return node, nodeChange, err -} - -func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - return RegisterNode(tx, node, ipv4, ipv6) - }) -} - -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. -func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). Str("user", node.User.Username()). - Msg("Registering node") + Msg("Registering test node") // If the a new node is registered with the same machine key, to the same user, // update the existing node. @@ -445,8 +350,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.ID = oldNode.ID node.GivenName = oldNode.GivenName node.ApprovedRoutes = oldNode.ApprovedRoutes - ipv4 = oldNode.IPv4 - ipv6 = oldNode.IPv6 + // Don't overwrite the provided IPs with old ones when they exist + if ipv4 == nil { + ipv4 = oldNode.IPv4 + } + if ipv6 == nil { + ipv6 = oldNode.IPv6 + } } // If the node exists and it already has IP(s), we just save it @@ -463,7 +373,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). Str("user", node.User.Username()). - Msg("Node authorized again") + Msg("Test node authorized again") return &node, nil } @@ -472,7 +382,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.IPv6 = ipv6 if node.GivenName == "" { - givenName, err := ensureUniqueGivenName(tx, node.Hostname) + givenName, err := EnsureUniqueGivenName(tx, node.Hostname) if err != nil { return nil, fmt.Errorf("failed to ensure unique given name: %w", err) } @@ -487,7 +397,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad log.Trace(). Caller(). Str("node", node.Hostname). - Msg("Node registered with the database") + Msg("Test node registered with the database") return &node, nil } @@ -560,7 +470,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) { return len(nodes) == 0, nil } -func ensureUniqueGivenName( +// EnsureUniqueGivenName generates a unique given name for a node based on its hostname. +func EnsureUniqueGivenName( tx *gorm.DB, name string, ) (string, error) { @@ -781,19 +692,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname . node := hsdb.CreateNodeForTest(user, hostname...) - err := hsdb.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, *node, nil, nil) + // Allocate IPs for the test node using the database's IP allocator + // This is a simplified allocation for testing - in production this would use State.ipAlloc + ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID) + if err != nil { + panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err)) + } + + var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { + var err error + registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) return err }) if err != nil { panic(fmt.Sprintf("failed to register test node: %v", err)) } - registeredNode, err := hsdb.GetNodeByID(node.ID) - if err != nil { - panic(fmt.Sprintf("failed to get registered test node: %v", err)) - } - return registeredNode } @@ -842,3 +757,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int return nodes } + +// allocateTestIPs allocates sequential test IPs for nodes during testing. +func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) { + if !testing.Testing() { + panic("allocateTestIPs can only be called during tests") + } + + // Use simple sequential allocation for tests + // IPv4: 100.64.0.x (where x is nodeID) + // IPv6: fd7a:115c:a1e0::x (where x is nodeID) + + if nodeID > 254 { + return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID) + } + + ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)}) + ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)}) + + return &ipv4, &ipv6, nil +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 8819fbcf..84e30e0a 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) { func TestAutoApproveRoutes(t *testing.T) { tests := []struct { - name string - acl string - routes []netip.Prefix - want []netip.Prefix - want2 []netip.Prefix + name string + acl string + routes []netip.Prefix + want []netip.Prefix + want2 []netip.Prefix + expectChange bool // whether to expect route changes }{ + { + name: "no-auto-approvers-empty-policy", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ] +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - no auto-approvers + want2: []netip.Prefix{}, // Should be empty - no auto-approvers + expectChange: false, // No changes expected + }, + { + name: "no-auto-approvers-explicit-empty", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ], + "autoApprovers": { + "routes": {}, + "exitNode": [] + } +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + expectChange: false, // No changes expected + }, { name: "2068-approve-issue-sub-kube", acl: ` @@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) { } } }`, - routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, - want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + expectChange: true, // Routes should be approved }, { name: "2068-approve-issue-sub-exit-tag", @@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) { tsaddr.AllIPv4(), tsaddr.AllIPv6(), }, + expectChange: true, // Routes should be approved }, } @@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) { require.NoError(t, err) require.NotNil(t, pm) - changed1 := policy.AutoApproveRoutes(pm, &node) - assert.True(t, changed1) + newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes) + assert.Equal(t, tt.expectChange, changed1) - err = adb.DB.Save(&node).Error - require.NoError(t, err) + if changed1 { + err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1) + require.NoError(t, err) + } - _ = policy.AutoApproveRoutes(pm, &nodeTagged) - - err = adb.DB.Save(&nodeTagged).Error - require.NoError(t, err) + newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes) + if changed2 { + err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2) + require.NoError(t, err) + } node1ByID, err := adb.GetNodeByID(1) require.NoError(t, err) - if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { + // For empty auto-approvers tests, handle nil vs empty slice comparison + expectedRoutes1 := tt.want + if len(expectedRoutes1) == 0 { + expectedRoutes1 = nil + } + if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } node2ByID, err := adb.GetNodeByID(2) require.NoError(t, err) - if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { + expectedRoutes2 := tt.want2 + if len(expectedRoutes2) == 0 { + expectedRoutes2 = nil + } + if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } }) @@ -620,11 +679,11 @@ func TestRenameNode(t *testing.T) { require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, node, nil, nil) + _, err := RegisterNodeForTest(tx, node, nil, nil) if err != nil { return err } - _, err = RegisterNode(tx, node2, nil, nil) + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err }) @@ -721,11 +780,11 @@ func TestListPeers(t *testing.T) { require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, node1, nil, nil) + _, err := RegisterNodeForTest(tx, node1, nil, nil) if err != nil { return err } - _, err = RegisterNode(tx, node2, nil, nil) + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err }) @@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) { // No parameter means no filter, should return all peers nodes, err = db.ListPeers(1) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Empty node list should return all peers nodes, err = db.ListPeers(1, types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // No match in IDs should return empty list and no error @@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) { // Partial match in IDs nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs, but node ID is still filtered out nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) } @@ -806,11 +865,11 @@ func TestListNodes(t *testing.T) { require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, node1, nil, nil) + _, err := RegisterNodeForTest(tx, node1, nil, nil) if err != nil { return err } - _, err = RegisterNode(tx, node2, nil, nil) + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err }) @@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) { // No parameter means no filter, should return all nodes nodes, err = db.ListNodes() require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) // Empty node list should return all nodes nodes, err = db.ListNodes(types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) @@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) { // Partial match in IDs nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 2e60de2e..a36c1f13 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "slices" "strings" "time" @@ -47,8 +48,9 @@ func CreatePreAuthKey( return nil, err } - // Remove duplicates + // Remove duplicates and sort for consistency aclTags = set.SetOf(aclTags).Slice() + slices.Sort(aclTags) // TODO(kradalby): factor out and create a reusable tag validation, // check if there is one in Tailscale's lib. diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 1b333792..26d10060 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { } // AssignNodeToUser assigns a Node to a user. +// Note: Validation should be done in the state layer before calling this function. func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error { - node, err := GetNodeByID(tx, nodeID) - if err != nil { - return err + // Check if the user exists + var userExists bool + if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil { + return fmt.Errorf("failed to check if user exists: %w", err) } - user, err := GetUserByID(tx, uid) - if err != nil { - return err + + if !userExists { + return ErrUserNotFound } - node.User = *user - node.UserID = user.ID - if result := tx.Save(&node); result.Error != nil { - return result.Error + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil { + return fmt.Errorf("failed to assign node to user: %w", err) } return nil diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 60676a1d..c2b478b1 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { } sshPol := make(map[string]*tailcfg.SSHPolicy) - for _, node := range nodes { - pol, err := h.state.SSHPolicy(node.View()) + for _, node := range nodes.All() { + pol, err := h.state.SSHPolicy(node) if err != nil { httpError(w, err) return } - sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol + sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol } sshJSON, err := json.MarshalIndent(sshPol, "", " ") diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 722f8421..1b1a22e2 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "github.com/samber/lo" "google.golang.org/grpc/codes" @@ -25,6 +24,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/views" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/state" @@ -59,9 +59,10 @@ func (api headscaleV1APIServer) CreateUser( return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } - c := change.UserAdded(types.UserID(user.ID)) - if policyChanged { + + // TODO(kradalby): Both of these might be policy changes, find a better way to merge. + if !policyChanged.Empty() { c.Change = change.Policy } @@ -79,15 +80,13 @@ func (api headscaleV1APIServer) RenameUser( return nil, err } - _, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) + _, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } // Send policy update notifications if needed - if policyChanged { - api.h.Change(change.PolicyChange()) - } + api.h.Change(c) newUser, err := api.h.state.GetUserByName(request.GetNewName()) if err != nil { @@ -288,17 +287,13 @@ func (api headscaleV1APIServer) GetNode( ctx context.Context, request *v1.GetNodeRequest, ) (*v1.GetNodeResponse, error) { - node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } resp := node.Proto() - // Populate the online field based on - // currently connected nodes. - resp.Online = api.h.mapBatcher.IsConnected(node.ID) - return &v1.GetNodeResponse{Node: resp}, nil } @@ -323,7 +318,8 @@ func (api headscaleV1APIServer) SetTags( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Caller(). + Str("node", node.Hostname()). Strs("tags", request.GetTags()). Msg("Changing tags of node") @@ -334,7 +330,13 @@ func (api headscaleV1APIServer) SetApprovedRoutes( ctx context.Context, request *v1.SetApprovedRoutesRequest, ) (*v1.SetApprovedRoutesResponse, error) { - var routes []netip.Prefix + log.Debug(). + Caller(). + Uint64("node.id", request.GetNodeId()). + Strs("requestedRoutes", request.GetRoutes()). + Msg("gRPC SetApprovedRoutes called") + + var newApproved []netip.Prefix for _, route := range request.GetRoutes() { prefix, err := netip.ParsePrefix(route) if err != nil { @@ -344,31 +346,35 @@ func (api headscaleV1APIServer) SetApprovedRoutes( // If the prefix is an exit route, add both. The client expect both // to annotate the node as an exit node. if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() { - routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6()) + newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6()) } else { - routes = append(routes, prefix) + newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(routes) - routes = slices.Compact(routes) + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) - node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes) + node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) if err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } - routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) - // Always propagate node changes from SetApprovedRoutes api.h.Change(nodeChange) - // If routes changed, propagate those changes too - if !routeChange.Empty() { - api.h.Change(routeChange) - } - proto := node.Proto() - proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID)) + // Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the + // routes that are actively served from the node (per architectural requirement in types/node.go) + primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID()) + proto.SubnetRoutes = util.PrefixesToString(primaryRoutes) + + log.Debug(). + Caller(). + Uint64("node.id", node.ID().Uint64()). + Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). + Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)). + Strs("finalSubnetRoutes", proto.SubnetRoutes). + Msg("gRPC SetApprovedRoutes completed") return &v1.SetApprovedRoutesResponse{Node: proto}, nil } @@ -390,9 +396,9 @@ func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, ) (*v1.DeleteNodeResponse, error) { - node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } nodeChange, err := api.h.state.DeleteNode(node) @@ -420,8 +426,9 @@ func (api headscaleV1APIServer) ExpireNode( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). - Time("expiry", *node.Expiry). + Caller(). + Str("node", node.Hostname()). + Time("expiry", *node.AsStruct().Expiry). Msg("node expired") return &v1.ExpireNodeResponse{Node: node.Proto()}, nil @@ -440,7 +447,8 @@ func (api headscaleV1APIServer) RenameNode( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Caller(). + Str("node", node.Hostname()). Str("new_name", request.GetNewName()). Msg("node renamed") @@ -455,58 +463,45 @@ func (api headscaleV1APIServer) ListNodes( // the filtering of nodes by user, vs nodes as a whole can // probably be done once. // TODO(kradalby): This should be done in one tx. - - IsConnected := api.h.mapBatcher.ConnectedMap() if request.GetUser() != "" { user, err := api.h.state.GetUserByName(request.GetUser()) if err != nil { return nil, err } - nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID)) - if err != nil { - return nil, err - } + nodes := api.h.state.ListNodesByUser(types.UserID(user.ID)) - response := nodesToProto(api.h.state, IsConnected, nodes) + response := nodesToProto(api.h.state, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } - nodes, err := api.h.state.ListNodes() - if err != nil { - return nil, err - } + nodes := api.h.state.ListNodes() - sort.Slice(nodes, func(i, j int) bool { - return nodes[i].ID < nodes[j].ID - }) - - response := nodesToProto(api.h.state, IsConnected, nodes) + response := nodesToProto(api.h.state, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } -func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node { - response := make([]*v1.Node, len(nodes)) - for index, node := range nodes { +func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node { + response := make([]*v1.Node, nodes.Len()) + for index, node := range nodes.All() { resp := node.Proto() - // Populate the online field based on - // currently connected nodes. - if val, ok := IsConnected.Load(node.ID); ok && val { - resp.Online = true - } - var tags []string for _, tag := range node.RequestTags() { - if state.NodeCanHaveTag(node.View(), tag) { + if state.NodeCanHaveTag(node, tag) { tags = append(tags, tag) } } - resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...)) - resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...)) + resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...)) + + resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) response[index] = resp } + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) + return response } @@ -674,17 +669,15 @@ func (api headscaleV1APIServer) SetPolicy( // a scenario where they might be allowed if the server has no nodes // yet, but it should help for the general case and for hot reloading // configurations. - nodes, err := api.h.state.ListNodes() - if err != nil { - return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) - } - changed, err := api.h.state.SetPolicy([]byte(p)) + nodes := api.h.state.ListNodes() + + _, err := api.h.state.SetPolicy([]byte(p)) if err != nil { return nil, fmt.Errorf("setting policy: %w", err) } - if len(nodes) > 0 { - _, err = api.h.state.SSHPolicy(nodes[0].View()) + if nodes.Len() > 0 { + _, err = api.h.state.SSHPolicy(nodes.At(0)) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -695,14 +688,20 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - // Only send update if the packet filter has changed. - if changed { - err = api.h.state.AutoApproveNodes() - if err != nil { - return nil, err - } + // Always reload policy to ensure route re-evaluation, even if policy content hasn't changed. + // This ensures that routes are re-evaluated for auto-approval in cases where routes + // were manually disabled but could now be auto-approved with the current policy. + cs, err := api.h.state.ReloadPolicy() + if err != nil { + return nil, fmt.Errorf("reloading policy: %w", err) + } - api.h.Change(change.PolicyChange()) + if len(cs) > 0 { + api.h.Change(cs...) + } else { + log.Debug(). + Caller(). + Msg("No policy changes to distribute because ReloadPolicy returned empty changeset") } response := &v1.SetPolicyResponse{ diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 2d664104..cac4ff0f 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -94,13 +94,19 @@ func (h *Headscale) handleVerifyRequest( return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) } - nodes, err := h.state.ListNodes() - if err != nil { - return fmt.Errorf("cannot list nodes: %w", err) + nodes := h.state.ListNodes() + + // Check if any node has the requested NodeKey + var nodeKeyFound bool + for _, node := range nodes.All() { + if node.NodeKey() == derpAdmitClientRequest.NodePublic { + nodeKeyFound = true + break + } } resp := &tailcfg.DERPAdmitClientResponse{ - Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), + Allow: nodeKeyFound, } return json.NewEncoder(writer).Encode(resp) diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index bb69eac2..1299ed54 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -1,6 +1,7 @@ package mapper import ( + "errors" "fmt" "time" @@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher type Batcher interface { Start() Close() - AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error - RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error + RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] AddWork(c change.ChangeSet) @@ -120,7 +121,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet]. func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error { if nc == nil { - return fmt.Errorf("nodeConnection is nil") + return errors.New("nodeConnection is nil") } nodeID := nc.nodeID() diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index e733e29a..7476b72f 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -21,8 +21,7 @@ type LockFreeBatcher struct { mapper *mapper workers int - // Lock-free concurrent maps - nodes *xsync.Map[types.NodeID, *nodeConn] + nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] connected *xsync.Map[types.NodeID, *time.Time] // Work queue channel @@ -32,7 +31,6 @@ type LockFreeBatcher struct { // Batching state pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] - batchMutex sync.RWMutex // Metrics totalNodes atomic.Int64 @@ -45,65 +43,63 @@ type LockFreeBatcher struct { // AddNode registers a new node connection with the batcher and sends an initial map response. // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. -// TODO(kradalby): See if we can move the isRouter argument somewhere else. -func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error { - // First validate that we can generate initial map before doing anything else - fullSelfChange := change.FullSelf(id) +func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + addNodeStart := time.Now() - // TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange. - // This currently means that the goroutine for the node connection will do the processing - // which means that we might have uncontrolled concurrency. - // When we use MapResponseFromChange, it will be processed by the same worker pool, causing - // it to be processed in a more controlled manner. - initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange) - if err != nil { - return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) + // Generate connection ID + connID := generateConnectionID() + + // Create new connection entry + now := time.Now() + newEntry := &connectionEntry{ + id: connID, + c: c, + version: version, + created: now, } // Only after validation succeeds, create or update node connection newConn := newNodeConn(id, c, version, b.mapper) - var conn *nodeConn - if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded { - // Update existing connection - existing.updateConnection(c, version) - conn = existing - } else { + if !loaded { b.totalNodes.Add(1) conn = newConn } - // Mark as connected only after validation succeeds b.connected.Store(id, nil) // nil = connected - log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher") + if err != nil { + log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") + nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) + } - // Send the validated initial map - if initialMap != nil { - if err := conn.send(initialMap); err != nil { - // Clean up the connection state on send failure - b.nodes.Delete(id) - b.connected.Delete(id) - return fmt.Errorf("failed to send initial map to node %d: %w", id, err) - } - - // Notify other nodes that this node came online - b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter}) + // Use a blocking send with timeout for initial map since the channel should be ready + // and we want to avoid the race condition where the receiver isn't ready yet + select { + case c <- initialMap: + // Success + case <-time.After(5 * time.Second): + log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). + Msg("Initial map send timed out because channel was blocked or receiver not ready") + nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to send initial map to node %d: timeout", id) } return nil } // RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. -// It validates the connection channel matches the current one, closes the connection, -// and notifies other nodes that this node has gone offline. -func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) { - // Check if this is the current connection and mark it as closed - if existing, ok := b.nodes.Load(id); ok { - if !existing.matchesChannel(c) { - log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring") - return // Not the current connection, not an error - } +// It validates the connection channel matches one of the current connections, closes that specific connection, +// and keeps the node entry alive for rapid reconnections instead of aggressive deletion. +// Reports if the node still has active connections after removal. +func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + nodeConn, exists := b.nodes.Load(id) + if !exists { + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher") + return false + } // Mark the connection as closed to prevent further sends if connData := existing.connData.Load(); connData != nil { @@ -111,15 +107,20 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo } } - log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline") + // Check if node has any remaining active connections + if nodeConn.hasActiveConnections() { + log.Debug().Caller().Uint64("node.id", id.Uint64()). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("Node connection removed but keeping online because other connections remain") + return true // Node still has active connections + } // Remove node and mark disconnected atomically b.nodes.Delete(id) b.connected.Store(id, ptr.To(time.Now())) b.totalNodes.Add(-1) - // Notify other nodes that this node went offline - b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter}) + return false } // AddWork queues a change to be processed by the batcher. @@ -205,15 +206,6 @@ func (b *LockFreeBatcher) worker(workerID int) { return } - duration := time.Since(startTime) - if duration > 100*time.Millisecond { - log.Warn(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Dur("duration", duration). - Msg("slow synchronous work processing") - } continue } @@ -221,16 +213,8 @@ func (b *LockFreeBatcher) worker(workerID int) { // that should be processed and sent to the node instead of // returned to the caller. if nc, exists := b.nodes.Load(w.nodeID); exists { - // Check if this connection is still active before processing - if connData := nc.connData.Load(); connData != nil && connData.closed.Load() { - log.Debug(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Msg("skipping work for closed connection") - continue - } - + // Apply change to node - this will handle offline nodes gracefully + // and queue work for when they reconnect err := nc.change(w.c) if err != nil { b.workErrors.Add(1) @@ -240,52 +224,18 @@ func (b *LockFreeBatcher) worker(workerID int) { Str("change", w.c.Change.String()). Msg("failed to apply change") } - } else { - log.Debug(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Msg("node not found for asynchronous work - node may have disconnected") } - - duration := time.Since(startTime) - if duration > 100*time.Millisecond { - log.Warn(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Dur("duration", duration). - Msg("slow asynchronous work processing") - } - case <-b.ctx.Done(): return } } } -func (b *LockFreeBatcher) addWork(c change.ChangeSet) { - // For critical changes that need immediate processing, send directly - if b.shouldProcessImmediately(c) { - if c.SelfUpdateOnly { - b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil}) - return - } - b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool { - if c.NodeID == nodeID && !c.AlsoSelf() { - return true - } - b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil}) - return true - }) - return - } - - // For non-critical changes, add to batch - b.addToBatch(c) +func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) { + b.addToBatch(c...) } -// queueWork safely queues work +// queueWork safely queues work. func (b *LockFreeBatcher) queueWork(w work) { b.workQueuedCount.Add(1) @@ -298,26 +248,21 @@ func (b *LockFreeBatcher) queueWork(w work) { } } -// shouldProcessImmediately determines if a change should bypass batching -func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool { - // Process these changes immediately to avoid delaying critical functionality - switch c.Change { - case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy: - return true - default: - return false +// addToBatch adds a change to the pending batch. +func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) { + // Short circuit if any of the changes is a full update, which + // means we can skip sending individual changes. + if change.HasFull(c) { + b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}}) + + return true + }) + return } } -// addToBatch adds a change to the pending batch -func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { - b.batchMutex.Lock() - defer b.batchMutex.Unlock() - if c.SelfUpdateOnly { - changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{}) - changes = append(changes, c) - b.pendingChanges.Store(c.NodeID, changes) return } @@ -329,15 +274,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) changes = append(changes, c) b.pendingChanges.Store(nodeID, changes) + return true }) } -// processBatchedChanges processes all pending batched changes +// processBatchedChanges processes all pending batched changes. func (b *LockFreeBatcher) processBatchedChanges() { - b.batchMutex.Lock() - defer b.batchMutex.Unlock() - if b.pendingChanges == nil { return } @@ -355,16 +298,31 @@ func (b *LockFreeBatcher) processBatchedChanges() { // Clear the pending changes for this node b.pendingChanges.Delete(nodeID) + return true }) } // IsConnected is lock-free read. func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { - if val, ok := b.connected.Load(id); ok { - // nil means connected - return val == nil + // First check if we have active connections for this node + if nodeConn, exists := b.nodes.Load(id); exists { + if nodeConn.hasActiveConnections() { + return true + } } + + // Check disconnected timestamp with grace period + val, ok := b.connected.Load(id) + if !ok { + return false + } + + // nil means connected + if val == nil { + return true + } + return false } @@ -372,9 +330,26 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret := xsync.NewMap[types.NodeID, bool]() + // First, add all nodes with active connections + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + if nodeConn.hasActiveConnections() { + ret.Store(id, true) + } + return true + }) + + // Then add all entries from the connected map b.connected.Range(func(id types.NodeID, val *time.Time) bool { - // nil means connected - ret.Store(id, val == nil) + // Only add if not already added as connected above + if _, exists := ret.Load(id); !exists { + if val == nil { + // nil means connected + ret.Store(id, true) + } else { + // timestamp means disconnected + ret.Store(id, false) + } + } return true }) @@ -482,12 +457,21 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error { return fmt.Errorf("node %d: connection closed", nc.id) } - // TODO(kradalby): We might need some sort of timeout here if the client is not reading - // the channel. That might mean that we are sending to a node that has gone offline, but - // the channel is still open. - connData.c <- data - nc.updateCount.Add(1) - return nil + // Add all entries from the connected map to capture both connected and disconnected nodes + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // Only add if not already processed above + if _, exists := result[id]; !exists { + // Use immediate connection status for debug (no grace period) + connected := (val == nil) // nil means connected, timestamp means disconnected + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: 0, + } + } + return true + }) + + return result } func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 12bb37be..6cf63dca 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -27,6 +27,60 @@ type batcherTestCase struct { fn batcherFunc } +// testBatcherWrapper wraps a real batcher to add online/offline notifications +// that would normally be sent by poll.go in production. +type testBatcherWrapper struct { + Batcher + state *state.State +} + +func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + // Mark node as online in state before AddNode to match production behavior + // This ensures the NodeStore has correct online status for change processing + if t.state != nil { + // Use Connect to properly mark node online in NodeStore but don't send its changes + _ = t.state.Connect(id) + } + + // First add the node to the real batcher + err := t.Batcher.AddNode(id, c, version) + if err != nil { + return err + } + + // Send the online notification that poll.go would normally send + // This ensures other nodes get notified about this node coming online + t.AddWork(change.NodeOnline(id)) + + return nil +} + +func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + // Mark node as offline in state BEFORE removing from batcher + // This ensures the NodeStore has correct offline status when the change is processed + if t.state != nil { + // Use Disconnect to properly mark node offline in NodeStore but don't send its changes + _, _ = t.state.Disconnect(id) + } + + // Send the offline notification that poll.go would normally send + // Do this BEFORE removing from batcher so the change can be processed + t.AddWork(change.NodeOffline(id)) + + // Finally remove from the real batcher + removed := t.Batcher.RemoveNode(id, c) + if !removed { + return false + } + + return true +} + +// wrapBatcherForTest wraps a batcher with test-specific behavior. +func wrapBatcherForTest(b Batcher, state *state.State) Batcher { + return &testBatcherWrapper{Batcher: b, state: state} +} + // allBatcherFunctions contains all batcher implementations to test. var allBatcherFunctions = []batcherTestCase{ {"LockFree", NewBatcherAndMapper}, @@ -183,8 +237,8 @@ func setupBatcherWithTestData( "acls": [ { "action": "accept", - "users": ["*"], - "ports": ["*:*"] + "src": ["*"], + "dst": ["*:*"] } ] }` @@ -194,8 +248,8 @@ func setupBatcherWithTestData( t.Fatalf("Failed to set allow-all policy: %v", err) } - // Create batcher with the state - batcher := bf(cfg, state) + // Create batcher with the state and wrap it for testing + batcher := wrapBatcherForTest(bf(cfg, state), state) batcher.Start() testData := &TestData{ @@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) time.Sleep(100 * time.Millisecond) // Let connection settle // Generate some work @@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) for i := range allNodes { node := &allNodes[i] - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullSet) @@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Disconnect all nodes for i := range allNodes { node := &allNodes[i] - batcher.RemoveNode(node.n.ID, node.ch, false) + batcher.RemoveNode(node.n.ID, node.ch) } // Give time for final updates to process @@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := testData.Nodes[1] // Test AddNode with real node ID - batcher.AddNode(tn.n.ID, tn.ch, false, 100) + batcher.AddNode(tn.n.ID, tn.ch, 100) + if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") } @@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) // Add the second node and verify update message - batcher.AddNode(tn2.n.ID, tn2.ch, false, 100) + batcher.AddNode(tn2.n.ID, tn2.ch, 100) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. select { case data := <-tn.ch: assertOnlineMapResponse(t, data, true) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Did not receive expected Online response update") } @@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) { } // Disconnect the second node - batcher.RemoveNode(tn2.n.ID, tn2.ch, false) - assert.False(t, batcher.IsConnected(tn2.n.ID)) + batcher.RemoveNode(tn2.n.ID, tn2.ch) + // Note: IsConnected may return true during grace period for DNS resolution // First node should get update that second has disconnected. select { case data := <-tn.ch: assertOnlineMapResponse(t, data, false) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Did not receive expected Online response update") } @@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) { // } // Test RemoveNode - batcher.RemoveNode(tn.n.ID, tn.ch, false) - if batcher.IsConnected(tn.n.ID) { - t.Error("Node should be disconnected after RemoveNode") - } + batcher.RemoveNode(tn.n.ID, tn.ch) + // Note: IsConnected may return true during grace period for DNS resolution + // The node is actually removed from active connections but grace period allows DNS lookups }) } } @@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100)) + + batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) }() // Add real work during connection chaos @@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { go func() { defer wg.Done() time.Sleep(1 * time.Microsecond) - batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }() // Remove second connection @@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { go func() { defer wg.Done() time.Sleep(2 * time.Microsecond) - batcher.RemoveNode(testNode.n.ID, ch2, false) + batcher.RemoveNode(testNode.n.ID, ch2) }() wg.Wait() @@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) batcher.AddWork(change.DERPSet) // Consumer goroutine to validate data and detect channel issues @@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Rapid removal creates race between worker and removal time.Sleep(time.Duration(i%3) * 100 * time.Microsecond) - batcher.RemoveNode(testNode.n.ID, ch, false) + batcher.RemoveNode(testNode.n.ID, ch) // Give workers time to process and close channels time.Sleep(5 * time.Millisecond) @@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for _, node := range stableNodes { ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) { churningChannelsMutex.Lock() churningChannels[nodeID] = ch churningChannelsMutex.Unlock() - batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { @@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) { ch, exists := churningChannels[nodeID] churningChannelsMutex.Unlock() if exists { - batcher.RemoveNode(nodeID, ch, false) + batcher.RemoveNode(nodeID, ch) } }(node.n.ID) } @@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) { var connectedNodesMutex sync.RWMutex for i := range testNodes { node := &testNodes[i] - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true connectedNodesMutex.Unlock() @@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) { connectedNodesMutex.RUnlock() if isConnected { - batcher.RemoveNode(nodeID, channel, false) + batcher.RemoveNode(nodeID, channel) connectedNodesMutex.Lock() connectedNodes[nodeID] = false connectedNodesMutex.Unlock() @@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) { // Now disconnect all nodes from batcher to stop new updates for i := range testNodes { node := &testNodes[i] - batcher.RemoveNode(node.n.ID, node.ch, false) + batcher.RemoveNode(node.n.ID, node.ch) } // Give time for enhanced tracking goroutines to process any remaining data in channels @@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time to avoid overwhelming the work queue for i, node := range allNodes { - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Small delay between connections to allow NodeCameOnline processing time.Sleep(50 * time.Millisecond) @@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Check how many peers each node should see for i, node := range allNodes { - peers, err := testData.State.ListPeers(node.n.ID) - if err != nil { - t.Errorf("Error listing peers for node %d: %v", i, err) - } else { - t.Logf("Node %d should see %d peers from state", i, len(peers)) - } + peers := testData.State.ListPeers(node.n.ID) + t.Logf("Node %d should see %d peers from state", i, peers.Len()) } // Send a full update - this should generate full peer lists @@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { foundFullUpdate := false // Read all available updates for each node - for i := range len(allNodes) { + for i := range allNodes { nodeUpdates := 0 t.Logf("Reading updates for node %d:", i) @@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) { t.Logf("=== WORK QUEUE TRACING TEST ===") - // Connect first node - batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100)) - t.Logf("Connected node %d", nodes[0].n.ID) + time.Sleep(100 * time.Millisecond) // Let connections settle // Wait for initial NodeCameOnline to be processed time.Sleep(200 * time.Millisecond) @@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) { t.Errorf("ERROR: Received unknown update type!") } - // Check if there should be peers available - peers, err := testData.State.ListPeers(nodes[0].n.ID) - if err != nil { - t.Errorf("Error getting peers from state: %v", err) - } else { - t.Logf("State shows %d peers available for this node", len(peers)) - if len(peers) > 0 && len(data.Peers) == 0 { - t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers)) + batcher := testData.Batcher + node1 := testData.Nodes[0] + node2 := testData.Nodes[1] + + t.Logf("=== MULTI-CONNECTION TEST ===") + + // Phase 1: Connect first node with initial connection + t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node1: %v", err) + } + + // Connect second node for comparison + err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node2: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Phase 2: Add second connection for node1 (multi-connection scenario) + t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add second connection for node1: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Phase 3: Add third connection for node1 + t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add third connection for node1: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Phase 4: Verify debug status shows correct connection count + t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + + if info, exists := debugInfo[node1.n.ID]; exists { + t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 3 { + t.Errorf("Node1 should have 3 active connections, got %d", activeConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 3 active connections") + } + } + if connected, ok := infoMap["connected"].(bool); ok && !connected { + t.Errorf("Node1 should show as connected with 3 active connections") + } + } + } + + if info, exists := debugInfo[node2.n.ID]; exists { + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 1 { + t.Errorf("Node2 should have 1 active connection, got %d", activeConnections) + } + } + } + } + } + + // Phase 5: Send update and verify ALL connections receive it + t.Logf("Phase 5: Testing update distribution to all connections...") + + // Clear any existing updates from all channels + clearChannel := func(ch chan *tailcfg.MapResponse) { + for { + select { + case <-ch: + // drain + default: + return + } + } + } + + clearChannel(node1.ch) + clearChannel(secondChannel) + clearChannel(thirdChannel) + clearChannel(node2.ch) + + // Send a change notification from node2 (so node1 should receive it on all connections) + testChangeSet := change.ChangeSet{ + NodeID: node2.n.ID, + Change: change.NodeNewOrUpdate, + SelfUpdateOnly: false, + } + + batcher.AddWork(testChangeSet) + + time.Sleep(100 * time.Millisecond) // Let updates propagate + + // Verify all three connections for node1 receive the update + connection1Received := false + connection2Received := false + connection3Received := false + + select { + case mapResp := <-node1.ch: + connection1Received = (mapResp != nil) + t.Logf("Node1 connection 1 received update: %t", connection1Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 1 did not receive update") + } + + select { + case mapResp := <-secondChannel: + connection2Received = (mapResp != nil) + t.Logf("Node1 connection 2 received update: %t", connection2Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 2 did not receive update") + } + + select { + case mapResp := <-thirdChannel: + connection3Received = (mapResp != nil) + t.Logf("Node1 connection 3 received update: %t", connection3Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 3 did not receive update") + } + + if connection1Received && connection2Received && connection3Received { + t.Logf("SUCCESS: All three connections for node1 received the update") + } else { + t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t", + connection1Received, connection2Received, connection3Received) + } + + // Phase 6: Test connection removal and verify remaining connections still work + t.Logf("Phase 6: Testing connection removal...") + + // Remove the second connection + removed := batcher.RemoveNode(node1.n.ID, secondChannel) + if !removed { + t.Errorf("Failed to remove second connection for node1") + } + + time.Sleep(50 * time.Millisecond) + + // Verify debug status shows 2 connections now + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + if info, exists := debugInfo[node1.n.ID]; exists { + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 2 { + t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal") + } } } } else { diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index dfe9d68d..dc43b933 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -1,6 +1,7 @@ package mapper import ( + "errors" "net/netip" "sort" "time" @@ -12,7 +13,7 @@ import ( "tailscale.com/util/multierr" ) -// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse +// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse. type MapResponseBuilder struct { resp *tailcfg.MapResponse mapper *mapper @@ -21,7 +22,17 @@ type MapResponseBuilder struct { errs []error } -// NewMapResponseBuilder creates a new builder with basic fields set +type debugType string + +const ( + fullResponseDebug debugType = "full" + patchResponseDebug debugType = "patch" + removeResponseDebug debugType = "remove" + changeResponseDebug debugType = "change" + derpResponseDebug debugType = "derp" +) + +// NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() return &MapResponseBuilder{ @@ -35,32 +46,39 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder } } -// addError adds an error to the builder's error list +// addError adds an error to the builder's error list. func (b *MapResponseBuilder) addError(err error) { if err != nil { b.errs = append(b.errs, err) } } -// hasErrors returns true if the builder has accumulated any errors +// hasErrors returns true if the builder has accumulated any errors. func (b *MapResponseBuilder) hasErrors() bool { return len(b.errs) > 0 } -// WithCapabilityVersion sets the capability version for the response +// WithCapabilityVersion sets the capability version for the response. func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder { b.capVer = capVer return b } -// WithSelfNode adds the requesting node to the response +// WithSelfNode adds the requesting node to the response. func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } + // Always use batcher's view of online status for self node + // The batcher respects grace periods for logout scenarios + node := nodeView.AsStruct() + // if b.mapper.batcher != nil { + // node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID)) + // } + _, matchers := b.mapper.state.Filter() tailnode, err := tailNode( node.View(), b.capVer, b.mapper.state, @@ -74,29 +92,38 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { } b.resp.Node = tailnode + return b } -// WithDERPMap adds the DERP map to the response +func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder { + if debugDumpMapResponsePath != "" { + b.debugType = t + } + + return b +} + +// WithDERPMap adds the DERP map to the response. func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder { b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct() return b } -// WithDomain adds the domain configuration +// WithDomain adds the domain configuration. func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder { b.resp.Domain = b.mapper.cfg.Domain() return b } -// WithCollectServicesDisabled sets the collect services flag to false +// WithCollectServicesDisabled sets the collect services flag to false. func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder { b.resp.CollectServices.Set(false) return b } // WithDebugConfig adds debug configuration -// It disables log tailing if the mapper's LogTail is not enabled +// It disables log tailing if the mapper's LogTail is not enabled. func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, @@ -104,53 +131,56 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { return b } -// WithSSHPolicy adds SSH policy configuration for the requesting node +// WithSSHPolicy adds SSH policy configuration for the requesting node. func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } - sshPolicy, err := b.mapper.state.SSHPolicy(node.View()) + sshPolicy, err := b.mapper.state.SSHPolicy(node) if err != nil { b.addError(err) return b } b.resp.SSHPolicy = sshPolicy + return b } -// WithDNSConfig adds DNS configuration for the requesting node +// WithDNSConfig adds DNS configuration for the requesting node. func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node) + return b } -// WithUserProfiles adds user profiles for the requesting node and given peers -func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) +// WithUserProfiles adds user profiles for the requesting node and given peers. +func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } b.resp.UserProfiles = generateUserProfiles(node, peers) + return b } -// WithPacketFilters adds packet filter rules based on policy +// WithPacketFilters adds packet filter rules based on policy. func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } @@ -161,15 +191,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node.View(), filter), + "base": policy.ReduceFilterRules(node, filter), } return b } -// WithPeers adds full peer list with policy filtering (for full map response) -func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { - +// WithPeers adds full peer list with policy filtering (for full map response). +func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder { tailPeers, err := b.buildTailPeers(peers) if err != nil { b.addError(err) @@ -177,12 +206,12 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { } b.resp.Peers = tailPeers + return b } -// WithPeerChanges adds changed peers with policy filtering (for incremental updates) -func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder { - +// WithPeerChanges adds changed peers with policy filtering (for incremental updates). +func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder { tailPeers, err := b.buildTailPeers(peers) if err != nil { b.addError(err) @@ -190,14 +219,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil } b.resp.PeersChanged = tailPeers + return b } -// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting -func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - return nil, err +// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting. +func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + return nil, errors.New("node not found") } filter, matchers := b.mapper.state.Filter() @@ -206,15 +236,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, // access each-other at all and remove them from the peers. var changedViews views.Slice[types.NodeView] if len(filter) > 0 { - changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers) + changedViews = policy.ReduceNodes(node, peers, matchers) } else { - changedViews = peers.ViewSlice() + changedViews = peers } tailPeers, err := tailNodes( changedViews, b.capVer, b.mapper.state, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) }, b.mapper.cfg) if err != nil { @@ -229,19 +259,20 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, return tailPeers, nil } -// WithPeerChangedPatch adds peer change patches +// WithPeerChangedPatch adds peer change patches. func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { b.resp.PeersChangedPatch = changes return b } -// WithPeersRemoved adds removed peer IDs +// WithPeersRemoved adds removed peer IDs. func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { var tailscaleIDs []tailcfg.NodeID for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } b.resp.PeersRemoved = tailscaleIDs + return b } @@ -251,11 +282,7 @@ func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { return nil, multierr.New(b.errs...) } if debugDumpMapResponsePath != "" { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - return nil, err - } - writeDebugMapResponse(b.resp, node) + writeDebugMapResponse(b.resp, b.debugType, b.nodeID) } return b.resp, nil diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 59c92e24..bb8340d0 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -19,6 +19,7 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/views" ) const ( @@ -69,16 +70,18 @@ func newMapper( } func generateUserProfiles( - node *types.Node, - peers types.Nodes, + node types.NodeView, + peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { userMap := make(map[uint]*types.User) ids := make([]uint, 0, len(userMap)) - userMap[node.User.ID] = &node.User - ids = append(ids, node.User.ID) - for _, peer := range peers { - userMap[peer.User.ID] = &peer.User - ids = append(ids, peer.User.ID) + user := node.User() + userMap[user.ID] = &user + ids = append(ids, user.ID) + for _, peer := range peers.All() { + peerUser := peer.User() + userMap[peerUser.ID] = &peerUser + ids = append(ids, peerUser.ID) } slices.Sort(ids) @@ -95,7 +98,7 @@ func generateUserProfiles( func generateDNSConfig( cfg *types.Config, - node *types.Node, + node types.NodeView, ) *tailcfg.DNSConfig { if cfg.TailcfgDNSConfig == nil { return nil @@ -115,12 +118,12 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ - "device_name": []string{node.Hostname}, - "device_model": []string{node.Hostinfo.OS}, + "device_name": []string{node.Hostname()}, + "device_model": []string{node.Hostinfo().OS()}, } if len(node.IPs()) > 0 { @@ -138,10 +141,7 @@ func (m *mapper) fullMapResponse( capVer tailcfg.CapabilityVersion, messages ...string, ) (*tailcfg.MapResponse, error) { - peers, err := m.listPeers(nodeID) - if err != nil { - return nil, err - } + peers := m.state.ListPeers(nodeID) return m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). @@ -183,10 +183,7 @@ func (m *mapper) peerChangeResponse( capVer tailcfg.CapabilityVersion, changedNodeID types.NodeID, ) (*tailcfg.MapResponse, error) { - peers, err := m.listPeers(nodeID, changedNodeID) - if err != nil { - return nil, err - } + peers := m.state.ListPeers(nodeID, changedNodeID) return m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). @@ -208,7 +205,8 @@ func (m *mapper) peerRemovedResponse( func writeDebugMapResponse( resp *tailcfg.MapResponse, - node *types.Node, + t debugType, + nodeID types.NodeID, ) { body, err := json.MarshalIndent(resp, "", " ") if err != nil { @@ -236,25 +234,6 @@ func writeDebugMapResponse( } } -// listPeers returns peers of node, regardless of any Policy or if the node is expired. -// If no peer IDs are given, all peers are returned. -// If at least one peer ID is given, only these peer nodes will be returned. -func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - peers, err := m.state.ListPeers(nodeID, peerIDs...) - if err != nil { - return nil, err - } - - // TODO(kradalby): Add back online via batcher. This was removed - // to avoid a circular dependency between the mapper and the notification. - for _, peer := range peers { - online := m.batcher.IsConnected(peer.ID) - peer.IsOnline = &online - } - - return peers, nil -} - // routeFilterFunc is a function that takes a node ID and returns a list of // netip.Prefixes that are allowed for that node. It is used to filter routes // from the primary route manager to the node. diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 198ba6c4..b801f7dd 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) { &types.Config{ TailcfgDNSConfig: &dnsConfigOrig, }, - nodeInShared1, + nodeInShared1.View(), ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 9729301d..3a518d94 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -133,13 +133,12 @@ func tailNode( tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} } - if !node.IsOnline().Valid() || !node.IsOnline().Get() { - // LastSeen is only set when node is - // not connected to the control server. - if node.LastSeen().Valid() { - lastSeen := node.LastSeen().Get() - tNode.LastSeen = &lastSeen - } + // Set LastSeen only for offline nodes to avoid confusing Tailscale clients + // during rapid reconnection cycles. Online nodes should not have LastSeen set + // as this can make clients interpret them as "not online" despite Online=true. + if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() { + lastSeen := node.LastSeen().Get() + tNode.LastSeen = &lastSeen } return &tNode, nil diff --git a/hscontrol/noise.go b/hscontrol/noise.go index db39992e..bb59fea6 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -13,7 +13,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "golang.org/x/net/http2" - "gorm.io/gorm" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/tailcfg" @@ -296,16 +295,11 @@ func (ns *noiseServer) NoiseRegistrationHandler( // getAndValidateNode retrieves the node from the database using the NodeKey // and validates that it matches the MachineKey from the Noise session. func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { - node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) - } - return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil) + nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) + if !ok { + return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) } - nv := node.View() - // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. if ns.machineKey != nv.MachineKey() { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 68361cae..021a6272 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( util.LogErr(err, "could not get userinfo; only using claims from id token") } - // The user claims are now updated from the the userinfo endpoint so we can verify the user a + // The user claims are now updated from the userinfo endpoint so we can verify the user // against allowed emails, email domains, and groups. if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { httpError(writer, err) @@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims) + user, c, err := a.createOrUpdateUserFromClaim(&claims) if err != nil { log.Error(). Err(err). @@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } // Send policy update notifications if needed - if policyChanged { - a.h.Change(change.PolicyChange()) - } + a.h.Change(c) // TODO(kradalby): Is this comment right? // If the node exists, then the node should be reauthenticated, @@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( claims *types.OIDCClaims, -) (*types.User, bool, error) { +) (*types.User, change.ChangeSet, error) { var user *types.User var err error var newUser bool - var policyChanged bool + var c change.ChangeSet user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, false, fmt.Errorf("creating or updating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err) } // if the user is still not found, create a new empty user. @@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( user.FromClaim(claims) if newUser { - user, policyChanged, err = a.h.state.CreateUser(*user) + user, c, err = a.h.state.CreateUser(*user) if err != nil { - return nil, false, fmt.Errorf("creating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } } else { - _, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { + _, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { *u = *user return nil }) if err != nil { - return nil, false, fmt.Errorf("updating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("updating user: %w", err) } } - return user, policyChanged, nil + return user, c, nil } func (a *AuthProviderOIDC) handleRegistration( diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 52457c9b..6a74e59f 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "github.com/samber/lo" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -138,39 +139,74 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf return ret } -// AutoApproveRoutes approves any route that can be autoapproved from -// the nodes perspective according to the given policy. -// It reports true if any routes were approved. -// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes. -func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { +// ApproveRoutesWithPolicy checks if the node can approve the announced routes +// and returns the new list of approved routes. +// The approved routes will include: +// 1. ALL previously approved routes (regardless of whether they're still advertised) +// 2. New routes from announcedRoutes that can be auto-approved by policy +// This ensures that: +// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes) +// - New routes can be auto-approved according to policy +// - Routes can only be removed by explicit admin action (not by auto-approval). +func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) { if pm == nil { - return false + return currentApproved, false } - nodeView := node.View() - var newApproved []netip.Prefix - for _, route := range nodeView.AnnouncedRoutes() { - if pm.NodeCanApproveRoute(nodeView, route) { + + // Start with ALL currently approved routes - we never remove approved routes + newApproved := make([]netip.Prefix, len(currentApproved)) + copy(newApproved, currentApproved) + + // Then, check for new routes that can be auto-approved + for _, route := range announcedRoutes { + // Skip if already approved + if slices.Contains(newApproved, route) { + continue + } + + // Check if this new route can be auto-approved by policy + canApprove := pm.NodeCanApproveRoute(nv, route) + if canApprove { newApproved = append(newApproved, route) } } - // Only modify ApprovedRoutes if we have new routes to approve. - // This prevents clearing existing approved routes when nodes - // temporarily don't have announced routes during policy changes. - if len(newApproved) > 0 { - combined := append(newApproved, node.ApprovedRoutes...) - tsaddr.SortPrefixes(combined) - combined = slices.Compact(combined) - combined = lo.Filter(combined, func(route netip.Prefix, index int) bool { - return route.IsValid() - }) + // Sort and deduplicate + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) + newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { + return route.IsValid() + }) - // Only update if the routes actually changed - if !slices.Equal(node.ApprovedRoutes, combined) { - node.ApprovedRoutes = combined - return true + // Sort the current approved for comparison + sortedCurrent := make([]netip.Prefix, len(currentApproved)) + copy(sortedCurrent, currentApproved) + tsaddr.SortPrefixes(sortedCurrent) + + // Only update if the routes actually changed + if !slices.Equal(sortedCurrent, newApproved) { + // Log what changed + var added, kept []netip.Prefix + for _, route := range newApproved { + if !slices.Contains(sortedCurrent, route) { + added = append(added, route) + } else { + kept = append(kept, route) + } } + + if len(added) > 0 { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.added", util.PrefixesToString(added)). + Strs("routes.kept", util.PrefixesToString(kept)). + Int("routes.total", len(newApproved)). + Msg("Routes auto-approved by policy") + } + + return newApproved, true } - return false + return newApproved, false } diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go new file mode 100644 index 00000000..6c0908b9 --- /dev/null +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -0,0 +1,339 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/types/key" + "tailscale.com/types/ptr" + "tailscale.com/types/views" +) + +func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { + user1 := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser@", + } + user2 := types.User{ + Model: gorm.Model{ID: 2}, + Name: "otheruser@", + } + users := []types.User{user1, user2} + + node1 := &types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test-node", + UserID: user1.ID, + User: user1, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ForcedTags: []string{"tag:test"}, + } + + node2 := &types.Node{ + ID: 2, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "other-node", + UserID: user2.ID, + User: user2, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + } + + // Create a policy that auto-approves specific routes + policyJSON := `{ + "groups": { + "group:test": ["testuser@"] + }, + "tagOwners": { + "tag:test": ["testuser@"] + }, + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + } + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["testuser@", "tag:test"], + "10.1.0.0/24": ["testuser@"], + "10.2.0.0/24": ["testuser@"], + "192.168.0.0/24": ["tag:test"] + } + } + }` + + pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) + assert.NoError(t, err) + + tests := []struct { + name string + node *types.Node + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + description string + }{ + { + name: "previously_approved_route_no_longer_advertised_should_remain", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Should still be here! + }, + wantChanged: false, + description: "Previously approved routes should never be removed even when no longer advertised", + }, + { + name: "add_new_auto_approved_route_keeps_old_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8) + netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept + }, + wantChanged: true, + description: "New auto-approved routes should be added while keeping old approved routes", + }, + { + name: "no_announced_routes_keeps_all_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + description: "All approved routes should remain when no routes are announced", + }, + { + name: "no_changes_when_announced_equals_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + description: "No changes should occur when announced routes match approved routes", + }, + { + name: "auto_approve_multiple_new_routes", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) + netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved + netip.MustParsePrefix("172.16.0.0/24"), // Original kept + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + }, + wantChanged: true, + description: "Multiple new routes should be auto-approved while keeping existing approved routes", + }, + { + name: "node_without_permission_no_auto_approval", + node: node2, // Different node without the tag + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route + }, + wantChanged: false, + description: "Routes should not be auto-approved for nodes without proper permissions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) + + // Sort for comparison since ApproveRoutesWithPolicy sorts the results + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) + + // Verify that all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should never happen", prevRoute) + } + }) + } +} + +func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { + // Create a basic policy for edge case testing + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.1.0.0/24": ["test@"], + }, + }, +}` + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_policy_manager", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "nil_announced_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: nil, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "duplicate_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_slices", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{}, + wantApproved: []netip.Prefix{}, + wantChanged: false, + }, + } + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + // Create policy manager or use nil if specified + var pm PolicyManager + var err error + if tt.name != "nil_policy_manager" { + pm, err = pmf(users, nodes.ViewSlice()) + assert.NoError(t, err) + } else { + pm = nil + } + + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch") + + // Handle nil vs empty slice comparison + if tt.wantApproved == nil { + assert.Nil(t, gotApproved, "expected nil approved routes") + } else { + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") + } + }) + } + } +} diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go new file mode 100644 index 00000000..610ce7b1 --- /dev/null +++ b/hscontrol/policy/policy_route_approval_test.go @@ -0,0 +1,361 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { + // Test policy that allows specific routes to be auto-approved + aclPolicy := ` +{ + "groups": { + "group:admins": ["test@"], + }, + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/24": ["test@"], + "192.168.0.0/24": ["group:admins"], + "172.16.0.0/16": ["tag:approved"], + }, + }, + "tagOwners": { + "tag:approved": ["test@"], + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + nodeHostname string + nodeUser string + nodeTags []string + wantApproved []netip.Prefix + wantChanged bool + wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result + }{ + { + name: "previously_approved_route_no_longer_advertised_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Should remain! + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed + }, + { + name: "add_new_auto_approved_route_keeps_existing", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Still advertised + netip.MustParsePrefix("192.168.0.0/24"), // New route + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group + }, + wantChanged: true, + }, + { + name: "no_announced_routes_keeps_all_approved", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced anymore + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + }, + { + name: "manually_approved_route_not_in_policy_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved + }, + wantChanged: true, + }, + { + name: "tagged_node_gets_tag_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route + }, + nodeUser: "test", + nodeTags: []string{"tag:approved"}, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved + netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + }, + wantChanged: true, + }, + { + name: "complex_scenario_multiple_changes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised + netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable + netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) + netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised + }, + wantChanged: true, + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: tt.nodeUser, + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: tt.nodeHostname, + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + ForcedTags: tt.nodeTags, + } + nodes := types.Nodes{&node} + + // Create policy manager + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, pm) + + // Test ApproveRoutesWithPolicy + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + // Check change flag + assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch") + + // Check approved routes match expected + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Logf("Want: %v", tt.wantApproved) + t.Logf("Got: %v", gotApproved) + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + + // Verify all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should NEVER happen", prevRoute) + } + + // Verify no routes were incorrectly removed + for _, removedRoute := range tt.wantRemovedRoutes { + assert.NotContains(t, gotApproved, removedRoute, + "route %s should have been removed but wasn't", removedRoute) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["test@"], + }, + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_current_approved", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "duplicate_routes_handled", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, // Duplicates are removed, so it's a change + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + assert.Equal(t, tt.wantChanged, gotChanged) + + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + + currentApproved := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + } + announcedRoutes := []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + } + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: currentApproved, + } + + // With nil policy manager, should return current approved unchanged + gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes) + + assert.False(t, gotChanged) + assert.Equal(t, currentApproved, gotApproved) +} diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 5e332fd3..1e6fabf3 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, canApprove: false, }, + { + name: "policy-without-autoApprovers-section", + node: normalNode, + route: p("10.33.0.0/16"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admin"], + "dst": ["group:admin:*"] + }, + { + "action": "accept", + "src": ["group:admin"], + "dst": ["10.33.0.0/16:*"] + } + ] + }`, + canApprove: false, + }, } for _, tt := range tests { diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index de839770..5e7aa34b 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // The fast path is that a node requests to approve a prefix // where there is an exact entry, e.g. 10.0.0.0/8, then // check and return quickly - if _, ok := pm.autoApproveMap[route]; ok { - if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) { + if approvers, ok := pm.autoApproveMap[route]; ok { + canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains) + if canApprove { return true } } @@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // Check if prefix is larger (so containing) and then overlaps // the route to see if the node can approve a subset of an autoapprover if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { - if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) { + canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains) + if canApprove { return true } } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 1833f060..4809257b 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -10,7 +10,6 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" @@ -112,6 +111,15 @@ func (m *mapSession) serve() { // This is the mechanism where the node gives us information about its // current configuration. // + // Process the MapRequest to update node state (endpoints, hostinfo, etc.) + c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + httpError(m.w, err) + return + } + + m.h.Change(c) + // If OmitPeers is true and Stream is false // then the server will let clients update their endpoints without // breaking existing long-polling (Stream == true) connections. @@ -122,14 +130,6 @@ func (m *mapSession) serve() { // the response and just wants a 200. // !req.stream && req.OmitPeers if m.isEndpointUpdate() { - c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req) - if err != nil { - httpError(m.w, err) - return - } - - m.h.Change(c) - m.w.WriteHeader(http.StatusOK) mapResponseEndpointUpdates.WithLabelValues("ok").Inc() } @@ -142,6 +142,8 @@ func (m *mapSession) serve() { func (m *mapSession) serveLongPoll() { m.beforeServeLongPoll() + log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected") + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -149,18 +151,38 @@ func (m *mapSession) serveLongPoll() { close(m.cancelCh) m.cancelChMu.Unlock() - // TODO(kradalby): This can likely be made more effective, but likely most - // nodes has access to the same routes, so it might not be a big deal. - disconnectChange, err := m.h.state.Disconnect(m.node) - if err != nil { - m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) + + // When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather). + // Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects. + // If it does reconnect, the existing mapSession will be replaced and the node remains online. + // If it doesn't reconnect within the timeout, we mark it as offline. + // + // This avoids flapping nodes in the UI and unnecessary churn in the network. + // This is not my favourite solution, but it kind of works in our eventually consistent world. + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + disconnected := true + // Wait up to 10 seconds for the node to reconnect. + // 10 seconds was arbitrary chosen as a reasonable time to reconnect. + for range 10 { + if m.h.mapBatcher.IsConnected(m.node.ID) { + disconnected = false + break + } + <-ticker.C } - m.h.Change(disconnectChange) - m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter()) + if disconnected { + disconnectChanges, err := m.h.state.Disconnect(m.node.ID) + if err != nil { + m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + } - m.afterServeLongPoll() - m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) + m.h.Change(disconnectChanges...) + m.afterServeLongPoll() + m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) + } }() // Set up the client stream @@ -172,25 +194,25 @@ func (m *mapSession) serveLongPoll() { m.keepAliveTicker = time.NewTicker(m.keepAlive) - // Add node to batcher BEFORE sending Connect change to prevent race condition - // where the change is sent before the node is in the batcher's node map - if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil { - m.errf(err, "failed to add node to batcher") - // Send empty response to client to fail fast for invalid/non-existent nodes - select { - case m.ch <- &tailcfg.MapResponse{}: - default: - // Channel might be closed - } + // Process the initial MapRequest to update node state (endpoints, hostinfo, etc.) + // CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly + // synchronized. When nodes reconnect, they send their hostinfo with announced routes + // in the MapRequest. We need this data in NodeStore before Connect() sets up the + // primary routes, otherwise SubnetRoutes() returns empty and the node is removed + // from AvailableRoutes. + mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + m.errf(err, "failed to update node from initial MapRequest") return } - // Now send the Connect change - the batcher handles NodeCameOnline internally - // but we still need to update routes and other state-level changes - connectChange := m.h.state.Connect(m.node) - if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline { - m.h.Change(connectChange) - } + // Connect the node after its state has been updated. + // We send two separate change notifications because these are distinct operations: + // 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo) + // 2. Connect: marks the node online and recalculates primary routes based on the updated state + // While this results in two notifications, it ensures route data is synchronized before + // primary route selection occurs, which is critical for proper HA subnet router failover. + connectChanges := m.h.state.Connect(m.node.ID) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) @@ -235,6 +257,7 @@ func (m *mapSession) serveLongPoll() { mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) } mapResponseSent.WithLabelValues("ok", "keepalive").Inc() + m.resetKeepAlive() } } } diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go new file mode 100644 index 00000000..3fd50d26 --- /dev/null +++ b/hscontrol/state/node_store.go @@ -0,0 +1,403 @@ +package state + +import ( + "fmt" + "maps" + "strings" + "sync/atomic" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +const ( + batchSize = 10 + batchTimeout = 500 * time.Millisecond +) + +const ( + put = 1 + del = 2 + update = 3 +) + +const prometheusNamespace = "headscale" + +var ( + nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operations_total", + Help: "Total number of NodeStore operations", + }, []string{"operation"}) + nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operation_duration_seconds", + Help: "Duration of NodeStore operations", + Buckets: prometheus.DefBuckets, + }, []string{"operation"}) + nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_size", + Help: "Size of NodeStore write batches", + Buckets: []float64{1, 2, 5, 10, 20, 50, 100}, + }) + nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_duration_seconds", + Help: "Duration of NodeStore batch processing", + Buckets: prometheus.DefBuckets, + }) + nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_snapshot_build_duration_seconds", + Help: "Duration of NodeStore snapshot building from nodes", + Buckets: prometheus.DefBuckets, + }) + nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_nodes_total", + Help: "Total number of nodes in the NodeStore", + }) + nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_peers_calculation_duration_seconds", + Help: "Duration of peers calculation in NodeStore", + Buckets: prometheus.DefBuckets, + }) + nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_queue_depth", + Help: "Current depth of NodeStore write queue", + }) +) + +// NodeStore is a thread-safe store for nodes. +// It is a copy-on-write structure, replacing the "snapshot" +// when a change to the structure occurs. It is optimised for reads, +// and while batches are not fast, they are grouped together +// to do less of the expensive peer calculation if there are many +// changes rapidly. +// +// Writes will block until committed, while reads are never +// blocked. This means that the caller of a write operation +// is responsible for ensuring an update depending on a write +// is not issued before the write is complete. +type NodeStore struct { + data atomic.Pointer[Snapshot] + + peersFunc PeersFunc + writeQueue chan work +} + +func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore { + nodes := make(map[types.NodeID]types.Node, len(allNodes)) + for _, n := range allNodes { + nodes[n.ID] = *n + } + snap := snapshotFromNodes(nodes, peersFunc) + + store := &NodeStore{ + peersFunc: peersFunc, + } + store.data.Store(&snap) + + // Initialize node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + + return store +} + +// Snapshot is the representation of the current state of the NodeStore. +// It contains all nodes and their relationships. +// It is a copy-on-write structure, meaning that when a write occurs, +// a new Snapshot is created with the updated state, +// and replaces the old one atomically. +type Snapshot struct { + // nodesByID is the main source of truth for nodes. + nodesByID map[types.NodeID]types.Node + + // calculated from nodesByID + nodesByNodeKey map[key.NodePublic]types.NodeView + peersByNode map[types.NodeID][]types.NodeView + nodesByUser map[types.UserID][]types.NodeView + allNodes []types.NodeView +} + +// PeersFunc is a function that takes a list of nodes and returns a map +// with the relationships between nodes and their peers. +// This will typically be used to calculate which nodes can see each other +// based on the current policy. +type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView + +// work represents a single operation to be performed on the NodeStore. +type work struct { + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} +} + +// PutNode adds or updates a node in the store. +// If the node already exists, it will be replaced. +// If the node does not exist, it will be added. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) PutNode(n types.Node) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put")) + defer timer.ObserveDuration() + + work := work{ + op: put, + nodeID: n.ID, + node: n, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("put").Inc() +} + +// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. +type UpdateNodeFunc func(n *types.Node) + +// UpdateNode applies a function to modify a specific node in the store. +// This is a blocking operation that waits for the write to complete. +// This is analogous to a database "transaction", or, the caller should +// rather collect all data they want to change, and then call this function. +// Fewer calls are better. +// +// TODO(kradalby): Technically we could have a version of this that modifies the node +// in the current snapshot if _we know_ that the change will not affect the peer relationships. +// This is because the main nodesByID map contains the struct, and every other map is using a +// pointer to the underlying struct. The gotcha with this is that we will need to introduce +// a lock around the nodesByID map to ensure that no other writes are happening +// while we are modifying the node. Which mean we would need to implement read-write locks +// on all read operations. +func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update")) + defer timer.ObserveDuration() + + work := work{ + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("update").Inc() +} + +// DeleteNode removes a node from the store by its ID. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) DeleteNode(id types.NodeID) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete")) + defer timer.ObserveDuration() + + work := work{ + op: del, + nodeID: id, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("delete").Inc() +} + +// Start initializes the NodeStore and starts processing the write queue. +func (s *NodeStore) Start() { + s.writeQueue = make(chan work) + go s.processWrite() +} + +// Stop stops the NodeStore. +func (s *NodeStore) Stop() { + close(s.writeQueue) +} + +// processWrite processes the write queue in batches. +func (s *NodeStore) processWrite() { + c := time.NewTicker(batchTimeout) + defer c.Stop() + batch := make([]work, 0, batchSize) + + for { + select { + case w, ok := <-s.writeQueue: + if !ok { + // Channel closed, apply any remaining batch and exit + if len(batch) != 0 { + s.applyBatch(batch) + } + return + } + batch = append(batch, w) + if len(batch) >= batchSize { + s.applyBatch(batch) + batch = batch[:0] + c.Reset(batchTimeout) + } + case <-c.C: + if len(batch) != 0 { + s.applyBatch(batch) + batch = batch[:0] + } + c.Reset(batchTimeout) + } + } +} + +// applyBatch applies a batch of work to the node store. +// This means that it takes a copy of the current nodes, +// then applies the batch of operations to that copy, +// runs any precomputation needed (like calculating peers), +// and finally replaces the snapshot in the store with the new one. +// The replacement of the snapshot is atomic, ensuring that reads +// are never blocked by writes. +// Each write item is blocked until the batch is applied to ensure +// the caller knows the operation is complete and do not send any +// updates that are dependent on a read that is yet to be written. +func (s *NodeStore) applyBatch(batch []work) { + timer := prometheus.NewTimer(nodeStoreBatchDuration) + defer timer.ObserveDuration() + + nodeStoreBatchSize.Observe(float64(len(batch))) + + nodes := make(map[types.NodeID]types.Node) + maps.Copy(nodes, s.data.Load().nodesByID) + + for _, w := range batch { + switch w.op { + case put: + nodes[w.nodeID] = w.node + case update: + // Update the specific node identified by nodeID + if n, exists := nodes[w.nodeID]; exists { + w.updateFn(&n) + nodes[w.nodeID] = n + } + case del: + delete(nodes, w.nodeID) + } + } + + newSnap := snapshotFromNodes(nodes, s.peersFunc) + s.data.Store(&newSnap) + + // Update node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + + for _, w := range batch { + close(w.result) + } +} + +// snapshotFromNodes creates a new Snapshot from the provided nodes. +// It builds a lot of "indexes" to make lookups fast for datasets we +// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser. +// This is not a fast operation, it is the "slow" part of our copy-on-write +// structure, but it allows us to have fast reads and efficient lookups. +func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot { + timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration) + defer timer.ObserveDuration() + + allNodes := make([]types.NodeView, 0, len(nodes)) + for _, n := range nodes { + allNodes = append(allNodes, n.View()) + } + + newSnap := Snapshot{ + nodesByID: nodes, + allNodes: allNodes, + nodesByNodeKey: make(map[key.NodePublic]types.NodeView), + + // peersByNode is most likely the most expensive operation, + // it will use the list of all nodes, combined with the + // current policy to precalculate which nodes are peers and + // can see each other. + peersByNode: func() map[types.NodeID][]types.NodeView { + peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) + defer peersTimer.ObserveDuration() + return peersFunc(allNodes) + }(), + nodesByUser: make(map[types.UserID][]types.NodeView), + } + + // Build nodesByUser and nodesByNodeKey maps + for _, n := range nodes { + nodeView := n.View() + newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView) + newSnap.nodesByNodeKey[n.NodeKey] = nodeView + } + + return newSnap +} + +// GetNode retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get").Inc() + + n, exists := s.data.Load().nodesByID[id] + if !exists { + return types.NodeView{}, false + } + + return n.View(), true +} + +// GetNodeByNodeKey retrieves a node by its NodeKey. +func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView { + return s.data.Load().nodesByNodeKey[nodeKey] +} + +// ListNodes returns a slice of all nodes in the store. +func (s *NodeStore) ListNodes() views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list").Inc() + + return views.SliceOf(s.data.Load().allNodes) +} + +// ListPeers returns a slice of all peers for a given node ID. +func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_peers").Inc() + + return views.SliceOf(s.data.Load().peersByNode[id]) +} + +// ListNodesByUser returns a slice of all nodes for a given user ID. +func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_by_user").Inc() + + return views.SliceOf(s.data.Load().nodesByUser[uid]) +} diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go new file mode 100644 index 00000000..9666e5db --- /dev/null +++ b/hscontrol/state/node_store_test.go @@ -0,0 +1,501 @@ +package state + +import ( + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/key" +) + +func TestSnapshotFromNodes(t *testing.T) { + tests := []struct { + name string + setupFunc func() (map[types.NodeID]types.Node, PeersFunc) + validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) + }{ + { + name: "empty nodes", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := make(map[types.NodeID]types.Node) + peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + return make(map[types.NodeID][]types.NodeView) + } + + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "single node", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + } + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers + assert.Len(t, snapshot.nodesByUser[1], 1) + assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID()) + }, + }, + { + name: "multiple nodes same user", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 1, "user1", "node2"), + } + + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Each node sees the other as peer (but not itself) + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "multiple nodes different users", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 1, "user1", "node3"), + } + + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // Each node should have 2 peers (all others, but not itself) + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2 + }, + }, + { + name: "odd-even peers filtering", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 3, "user3", "node3"), + 4: createTestNode(4, 4, "user4", "node4"), + } + peersFunc := oddEvenPeersFunc + + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 4) + assert.Len(t, snapshot.allNodes, 4) + assert.Len(t, snapshot.peersByNode, 4) + assert.Len(t, snapshot.nodesByUser, 4) + + // Odd nodes should only see other odd nodes as peers + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // Even nodes should only see other even nodes as peers + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nodes, peersFunc := tt.setupFunc() + snapshot := snapshotFromNodes(nodes, peersFunc) + tt.validate(t, nodes, snapshot) + }) + } +} + +// Helper functions + +func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node { + now := time.Now() + machineKey := key.NewMachine() + nodeKey := key.NewNode() + discoKey := key.NewDisco() + + ipv4 := netip.MustParseAddr("100.64.0.1") + ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1") + + return types.Node{ + ID: nodeID, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + DiscoKey: discoKey.Public(), + Hostname: hostname, + GivenName: hostname, + UserID: userID, + User: types.User{ + Name: username, + DisplayName: username, + }, + RegisterMethod: "test", + IPv4: &ipv4, + IPv6: &ipv6, + CreatedAt: now, + UpdatedAt: now, + } +} + +// Peer functions + +func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + for _, n := range nodes { + if n.ID() != node.ID() { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + + return ret +} + +func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 + + for _, n := range nodes { + if n.ID() == node.ID() { + continue + } + + peerIsOdd := n.ID()%2 == 1 + + // Only add peer if both are odd or both are even + if nodeIsOdd == peerIsOdd { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + + return ret +} + +func TestNodeStoreOperations(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) *NodeStore + steps []testStep + }{ + { + name: "create empty store and add single node", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify empty store", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "add first node", + action: func(store *NodeStore) { + node := createTestNode(1, 1, "user1", "node1") + store.PutNode(node) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, node.ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no peers yet + assert.Len(t, snapshot.nodesByUser[1], 1) + }, + }, + }, + }, + { + name: "create store with initial node and add more", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + initialNodes := types.Nodes{&node1} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + assert.Empty(t, snapshot.peersByNode[1]) + }, + }, + { + name: "add second node same user", + action: func(store *NodeStore) { + node2 := createTestNode(2, 1, "user1", "node2") + store.PutNode(node2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Now both nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "add third node different user", + action: func(store *NodeStore) { + node3 := createTestNode(3, 2, "user2", "node3") + store.PutNode(node3) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // All nodes should see the other 2 as peers + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3 + }, + }, + }, + }, + { + name: "test node deletion", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + node3 := createTestNode(3, 2, "user2", "node3") + initialNodes := types.Nodes{&node1, &node2, &node3} + + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial 3 nodes", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + }, + }, + { + name: "delete middle node", + action: func(store *NodeStore) { + store.DeleteNode(2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 2) + + // Node 2 should be gone + assert.NotContains(t, snapshot.nodesByID, types.NodeID(2)) + + // Remaining nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // User groupings updated + assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3 + }, + }, + { + name: "delete all remaining nodes", + action: func(store *NodeStore) { + store.DeleteNode(1) + store.DeleteNode(3) + + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + }, + }, + { + name: "test node updates", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial hostnames", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "update node hostname", + action: func(store *NodeStore) { + store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "updated-node1" + n.GivenName = "updated-node1" + }) + + snapshot := store.data.Load() + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged + + // Peers should still work correctly + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Len(t, snapshot.peersByNode[2], 1) + }, + }, + }, + }, + { + name: "test with odd-even peers filtering", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, oddEvenPeersFunc) + }, + steps: []testStep{ + { + name: "add nodes with odd-even filtering", + action: func(store *NodeStore) { + // Add nodes in sequence + store.PutNode(createTestNode(1, 1, "user1", "node1")) + store.PutNode(createTestNode(2, 2, "user2", "node2")) + store.PutNode(createTestNode(3, 3, "user3", "node3")) + store.PutNode(createTestNode(4, 4, "user4", "node4")) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 4) + + // Verify odd-even peer relationships + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + { + name: "delete odd node and verify even nodes unaffected", + action: func(store *NodeStore) { + store.DeleteNode(1) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + + // Node 3 (odd) should now have no peers + assert.Empty(t, snapshot.peersByNode[3]) + + // Even nodes should still see each other + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := tt.setupFunc(t) + store.Start() + defer store.Stop() + + for _, step := range tt.steps { + t.Run(step.name, func(t *testing.T) { + step.action(store) + }) + } + }) + } +} + +type testStep struct { + name string + action func(store *NodeStore) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 0a743184..958a2f52 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1,5 +1,6 @@ // Package state provides core state management for Headscale, coordinating // between subsystems like database, IP allocation, policy management, and DERP routing. + package state import ( @@ -9,6 +10,8 @@ import ( "io" "net/netip" "os" + "slices" + "sync" "sync/atomic" "time" @@ -21,12 +24,13 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" - xslices "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" + "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -49,6 +53,9 @@ type State struct { // cfg holds the current Headscale configuration cfg *types.Config + // nodeStore provides an in-memory cache for nodes. + nodeStore *NodeStore + // subsystem keeping state // db provides persistent storage and database operations db *hsdb.HSDatabase @@ -90,6 +97,12 @@ func NewState(cfg *types.Config) (*State, error) { if err != nil { return nil, fmt.Errorf("loading nodes: %w", err) } + + // On startup, all nodes should be marked as offline until they reconnect + // This ensures we don't have stale online status from previous runs + for _, node := range nodes { + node.IsOnline = ptr.To(false) + } users, err := db.ListUsers() if err != nil { return nil, fmt.Errorf("loading users: %w", err) @@ -105,7 +118,13 @@ func NewState(cfg *types.Config) (*State, error) { return nil, fmt.Errorf("init policy manager: %w", err) } - s := &State{ + nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + _, matchers := polMan.Filter() + return policy.BuildPeerMap(views.SliceOf(nodes), matchers) + }) + nodeStore.Start() + + return &State{ cfg: cfg, db: db, @@ -113,13 +132,14 @@ func NewState(cfg *types.Config) (*State, error) { polMan: polMan, registrationCache: registrationCache, primaryRoutes: routes.New(), - } - - return s, nil + nodeStore: nodeStore, + }, nil } // Close gracefully shuts down the State instance and releases all resources. func (s *State) Close() error { + s.nodeStore.Stop() + if err := s.db.Close(); err != nil { return fmt.Errorf("closing database: %w", err) } @@ -180,69 +200,78 @@ func (s *State) DERPMap() tailcfg.DERPMapView { // ReloadPolicy reloads the access control policy and triggers auto-approval if changed. // Returns true if the policy changed. -func (s *State) ReloadPolicy() (bool, error) { +func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { pol, err := policyBytes(s.db, s.cfg) if err != nil { - return false, fmt.Errorf("loading policy: %w", err) + return nil, fmt.Errorf("loading policy: %w", err) } - changed, err := s.polMan.SetPolicy(pol) + policyChanged, err := s.polMan.SetPolicy(pol) if err != nil { - return false, fmt.Errorf("setting policy: %w", err) + return nil, fmt.Errorf("setting policy: %w", err) } - if changed { - err := s.autoApproveNodes() - if err != nil { - return false, fmt.Errorf("auto approving nodes: %w", err) - } + cs := []change.ChangeSet{change.PolicyChange()} + + // Always call autoApproveNodes during policy reload, regardless of whether + // the policy content has changed. This ensures that routes are re-evaluated + // when they might have been manually disabled but could now be auto-approved + // with the current policy. + rcs, err := s.autoApproveNodes() + if err != nil { + return nil, fmt.Errorf("auto approving nodes: %w", err) } - return changed, nil -} + // TODO(kradalby): These changes can probably be safely ignored. + // If the PolicyChange is happening, that will lead to a full update + // meaning that we do not need to send individual route changes. + cs = append(cs, rcs...) -// AutoApproveNodes processes pending nodes and auto-approves those meeting policy criteria. -func (s *State) AutoApproveNodes() error { - return s.autoApproveNodes() + if len(rcs) > 0 || policyChanged { + log.Info(). + Bool("policy.changed", policyChanged). + Int("route.changes", len(rcs)). + Int("total.changes", len(cs)). + Msg("Policy reload completed with changes") + } + + return cs, nil } // CreateUser creates a new user and updates the policy manager. -// Returns the created user, whether policies changed, and any error. -func (s *State) CreateUser(user types.User) (*types.User, bool, error) { +// Returns the created user, change set, and any error. +func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() if err := s.db.DB.Save(&user).Error; err != nil { - return nil, false, fmt.Errorf("creating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerUsers() + c, err := s.updatePolicyManagerUsers() if err != nil { // Log the error but don't fail the user creation - return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err) + return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err) } // Even if the policy manager doesn't detect a filter change, SSH policies // might now be resolvable when they weren't before. If there are existing // nodes, we should send a policy change to ensure they get updated SSH policies. - if !policyChanged { - nodes, err := s.ListNodes() - if err == nil && len(nodes) > 0 { - policyChanged = true - } + // TODO(kradalby): detect this, or rebuild all SSH policies so we can determine + // this upstream. + if c.Empty() { + c = change.PolicyChange() } - log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated") + log.Info().Str("user.name", user.Name).Msg("User created") - // TODO(kradalby): implement the user in-memory cache - - return &user, policyChanged, nil + return &user, c, nil } // UpdateUser modifies an existing user using the provided update function within a transaction. -// Returns the updated user, whether policies changed, and any error. -func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, bool, error) { +// Returns the updated user, change set, and any error. +func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() @@ -263,18 +292,18 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return user, nil }) if err != nil { - return nil, false, err + return nil, change.EmptySet, err } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerUsers() + c, err := s.updatePolicyManagerUsers() if err != nil { - return user, false, fmt.Errorf("failed to update policy manager after user update: %w", err) + return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err) } - // TODO(kradalby): implement the user in-memory cache + // TODO(kradalby): We might want to update nodestore with the user data - return user, policyChanged, nil + return user, c, nil } // DeleteUser permanently removes a user and all associated data (nodes, API keys, etc). @@ -284,7 +313,7 @@ func (s *State) DeleteUser(userID types.UserID) error { } // RenameUser changes a user's name. The new name must be unique. -func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, bool, error) { +func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) { return s.UpdateUser(userID, func(user *types.User) error { user.Name = newName return nil @@ -316,33 +345,16 @@ func (s *State) ListAllUsers() ([]types.User, error) { return s.db.ListUsers() } -// CreateNode creates a new node and updates the policy manager. -// Returns the created node, whether policies changed, and any error. -func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if err := s.db.DB.Save(node).Error; err != nil { - return nil, false, fmt.Errorf("creating node: %w", err) - } - - // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() - if err != nil { - return node, false, fmt.Errorf("failed to update policy manager after node creation: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - return node, policyChanged, nil -} - // updateNodeTx performs a database transaction to update a node and refresh the policy manager. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) { +// IMPORTANT: This function does NOT update the NodeStore. The caller MUST update the NodeStore +// BEFORE calling this function with the EXACT same changes that the database update will make. +// This ensures the NodeStore is the source of truth for the batcher and maintains consistency. +// Returns error only; callers should get the updated NodeView from NodeStore to maintain consistency. +func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) error { s.mu.Lock() defer s.mu.Unlock() - node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + _, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { if err := updateFn(tx); err != nil { return nil, err } @@ -358,166 +370,283 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err return node, nil }) - if err != nil { - return nil, change.EmptySet, err + return err +} + +// persistNodeToDB saves the current state of a node from NodeStore to the database. +// CRITICAL: This function MUST get the latest node from NodeStore to ensure consistency. +func (s *State) persistNodeToDB(nodeID types.NodeID) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // CRITICAL: Always get the latest node from NodeStore to ensure we save the current state + node, found := s.nodeStore.GetNode(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + nodePtr := node.AsStruct() + + if err := s.db.DB.Save(nodePtr).Error; err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err) } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() + c, err := s.updatePolicyManagerNodes() if err != nil { - return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) } - // TODO(kradalby): implement the node in-memory cache - - var c change.ChangeSet - if policyChanged { - c = change.PolicyChange() - } else { - // Basic node change without specific details since this is a generic update - c = change.NodeAdded(node.ID) + if c.Empty() { + c = change.NodeAdded(node.ID()) } return node, c, nil } -// SaveNode persists an existing node to the database and updates the policy manager. -func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) { + // Update NodeStore first + nodePtr := node.AsStruct() - if err := s.db.DB.Save(node).Error; err != nil { - return nil, change.EmptySet, fmt.Errorf("saving node: %w", err) - } + s.nodeStore.PutNode(*nodePtr) - // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() - if err != nil { - return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - if policyChanged { - return node, change.PolicyChange(), nil - } - - return node, change.EmptySet, nil + // Then save to database + return s.persistNodeToDB(node.ID()) } // DeleteNode permanently removes a node and cleans up associated resources. // Returns whether policies changed and any error. This operation is irreversible. -func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) { - err := s.db.DeleteNode(node) +func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { + s.nodeStore.DeleteNode(node.ID()) + + err := s.db.DeleteNode(node.AsStruct()) if err != nil { return change.EmptySet, err } - c := change.NodeRemoved(node.ID) + c := change.NodeRemoved(node.ID()) // Check if policy manager needs updating after node deletion - policyChanged, err := s.updatePolicyManagerNodes() + policyChange, err := s.updatePolicyManagerNodes() if err != nil { return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err) } - if policyChanged { - c = change.PolicyChange() + if !policyChange.Empty() { + c = policyChange } return c, nil } -func (s *State) Connect(node *types.Node) change.ChangeSet { - c := change.NodeOnline(node.ID) - routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) +// Connect marks a node as connected and updates its primary routes in the state. +func (s *State) Connect(id types.NodeID) []change.ChangeSet { + // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification + // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, + // the NodeStore already reflects the correct online status for full map generation. + // now := time.Now() + s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.IsOnline = ptr.To(true) + // n.LastSeen = ptr.To(now) + }) + c := []change.ChangeSet{change.NodeOnline(id)} + + // Get fresh node data from NodeStore after the online status update + node, found := s.GetNodeByID(id) + if !found { + return nil + } + + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") + + // Use the node's current routes for primary route update + // SubnetRoutes() returns only the intersection of announced AND approved routes + // We MUST use SubnetRoutes() to maintain the security model + routeChange := s.primaryRoutes.SetRoutes(id, node.SubnetRoutes()...) if routeChange { - c = change.NodeAdded(node.ID) + c = append(c, change.NodeAdded(id)) } return c } -func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) { - c := change.NodeOffline(node.ID) +// Disconnect marks a node as disconnected and updates its primary routes in the state. +func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { + now := time.Now() - _, _, err := s.SetLastSeen(node.ID, time.Now()) + // Get node info before updating for logging + node, found := s.GetNodeByID(id) + var nodeName string + if found { + nodeName = node.Hostname() + } + + s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.LastSeen = ptr.To(now) + // NodeStore is the source of truth for all node state including online status. + n.IsOnline = ptr.To(false) + }) + + if found { + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Node disconnected") + } + + err := s.updateNodeTx(id, func(tx *gorm.DB) error { + // Update last_seen in the database + // Note: IsOnline is managed only in NodeStore (marked with gorm:"-"), not persisted to database + return hsdb.SetLastSeen(tx, id, now) + }) if err != nil { - return c, fmt.Errorf("disconnecting node: %w", err) + // Log error but don't fail the disconnection - NodeStore is already updated + // and we need to send change notifications to peers + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update last seen in database") } - if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange { - c = change.PolicyChange() + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + // Log error but continue - disconnection must proceed + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update policy manager after node disconnect") + c = change.EmptySet } - // TODO(kradalby): This node should update the in memory state - return c, nil + // The node is disconnecting so make sure that none of the routes it + // announced are served to any nodes. + routeChange := s.primaryRoutes.SetRoutes(id) + + cs := []change.ChangeSet{change.NodeOffline(id), c} + + // If we have a policy change or route change, return that as it's more comprehensive + // Otherwise, return the NodeOffline change to ensure nodes are notified + if c.IsFull() || routeChange { + cs = append(cs, change.PolicyChange()) + } + + return cs, nil } // GetNodeByID retrieves a node by ID. -func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) { - return s.db.GetNodeByID(nodeID) -} - -// GetNodeViewByID retrieves a node view by ID. -func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) { - node, err := s.db.GetNodeByID(nodeID) - if err != nil { - return types.NodeView{}, err - } - - return node.View(), nil +// GetNodeByID retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, bool) { + return s.nodeStore.GetNode(nodeID) } // GetNodeByNodeKey retrieves a node by its Tailscale public key. -func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { - return s.db.GetNodeByNodeKey(nodeKey) +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) { + return s.nodeStore.GetNodeByNodeKey(nodeKey) } -// GetNodeViewByNodeKey retrieves a node view by its Tailscale public key. -func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) { - node, err := s.db.GetNodeByNodeKey(nodeKey) - if err != nil { - return types.NodeView{}, err - } - - return node.View(), nil +// GetNodeByMachineKey retrieves a node by its machine key. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) { + return s.nodeStore.GetNodeByMachineKey(machineKey) } // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. -func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { +func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] { if len(nodeIDs) == 0 { - return s.db.ListNodes() + return s.nodeStore.ListNodes() } - return s.db.ListNodes(nodeIDs...) + // Filter nodes by the requested IDs + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(nodeIDs)) + for _, id := range nodeIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes) } // ListNodesByUser retrieves all nodes belonging to a specific user. -func (s *State) ListNodesByUser(userID types.UserID) (types.Nodes, error) { - return hsdb.Read(s.db.DB, func(rx *gorm.DB) (types.Nodes, error) { - return hsdb.ListNodesByUser(rx, userID) - }) +func (s *State) ListNodesByUser(userID types.UserID) views.Slice[types.NodeView] { + return s.nodeStore.ListNodesByUser(userID) } // ListPeers retrieves nodes that can communicate with the specified node based on policy. -func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - return s.db.ListPeers(nodeID, peerIDs...) +func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) views.Slice[types.NodeView] { + if len(peerIDs) == 0 { + return s.nodeStore.ListPeers(nodeID) + } + + // For specific peerIDs, filter from all nodes + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(peerIDs)) + for _, id := range peerIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes) } // ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system. -func (s *State) ListEphemeralNodes() (types.Nodes, error) { - return s.db.ListEphemeralNodes() +func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { + allNodes := s.nodeStore.ListNodes() + var ephemeralNodes []types.NodeView + + for _, node := range allNodes.All() { + // Check if node is ephemeral by checking its AuthKey + if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + ephemeralNodes = append(ephemeralNodes, node) + } + } + + return views.SliceOf(ephemeralNodes) } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) { + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + // If the database update fails, the NodeStore change will remain, but since we return + // an error, no change notification will be sent to the batcher. + expiryPtr := expiry + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.Expiry = &expiryPtr + }) + + err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.NodeSetExpiry(tx, nodeID, expiry) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -528,12 +657,32 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Nod } // SetNodeTags assigns tags to a node for use in access control policies. -func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ForcedTags = tags + }) + + err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetTags(tx, nodeID, tags) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err) + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -544,16 +693,42 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, ch } // SetApprovedRoutes sets the network routes that a node is approved to advertise. -func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) { + // TODO(kradalby): In principle we should call the AutoApprove logic here + // because even if the CLI removes an auto-approved route, it will be added + // back automatically. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ApprovedRoutes = routes + }) + + err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetApprovedRoutes(tx, nodeID, routes) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) } - // Update primary routes after changing approved routes - routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...) + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } + + // Get the node from NodeStore to ensure we have the latest state + nodeView, ok := s.GetNodeByID(nodeID) + if !ok { + return n, change.EmptySet, fmt.Errorf("node %d not found in NodeStore", nodeID) + } + // Use SubnetRoutes() instead of ApprovedRoutes() to ensure we only set + // primary routes for routes that are both announced AND approved + routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...) if routeChange || !c.IsFull() { c = change.PolicyChange() @@ -563,12 +738,48 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (* } // RenameNode changes the display name of a node. -func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) { + // Validate the new name before making any changes + if err := util.CheckForFQDNRules(newName); err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) + } + + // Check name uniqueness + nodes, err := s.db.ListNodes() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("checking name uniqueness: %w", err) + } + for _, node := range nodes { + if node.ID != nodeID && node.GivenName == newName { + return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName) + } + } + + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.GivenName = newName + }) + + err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.RenameNode(tx, nodeID, newName) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -578,20 +789,45 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, ch return n, c, nil } -// SetLastSeen updates when a node was last seen, used for connectivity monitoring. -func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetLastSeen(tx, nodeID, lastSeen) - }) -} - // AssignNodeToUser transfers a node to a different user. -func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (types.NodeView, change.ChangeSet, error) { + // Validate that both node and user exist + _, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found: %d", nodeID) + } + + user, err := s.GetUserByID(userID) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("user not found: %w", err) + } + + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.User = *user + n.UserID = uint(userID) + }) + + err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.AssignNodeToUser(tx, nodeID, userID) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err) + return types.NodeView{}, change.EmptySet, err + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -603,13 +839,59 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*typ // BackfillNodeIPs assigns IP addresses to nodes that don't have them. func (s *State) BackfillNodeIPs() ([]string, error) { - return s.db.BackfillNodeIPs(s.ipAlloc) + changes, err := s.db.BackfillNodeIPs(s.ipAlloc) + if err != nil { + return nil, err + } + + // Refresh NodeStore after IP changes to ensure consistency + if len(changes) > 0 { + nodes, err := s.db.ListNodes() + if err != nil { + return changes, fmt.Errorf("failed to refresh NodeStore after IP backfill: %w", err) + } + + for _, node := range nodes { + // Preserve online status when refreshing from database + existingNode, exists := s.nodeStore.GetNode(node.ID) + if exists && existingNode.Valid() { + node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + } + // TODO(kradalby): This should just update the IP addresses, nothing else in the node store. + // We should avoid PutNode here. + s.nodeStore.PutNode(*node) + } + } + + return changes, nil } // ExpireExpiredNodes finds and processes expired nodes since the last check. // Returns next check time, state update with expired nodes, and whether any were found. func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) { - return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck) + // Why capture start time: We need to ensure we don't miss nodes that expire + // while this function is running by using a consistent timestamp for the next check + started := time.Now() + + var updates []change.ChangeSet + + for _, node := range s.nodeStore.ListNodes().All() { + if !node.Valid() { + continue + } + + // Why check After(lastCheck): We only want to notify about nodes that + // expired since the last check to avoid duplicate notifications + if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { + updates = append(updates, change.KeyExpiry(node.ID())) + } + } + + if len(updates) > 0 { + return started, updates, true + } + + return started, nil, false } // SSHPolicy returns the SSH access policy for a node. @@ -633,13 +915,35 @@ func (s *State) SetPolicy(pol []byte) (bool, error) { } // AutoApproveRoutes checks if a node's routes should be auto-approved. -func (s *State) AutoApproveRoutes(node *types.Node) bool { - return policy.AutoApproveRoutes(s.polMan, node) -} +// AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them. +func (s *State) AutoApproveRoutes(nv types.NodeView) bool { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) + if changed { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.announced", util.PrefixesToString(nv.AnnouncedRoutes())). + Strs("routes.approved.old", util.PrefixesToString(nv.ApprovedRoutes().AsSlice())). + Strs("routes.approved.new", util.PrefixesToString(approved)). + Msg("Single node auto-approval detected route changes") -// PolicyDebugString returns a debug representation of the current policy. -func (s *State) PolicyDebugString() string { - return s.polMan.DebugString() + // Persist the auto-approved routes to database and NodeStore via SetApprovedRoutes + // This ensures consistency between database and NodeStore + _, _, err := s.SetApprovedRoutes(nv.ID(), approved) + if err != nil { + log.Error(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Err(err). + Msg("Failed to persist auto-approved routes") + + return false + } + + log.Info().Uint64("node.id", nv.ID().Uint64()).Str("node.name", nv.Hostname()).Strs("routes.approved", util.PrefixesToString(approved)).Msg("Routes approved") + } + + return changed } // GetPolicy retrieves the current policy from the database. @@ -744,36 +1048,238 @@ func (s *State) HandleNodeFromAuthPath( userID types.UserID, expiry *time.Time, registrationMethod string, -) (*types.Node, change.ChangeSet, error) { - ipv4, ipv6, err := s.ipAlloc.Next() - if err != nil { - return nil, change.EmptySet, err +) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Get the registration entry from cache + regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + if !ok { + return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache } - return s.db.HandleNodeFromAuthPath( - registrationID, - userID, - expiry, - util.RegisterMethodOIDC, - ipv4, ipv6, - ) + // Get the user + user, err := s.db.GetUserByID(userID) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) + } + + // Check if node already exists by node key + existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey) + if exists && existingNodeView.Valid() { + // Node exists - this is a refresh/re-registration + log.Debug(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Username()). + Str("registrationMethod", registrationMethod). + Str("node.name", existingNodeView.Hostname()). + Uint64("node.id", existingNodeView.ID().Uint64()). + Msg("Refreshing existing node registration") + + // Update NodeStore first with the new expiry + s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) { + if expiry != nil { + node.Expiry = expiry + } + // Mark as offline since node is reconnecting + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) + + // Save to database + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry) + if err != nil { + return nil, err + } + // Return the node to satisfy the Write signature + return hsdb.GetNodeByID(tx, existingNodeView.ID()) + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err) + } + + // Get updated node from NodeStore + updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) + + return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil + } + + // New node registration + log.Debug(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Username()). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", expiry)). + Msg("Registering new node from auth callback") + + // Check if node exists with same machine key + var existingMachineNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() { + existingMachineNode = nv.AsStruct() + } + + // Prepare the node for registration + nodeToRegister := regEntry.Node + nodeToRegister.UserID = uint(userID) + nodeToRegister.User = *user + nodeToRegister.RegisterMethod = registrationMethod + if expiry != nil { + nodeToRegister.Expiry = expiry + } + + // Handle IP allocation + var ipv4, ipv6 *netip.Addr + if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { + // Reuse existing IPs and properties + nodeToRegister.ID = existingMachineNode.ID + nodeToRegister.GivenName = existingMachineNode.GivenName + nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes + ipv4 = existingMachineNode.IPv4 + ipv6 = existingMachineNode.IPv6 + } else { + // Allocate new IPs + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } + } + + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 + + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + var savedNode *types.Node + if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.DiscoKey = nodeToRegister.DiscoKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + if expiry != nil { + node.Expiry = expiry + } + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) + + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) + } + + // Delete from registration cache + s.registrationCache.Delete(registrationID) + + // Signal to waiting clients + select { + case regEntry.Registered <- savedNode: + default: + } + close(regEntry.Registered) + + // Update policy manager + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err) + } + + if !nodesChange.Empty() { + return savedNode.View(), nodesChange, nil + } + + return savedNode.View(), change.NodeAdded(savedNode.ID), nil } // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, -) (*types.Node, change.ChangeSet, bool, error) { +) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { - return nil, change.EmptySet, false, err + return types.NodeView{}, change.EmptySet, err } err = pak.Validate() if err != nil { - return nil, change.EmptySet, false, err + return types.NodeView{}, change.EmptySet, err } + // Check if this is a logout request for an ephemeral node + if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { + // Find the node to delete + var nodeToDelete types.NodeView + for _, nv := range s.nodeStore.ListNodes().All() { + if nv.Valid() && nv.MachineKey() == machineKey { + nodeToDelete = nv + break + } + } + if nodeToDelete.Valid() { + c, err := s.DeleteNode(nodeToDelete) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) + } + + return types.NodeView{}, c, nil + } + + return types.NodeView{}, change.EmptySet, nil + } + + log.Debug(). + Caller(). + Str("node.name", regReq.Hostinfo.Hostname). + Str("machine.key", machineKey.ShortString()). + Str("node.key", regReq.NodeKey.ShortString()). + Str("user.name", pak.User.Username()). + Msg("Registering node with pre-auth key") + + // Check if node already exists with same machine key + var existingNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { + existingNode = nv.AsStruct() + } + + // Prepare the node for registration nodeToRegister := types.Node{ Hostname: regReq.Hostinfo.Hostname, UserID: pak.User.ID, @@ -783,75 +1289,133 @@ func (s *State) HandleNodeFromPreAuthKey( Hostinfo: regReq.Hostinfo, LastSeen: ptr.To(time.Now()), RegisterMethod: util.RegisterMethodAuthKey, - - // TODO(kradalby): This should not be set on the node, - // they should be looked up through the key, which is - // attached to the node. - ForcedTags: pak.Proto().GetAclTags(), - AuthKey: pak, - AuthKeyID: &pak.ID, + ForcedTags: pak.Proto().GetAclTags(), + AuthKey: pak, + AuthKeyID: &pak.ID, } if !regReq.Expiry.IsZero() { nodeToRegister.Expiry = ®Req.Expiry } - ipv4, ipv6, err := s.ipAlloc.Next() - if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err) + // Handle IP allocation and existing node properties + var ipv4, ipv6 *netip.Addr + if existingNode != nil && existingNode.UserID == pak.User.ID { + // Reuse existing node properties + nodeToRegister.ID = existingNode.ID + nodeToRegister.GivenName = existingNode.GivenName + nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes + ipv4 = existingNode.IPv4 + ipv6 = existingNode.IPv6 + } else { + // Allocate new IPs + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } } - node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - node, err := hsdb.RegisterNode(tx, - nodeToRegister, - ipv4, ipv6, - ) - if err != nil { - return nil, fmt.Errorf("registering node: %w", err) - } + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 - if !pak.Reusable { - err = hsdb.UsePreAuthKey(tx, pak) - if err != nil { - return nil, fmt.Errorf("using pre auth key: %w", err) + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + var savedNode *types.Node + if existingNode != nil && existingNode.UserID == pak.User.ID { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + node.ForcedTags = nodeToRegister.ForcedTags + node.AuthKey = nodeToRegister.AuthKey + node.AuthKeyID = nodeToRegister.AuthKeyID + if nodeToRegister.Expiry != nil { + node.Expiry = nodeToRegister.Expiry } - } + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) - return node, nil - }) - if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err) - } + log.Trace(). + Caller(). + Str("node.name", nodeToRegister.Hostname). + Uint64("node.id", existingNode.ID.Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", regReq.NodeKey.ShortString()). + Str("user.name", pak.User.Username()). + Msg("Node re-authorized") - // Check if this is a logout request for an ephemeral node - if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { - // This is a logout request for an ephemeral node, delete it immediately - c, err := s.DeleteNode(node) + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) } - return nil, c, false, nil + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) } - // Check if policy manager needs updating - // This is necessary because we just created a new node. - // We need to ensure that the policy manager is aware of this new node. - // Also update users to ensure all users are known when evaluating policies. - usersChanged, err := s.updatePolicyManagerUsers() + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err) + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err) } - nodesChanged, err := s.updatePolicyManagerNodes() + nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err) + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err) } - policyChanged := usersChanged || nodesChanged + var c change.ChangeSet + if !usersChange.Empty() || !nodesChange.Empty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(savedNode.ID) + } - c := change.NodeAdded(node.ID) - - return node, c, policyChanged, nil + return savedNode.View(), c, nil } // AllocateNextIPs allocates the next available IPv4 and IPv6 addresses. @@ -865,22 +1429,26 @@ func (s *State) AllocateNextIPs() (*netip.Addr, *netip.Addr, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for users. // updatePolicyManagerUsers refreshes the policy manager with current user data. -func (s *State) updatePolicyManagerUsers() (bool, error) { +func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { users, err := s.ListAllUsers() if err != nil { - return false, fmt.Errorf("listing users for policy update: %w", err) + return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err) } log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users") changed, err := s.polMan.SetUsers(users) if err != nil { - return false, fmt.Errorf("updating policy manager users: %w", err) + return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err) } - log.Debug().Bool("changed", changed).Msg("Policy manager users updated") + log.Debug().Caller().Bool("policy.changed", changed).Msg("Policy manager user update completed because SetUsers operation finished") - return changed, nil + if changed { + return change.PolicyChange(), nil + } + + return change.EmptySet, nil } // updatePolicyManagerNodes updates the policy manager with current nodes. @@ -889,18 +1457,19 @@ func (s *State) updatePolicyManagerUsers() (bool, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for nodes. // updatePolicyManagerNodes refreshes the policy manager with current node data. -func (s *State) updatePolicyManagerNodes() (bool, error) { - nodes, err := s.ListNodes() +func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) { + nodes := s.ListNodes() + + changed, err := s.polMan.SetNodes(nodes) if err != nil { - return false, fmt.Errorf("listing nodes for policy update: %w", err) + return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err) } - changed, err := s.polMan.SetNodes(nodes.ViewSlice()) - if err != nil { - return false, fmt.Errorf("updating policy manager nodes: %w", err) + if changed { + return change.PolicyChange(), nil } - return changed, nil + return change.EmptySet, nil } // PingDB checks if the database connection is healthy. @@ -914,147 +1483,235 @@ func (s *State) PingDB(ctx context.Context) error { // TODO(kradalby): This is kind of messy, maybe this is another +1 // for an event bus. See example comments here. // autoApproveNodes automatically approves nodes based on policy rules. -func (s *State) autoApproveNodes() error { - err := s.db.Write(func(tx *gorm.DB) error { - nodes, err := hsdb.ListNodes(tx) - if err != nil { - return err - } +func (s *State) autoApproveNodes() ([]change.ChangeSet, error) { + nodes := s.ListNodes() - for _, node := range nodes { - // TODO(kradalby): This change should probably be sent to the rest of the system. - changed := policy.AutoApproveRoutes(s.polMan, node) + // Approve routes concurrently, this should make it likely + // that the writes end in the same batch in the nodestore write. + var errg errgroup.Group + var cs []change.ChangeSet + var mu sync.Mutex + for _, nv := range nodes.All() { + errg.Go(func() error { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) if changed { - err = tx.Save(node).Error + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.approved.old", util.PrefixesToString(nv.ApprovedRoutes().AsSlice())). + Strs("routes.approved.new", util.PrefixesToString(approved)). + Msg("Routes auto-approved by policy") + + _, c, err := s.SetApprovedRoutes(nv.ID(), approved) if err != nil { return err } - // TODO(kradalby): This should probably be done outside of the transaction, - // and the result of this should be propagated to the system. - s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) + mu.Lock() + cs = append(cs, c) + mu.Unlock() } - } - return nil - }) - if err != nil { - return fmt.Errorf("auto approving routes for nodes: %w", err) + return nil + }) } - return nil + err := errg.Wait() + if err != nil { + return nil, err + } + + return cs, nil } -// TODO(kradalby): This should just take the node ID? -func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) { - // TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, - // which means we could shortcut the whole change thing if there are no other important updates. - peerChange := node.PeerChangeFromMapRequest(req) +// UpdateNodeFromMapRequest processes a MapRequest and updates the node. +// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, +// which means we could shortcut the whole change thing if there are no other important updates. +// When a field is added to this function, remember to also add it to: +// - node.PeerChangeFromMapRequest +// - node.ApplyPeerChange +// - logTracePeerChange in poll.go. +func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) { + var routeChange bool + var hostinfoChanged bool + var needsRouteApproval bool + // We need to ensure we update the node as it is in the NodeStore at + // the time of the request. + s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { + peerChange := currentNode.PeerChangeFromMapRequest(req) + hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) - node.ApplyPeerChange(&peerChange) + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(peerChange) && !hostinfoChanged { + return + } - sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo) + // Calculate route approval before NodeStore update to avoid calling View() inside callback + var autoApprovedRoutes []netip.Prefix + hasNewRoutes := req.Hostinfo != nil && len(req.Hostinfo.RoutableIPs) > 0 + needsRouteApproval = hostinfoChanged && (routesChanged(currentNode.View(), req.Hostinfo) || (hasNewRoutes && len(currentNode.ApprovedRoutes) == 0)) + if needsRouteApproval { + autoApprovedRoutes, routeChange = policy.ApproveRoutesWithPolicy( + s.polMan, + currentNode.View(), + // We need to preserve currently approved routes to ensure + // routes outside of the policy approver is persisted. + currentNode.ApprovedRoutes, + // However, the node has updated its routable IPs, so we + // need to approve them using that as a context. + req.Hostinfo.RoutableIPs, + ) + } - // The node might not set NetInfo if it has not changed and if - // the full HostInfo object is overwritten, the information is lost. - // If there is no NetInfo, keep the previous one. - // From 1.66 the client only sends it if changed: - // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 - // TODO(kradalby): evaluate if we need better comparing of hostinfo - // before we take the changes. - if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil { - req.Hostinfo.NetInfo = node.Hostinfo.NetInfo - } - node.Hostinfo = req.Hostinfo + // Log when routes change but approval doesn't + if hostinfoChanged && req.Hostinfo != nil && routesChanged(currentNode.View(), req.Hostinfo) && !routeChange { + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("oldAnnouncedRoutes", util.PrefixesToString(currentNode.AnnouncedRoutes())). + Strs("newAnnouncedRoutes", util.PrefixesToString(req.Hostinfo.RoutableIPs)). + Strs("approvedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Bool("routeChange", routeChange). + Msg("announced routes changed but approved routes did not") + } - // If there is no changes and nothing to save, - // return early. - if peerChangeEmpty(peerChange) && !sendUpdate { - // mapResponseEndpointUpdates.WithLabelValues("noop").Inc() - return change.EmptySet, nil + currentNode.ApplyPeerChange(&peerChange) + + if hostinfoChanged { + // The node might not set NetInfo if it has not changed and if + // the full HostInfo object is overwritten, the information is lost. + // If there is no NetInfo, keep the previous one. + // From 1.66 the client only sends it if changed: + // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 + // TODO(kradalby): evaluate if we need better comparing of hostinfo + // before we take the changes. + // Preserve NetInfo only if the existing node actually has valid NetInfo + // This prevents copying nil NetInfo which would lose DERP relay assignments + if req.Hostinfo != nil && req.Hostinfo.NetInfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). + Msg("preserving NetInfo from previous Hostinfo in MapRequest") + req.Hostinfo.NetInfo = currentNode.Hostinfo.NetInfo + } else if req.Hostinfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { + // When MapRequest has no Hostinfo but we have existing NetInfo, create a minimal + // Hostinfo to preserve the NetInfo to maintain DERP connectivity + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). + Msg("creating minimal Hostinfo to preserve NetInfo in MapRequest") + req.Hostinfo = &tailcfg.Hostinfo{ + NetInfo: currentNode.Hostinfo.NetInfo, + } + } + currentNode.Hostinfo = req.Hostinfo + currentNode.ApplyHostnameFromHostInfo(req.Hostinfo) + + if routeChange { + // Apply pre-calculated route approval + // Always apply the route approval result to ensure consistency, + // regardless of whether the policy evaluation detected changes. + // This fixes the bug where routes weren't properly cleared when + // auto-approvers were removed from the policy. + log.Info(). + Uint64("node.id", id.Uint64()). + Strs("oldApprovedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Strs("newApprovedRoutes", util.PrefixesToString(autoApprovedRoutes)). + Bool("routeChanged", routeChange). + Msg("applying route approval results") + currentNode.ApprovedRoutes = autoApprovedRoutes + } + } + }) + + nodeRouteChange := change.EmptySet + + // Handle route changes after NodeStore update + // We need to update node routes if either: + // 1. The approved routes changed (routeChange is true), OR + // 2. The announced routes changed (even if approved routes stayed the same) + // This is because SubnetRoutes is the intersection of announced AND approved routes. + needsRouteUpdate := false + routesChangedButNotApproved := hostinfoChanged && req.Hostinfo != nil && needsRouteApproval && !routeChange + if routeChange { + needsRouteUpdate = true + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because approved routes changed") + } else if routesChangedButNotApproved { + needsRouteUpdate = true + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because announced routes changed but approved routes did not") } - c := change.EmptySet + if needsRouteUpdate { + // Get the updated node to access its subnet routes + updatedNode, exists := s.GetNodeByID(id) + if !exists { + return change.EmptySet, fmt.Errorf("node disappeared during update: %d", id) + } - // Check if the Hostinfo of the node has changed. - // If it has changed, check if there has been a change to - // the routable IPs of the host and update them in - // the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the route change. - // If the hostinfo has changed, but not the routes, just update - // hostinfo and let the function continue. - if routesChanged { - // Auto approve any routes that have been defined in policy as - // auto approved. Check if this actually changed the node. - _ = s.AutoApproveRoutes(node) - - // Update the routes of the given node in the route manager to - // see if an update needs to be sent. - c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...) + // SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes() + // which returns only the intersection of announced AND approved routes. + // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())). + Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())). + Strs("subnetRoutes", util.PrefixesToString(updatedNode.SubnetRoutes())). + Msg("updating node routes for distribution") + nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...) } - // Check if there has been a change to Hostname and update them - // in the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the hostname change. - node.ApplyHostnameFromHostInfo(req.Hostinfo) - - _, policyChange, err := s.SaveNode(node) + _, policyChange, err := s.persistNodeToDB(id) if err != nil { - return change.EmptySet, err + return change.EmptySet, fmt.Errorf("saving to database: %w", err) } if policyChange.IsFull() { - c = policyChange + return policyChange, nil + } + if !nodeRouteChange.Empty() { + return nodeRouteChange, nil } - if c.Empty() { - c = change.NodeAdded(node.ID) - } - - return c, nil + return change.NodeAdded(id), nil } -// hostInfoChanged reports if hostInfo has changed in two ways, -// - first bool reports if an update needs to be sent to nodes -// - second reports if there has been changes to routes -// the caller can then use this info to save and update nodes -// and routes as needed. -func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { - if old.Equal(new) { - return false, false +func hostinfoEqual(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { + if !oldNode.Valid() && new == nil { + return true + } + if !oldNode.Valid() || new == nil { + return false + } + old := oldNode.AsStruct().Hostinfo + + return old.Equal(new) +} + +func routesChanged(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { + var oldRoutes []netip.Prefix + if oldNode.Valid() && oldNode.AsStruct().Hostinfo != nil { + oldRoutes = oldNode.AsStruct().Hostinfo.RoutableIPs } - if old == nil && new != nil { - return true, true - } - - // Routes - oldRoutes := make([]netip.Prefix, 0) - if old != nil { - oldRoutes = old.RoutableIPs - } newRoutes := new.RoutableIPs + if newRoutes == nil { + newRoutes = []netip.Prefix{} + } tsaddr.SortPrefixes(oldRoutes) tsaddr.SortPrefixes(newRoutes) - if !xslices.Equal(oldRoutes, newRoutes) { - return true, true - } - - // Services is mostly useful for discovery and not critical, - // except for peerapi, which is how nodes talk to each other. - // If peerapi was not part of the initial mapresponse, we - // need to make sure its sent out later as it is needed for - // Taildrop. - // TODO(kradalby): Length comparison is a bit naive, replace. - if len(old.Services) != len(new.Services) { - return true, false - } - - return false, false + return !slices.Equal(oldRoutes, newRoutes) } func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 3301cb35..e38a98f6 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool { case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate: return true } + return false } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 81a2a86a..959572a2 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -13,6 +13,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/tsaddr" @@ -355,6 +356,7 @@ func (node *Node) Proto() *v1.Node { GivenName: node.GivenName, User: node.User.Proto(), ForcedTags: node.ForcedTags, + Online: node.IsOnline != nil && *node.IsOnline, // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has // to be populated manually with PrimaryRoute, to ensure it includes the @@ -419,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix { } // SubnetRoutes returns the list of routes that the node announces and are approved. +// +// IMPORTANT: This method is used for internal data structures and should NOT be used +// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually +// with PrimaryRoutes to ensure it includes only routes actively served by the node. +// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto. func (node *Node) SubnetRoutes() []netip.Prefix { var routes []netip.Prefix @@ -511,11 +518,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } if node.Hostname != hostInfo.Hostname { + log.Trace(). + Str("node.id", node.ID.String()). + Str("old_hostname", node.Hostname). + Str("new_hostname", hostInfo.Hostname). + Str("old_given_name", node.GivenName). + Bool("given_name_changed", node.GivenNameHasBeenChanged()). + Msg("Updating hostname from hostinfo") + if node.GivenNameHasBeenChanged() { node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname) } node.Hostname = hostInfo.Hostname + + log.Trace(). + Str("node.id", node.ID.String()). + Str("new_hostname", node.Hostname). + Str("new_given_name", node.GivenName). + Msg("Hostname updated") } } @@ -759,6 +780,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix { return v.ж.ExitRoutes() } +// RequestTags returns the ACL tags that the node is requesting. +func (v NodeView) RequestTags() []string { + if !v.Valid() || !v.Hostinfo().Valid() { + return []string{} + } + return v.Hostinfo().RequestTags().AsSlice() +} + +// Proto converts the NodeView to a protobuf representation. +func (v NodeView) Proto() *v1.Node { + if !v.Valid() { + return nil + } + return v.ж.Proto() +} + // HasIP reports if a node has a given IP address. func (v NodeView) HasIP(i netip.Addr) bool { if !v.Valid() { diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index d7bc7897..97bb3da4 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Parse each hop line - hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?") + hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) for i := 1; i < len(lines); i++ { matches := hopRegex.FindStringSubmatch(lines[i]) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index d118b643..394d219b 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.NoError(ct, err) assert.Equal(ct, "NeedsLogin", status.BackendState) } + assertTailscaleNodesLogout(t, allClients) }, shortAccessTTL+10*time.Second, 5*time.Second) } diff --git a/integration/general_test.go b/integration/general_test.go index 4bf36567..9da61958 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -547,6 +547,8 @@ func TestUpdateHostnameFromClient(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second var nodes []*v1.Node assert.EventuallyWithT(t, func(ct *assert.CollectT) { err := executeAndUnmarshal( @@ -642,27 +644,34 @@ func TestUpdateHostnameFromClient(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "node", - "list", - "--output", - "json", - }, - &nodes, - ) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second + assert.Eventually(t, func() bool { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) - assertNoErr(t, err) - assert.Len(t, nodes, 3) + if err != nil || len(nodes) != 3 { + return false + } - for _, node := range nodes { - hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] - givenName := fmt.Sprintf("%d-givenname", node.GetId()) - assert.Equal(t, hostname+"NEW", node.GetName()) - assert.Equal(t, givenName, node.GetGivenName()) - } + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName { + return false + } + } + return true + }, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix") } func TestExpireNode(t *testing.T) { diff --git a/integration/route_test.go b/integration/route_test.go index 7243d3f2..bb13a47f 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, node.GetSubnetRoutes(), 1) } - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - assert.NotNil(t, peerStatus.PrimaryRoutes) - - assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3) - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + assert.NotNil(c, peerStatus.PrimaryRoutes) + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + } } - } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") _, err = headscale.ApproveRoutes( 1, @@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - - for _, node := range nodes { - if node.GetId() == 1 { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 - assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 - assert.Empty(t, node.GetSubnetRoutes()) - } else if node.GetId() == 2 { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 - assert.Empty(t, node.GetApprovedRoutes()) - assert.Empty(t, node.GetSubnetRoutes()) - } else { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 - assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 - assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + for _, node := range nodes { + if node.GetId() == 1 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetSubnetRoutes()) + } else if node.GetId() == 2 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetApprovedRoutes()) + assert.Empty(c, node.GetSubnetRoutes()) + } else { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + } } - } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the clients can see the new routes for _, client := range allClients { @@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - time.Sleep(3 * time.Second) + // Wait for route configuration changes after advertising routes + var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err := headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 0, 0) - requireNodeRouteCount(t, nodes[1], 1, 0, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved") // Verify that no routes has been sent to the client, // they are not yet enabled. @@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on first subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 0, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route") // Verify that the client has routes from the primary machine and can access // the webservice. @@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on second subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route") // Verify that the client has routes from the primary machine srs1 = subRouter1.MustStatus() @@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on third subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0) + }, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route") // Verify that the client has routes from the primary machine srs1 = subRouter1.MustStatus() @@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NoError(t, err) assert.Len(t, result, 13) - tr, err = client.Traceroute(webip) - require.NoError(t, err) - assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + // Wait for traceroute to work correctly through the expected router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + + // Get the expected router IP - use a more robust approach to handle temporary disconnections + ips, err := subRouter1.IPs() + assert.NoError(c, err) + assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") + + var expectedIP netip.Addr + for _, ip := range ips { + if ip.Is4() { + expectedIP = ip + break + } + } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") + + assertTracerouteViaIPWithCollect(c, tr, expectedIP) + }, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1") // Take down the current primary t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname()) @@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter1.Down() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r1 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs2 = subRouter2.MustStatus() + clientStatus = client.MustStatus() - srs2 = subRouter2.MustStatus() - clientStatus = client.MustStatus() + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down") - assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up") - assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up") + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) require.NotNil(t, srs2PeerStatus.PrimaryRoutes) @@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter2.Down() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r2 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // TODO(kradalby): Check client status - // Both are expected to be down + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - // Verify that the route is not presented from either router - clientStatus, err = client.Status() - require.NoError(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down") - assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down") - assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down") + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.False(c, srs2PeerStatus.Online, "r2 should be offline") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter1.Up() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r1 comes back up + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - require.NoError(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down") - assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down") - assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available") + assert.True(c, srs1PeerStatus.Online, "r1 should be back online") + assert.False(c, srs2PeerStatus.Online, "r2 should still be offline") + assert.True(c, srs3PeerStatus.Online, "r3 should still be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter2.Up() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing to complete and online status to be updated + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - require.NoError(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up") - assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up") - assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up") + assert.True(c, srs1PeerStatus.Online, "r1 should be online") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname()) _, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{}) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + // After disabling route on r3, r1 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname()) _, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{}) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + // After disabling route on r1, r2 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) { util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()), ) - time.Sleep(5 * time.Second) + // Wait for route state changes after re-enabling r1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - require.Len(t, nodes, 2) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 0, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the client has routes from the primary machine srs1, _ := subRouter1.Status() @@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) { requireNodeRouteCount(t, nodes[0], 2, 2, 2) requireNodeRouteCount(t, nodes[1], 2, 2, 2) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - require.NotNil(t, peerStatus.AllowedIPs) - assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4) - assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) - assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + assert.NotNil(c, peerStatus.AllowedIPs) + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + } } - } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") } // TestSubnetRouterMultiNetwork is an evolution of the subnet router test. @@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes and clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 2) - requireNodeRouteCount(t, nodes[0], 1, 1, 1) + // Verify that the routes have been sent to the client + status, err = user2c.Status() + assert.NoError(c, err) - // Verify that the routes have been sent to the client. - status, err = user2c.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref) - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref}) - } + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients") usernet1, err := scenario.Network("usernet1") require.NoError(t, err) @@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes and clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 2) - requireNodeRouteCount(t, nodes[0], 2, 2, 2) + // Verify that the routes have been sent to the client + status, err = user2c.Status() + assert.NoError(c, err) - // Verify that the routes have been sent to the client. - status, err = user2c.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) - } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients") // Tell user2c to use user1c as an exit node. command = []string{ @@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) + var nodes []*v1.Node opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoErrorf(t, err, "failed to advertise route: %s", err) } - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err := headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err := client.Status() @@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = headscale.SetPolicy(tt.pol) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = headscale.SetPolicy(tt.pol) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerSubRoute.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") requireNodeRouteCount(t, nodes[1], 1, 1, 1) // Verify that the routes have been sent to the client. @@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerSubRoute.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") requireNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[2], 0, 0, 0) @@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerExitNode.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 2, 2, 2) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) { require.Equal(t, tr.Route[0].IP, ip) } +// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT +func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { + assert.NotNil(c, tr) + assert.True(c, tr.Success) + assert.NoError(c, tr.Err) + assert.NotEmpty(c, tr.Route) + assert.Equal(c, tr.Route[0].IP, ip) +} + // requirePeerSubnetRoutes asserts that the peer has the expected subnet routes. func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) { t.Helper() @@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected } } +func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) { + if status.AllowedIPs.Len() <= 2 && len(expected) != 0 { + assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected)) + return + } + + if len(expected) == 0 { + expected = []netip.Prefix{} + } + + got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool { + if tsaddr.IsExitRoute(p) { + return true + } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) + }) + + if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { + assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff)) + } +} + func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) { t.Helper() require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) @@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) } +func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) { + assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) + assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes())) + assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) +} + // TestSubnetRouteACLFiltering tests that a node can only access subnet routes // that are explicitly allowed in the ACL. func TestSubnetRouteACLFiltering(t *testing.T) { @@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) { ) require.NoError(t, err) - // Give some time for the routes to propagate - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // List nodes and verify the router has 3 available routes + nodes, err = headscale.NodesByUser() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - // List nodes and verify the router has 3 available routes - nodes, err = headscale.NodesByUser() - require.NoError(t, err) - require.Len(t, nodes, 2) + // Find the router node + routerNode = nodes[routerUser][0] - // Find the router node - routerNode = nodes[routerUser][0] - - // Check that the router has 3 routes now approved and available - requireNodeRouteCount(t, routerNode, 3, 3, 3) + // Check that the router has 3 routes now approved and available + requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Now check the client node status nodeStatus, err := nodeClient.Status() diff --git a/integration/scenario.go b/integration/scenario.go index 817d927b..8ce54b89 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -14,7 +14,6 @@ import ( "net/netip" "net/url" "os" - "sort" "strconv" "strings" "sync" @@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { return nil, fmt.Errorf("no network named: %s", name) } - for _, ipam := range net.Network.IPAM.Config { - pref, err := netip.ParsePrefix(ipam.Subnet) - if err != nil { - return nil, err - } - - return &pref, nil + if len(net.Network.IPAM.Config) == 0 { + return nil, fmt.Errorf("no IPAM config found in network: %s", name) } - return nil, fmt.Errorf("no prefix found in network: %s", name) + pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) + if err != nil { + return nil, err + } + + return &pref, nil } func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { @@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv( return err } - sort.Strings(s.spec.Users) for _, user := range s.spec.Users { u, err := s.CreateUser(user) if err != nil {