From c24b9882478658b8c928844bec9c1623721904d6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 6 Aug 2025 08:46:12 +0200 Subject: [PATCH] rest Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 35 +- hscontrol/auth.go | 47 +- hscontrol/db/db.go | 2 +- hscontrol/db/node.go | 78 +- hscontrol/db/node_test.go | 109 ++- hscontrol/db/users.go | 21 +- hscontrol/grpcv1.go | 82 +- hscontrol/handlers.go | 5 +- hscontrol/mapper/batcher.go | 7 +- hscontrol/mapper/batcher_lockfree.go | 73 +- hscontrol/mapper/batcher_test.go | 116 ++- hscontrol/mapper/builder.go | 99 ++- hscontrol/mapper/builder_test.go | 114 +-- hscontrol/mapper/mapper.go | 56 +- hscontrol/mapper/tail.go | 14 +- hscontrol/noise.go | 10 +- hscontrol/policy/policy.go | 75 +- hscontrol/policy/policy_autoapprove_test.go | 339 +++++++ .../policy/policy_route_approval_test.go | 361 ++++++++ hscontrol/policy/route_approval_test.go | 23 + hscontrol/policy/v2/policy.go | 8 +- hscontrol/poll.go | 90 +- hscontrol/state/node_store.go | 214 ++++- hscontrol/state/node_store_test.go | 7 + hscontrol/state/state.go | 834 ++++++++++++------ hscontrol/types/change/change.go | 1 + hscontrol/types/node.go | 8 +- 27 files changed, 2073 insertions(+), 755 deletions(-) create mode 100644 hscontrol/policy/policy_autoapprove_test.go create mode 100644 hscontrol/policy/policy_route_approval_test.go diff --git a/hscontrol/app.go b/hscontrol/app.go index 27746b8e..daf64e39 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -136,9 +136,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { // Initialize ephemeral garbage collector ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { - node, err := app.state.GetNodeByID(ni) - if err != nil { - log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion") + node, ok := app.state.GetNodeByID(ni) + if !ok { + log.Warn().Uint64("node.id", ni.Uint64()).Msgf("ephemeral node not found for deletion") return } @@ -371,7 +371,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler Str("client_address", req.RemoteAddr). Msg("HTTP authentication invoked") - authHeader := req.Header.Get("authorization") + authHeader := req.Header.Get("Authorization") if !strings.HasPrefix(authHeader, AuthPrefix) { log.Error(). @@ -487,11 +487,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { + var err error capver.CanOldCodeBeCleanedUp() if profilingEnabled { if profilingPath != "" { - err := os.MkdirAll(profilingPath, os.ModePerm) + err = os.MkdirAll(profilingPath, os.ModePerm) if err != nil { log.Fatal().Err(err).Msg("failed to create profiling directory") } @@ -543,10 +544,7 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() - ephmNodes, err := h.state.ListEphemeralNodes() - if err != nil { - return fmt.Errorf("failed to list ephemeral nodes: %w", err) - } + ephmNodes := h.state.ListEphemeralNodes() for _, node := range ephmNodes.All() { h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) } @@ -778,23 +776,14 @@ func (h *Headscale) Serve() error { continue } - changed, err := h.state.ReloadPolicy() + changes, err := h.state.ReloadPolicy() if err != nil { log.Error().Err(err).Msgf("reloading policy") continue } - if changed { - log.Info(). - Msg("ACL policy successfully reloaded, notifying nodes of change") + h.Change(changes...) - err = h.state.AutoApproveNodes() - if err != nil { - log.Error().Err(err).Msg("failed to approve routes after new policy") - } - - h.Change(change.PolicySet) - } default: info := func(msg string) { log.Info().Msg(msg) } log.Info(). @@ -1004,6 +993,8 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { // Change is used to send changes to nodes. // All change should be enqueued here and empty will be automatically // ignored. -func (h *Headscale) Change(c change.ChangeSet) { - h.mapBatcher.AddWork(c) +func (h *Headscale) Change(cs ...change.ChangeSet) { + for _, c := range cs { + h.mapBatcher.AddWork(c) + } } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index ca46aec3..2a7f1976 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -11,7 +11,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -28,27 +27,9 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, err := h.state.GetNodeByNodeKey(regReq.NodeKey) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("looking up node in database: %w", err) - } - - if node.Valid() { - // If an existing node is trying to register with an auth key, - // we need to validate the auth key even for existing nodes - if regReq.Auth != nil && regReq.Auth.AuthKey != "" { - resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) - if err != nil { - // Preserve HTTPError types so they can be handled properly by the HTTP layer - var httpErr HTTPError - if errors.As(err, &httpErr) { - return nil, httpErr - } - return nil, fmt.Errorf("handling register with auth key for existing node: %w", err) - } - return resp, nil - } + node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey) + if ok { resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey) if err != nil { return nil, fmt.Errorf("handling existing node: %w", err) @@ -69,6 +50,7 @@ func (h *Headscale) handleRegister( if errors.As(err, &httpErr) { return nil, httpErr } + return nil, fmt.Errorf("handling register with auth key: %w", err) } @@ -88,13 +70,22 @@ func (h *Headscale) handleExistingNode( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - if node.MachineKey != machineKey { return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } expired := node.IsExpired() + // If the node is expired and this is not a re-authentication attempt, + // force the client to re-authenticate + if expired && regReq.Auth == nil { + return &tailcfg.RegisterResponse{ + NodeKeyExpired: true, + MachineAuthorized: false, + AuthURL: "", // Client will need to re-authenticate + }, nil + } + if !expired && !regReq.Expiry.IsZero() { requestExpiry := regReq.Expiry @@ -117,12 +108,16 @@ func (h *Headscale) handleExistingNode( } } - _, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) + updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) if err != nil { return nil, fmt.Errorf("setting node expiry: %w", err) } h.Change(c) + + // CRITICAL: Use the updated node view for the response + // The original node object has stale expiry information + node = updatedNode.AsStruct() } return nodeToRegisterResponse(node), nil @@ -192,8 +187,8 @@ func (h *Headscale) handleRegisterWithAuthKey( return nil, err } - // If node is nil, it means an ephemeral node was deleted during logout - if node.Valid() { + // If node is not valid, it means an ephemeral node was deleted during logout + if !node.Valid() { h.Change(changed) return nil, nil } @@ -212,6 +207,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // TODO(kradalby): This needs to be ran as part of the batcher maybe? // now since we dont update the node/pol here anymore routeChange := h.state.AutoApproveRoutes(node) + if _, _, err := h.state.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } @@ -229,6 +225,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // } user := node.User() + return &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index d2f39ff0..53786bb6 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -936,7 +936,7 @@ AND auth_key_id NOT IN ( // - NEVER use gorm.AutoMigrate, write the exact migration steps needed // - AutoMigrate depends on the struct staying exactly the same, which it won't over time. // - Never write migrations that requires foreign keys to be disabled. - }, + }, ) if err := runMigrations(cfg, dbConn, migrations); err != nil { diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 83d62d3d..9a278276 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { } // RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. If the name is not unique, it will return an error. +// and renames it. Validation should be done in the state layer before calling this function. func RenameNode(tx *gorm.DB, nodeID types.NodeID, newName string, ) error { - err := util.CheckForFQDNRules( - newName, - ) - if err != nil { - return fmt.Errorf("renaming node: %w", err) + // Check if the new name is unique + var count int64 + if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { + return fmt.Errorf("failed to check name uniqueness: %w", err) } - - uniq, err := isUniqueName(tx, newName) - if err != nil { - return fmt.Errorf("checking if name is unique: %w", err) - } - - if !uniq { - return fmt.Errorf("name is not unique: %s", newName) + + if count > 0 { + return fmt.Errorf("name is not unique") } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -409,9 +403,16 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, err } + // CRITICAL: Reload the node to get the updated expiry + // Without this, we return stale node data to NodeStore + updatedNode, err := GetNodeByID(tx, node.ID) + if err != nil { + return nil, fmt.Errorf("failed to reload node after expiry update: %w", err) + } + nodeChange = change.KeyExpiry(node.ID) - return node, nil + return updatedNode, nil } } @@ -445,8 +446,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.ID = oldNode.ID node.GivenName = oldNode.GivenName node.ApprovedRoutes = oldNode.ApprovedRoutes - ipv4 = oldNode.IPv4 - ipv6 = oldNode.IPv6 + // Don't overwrite the provided IPs with old ones when they exist + if ipv4 == nil { + ipv4 = oldNode.IPv4 + } + if ipv6 == nil { + ipv6 = oldNode.IPv6 + } } // If the node exists and it already has IP(s), we just save it @@ -781,19 +787,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname . node := hsdb.CreateNodeForTest(user, hostname...) - err := hsdb.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, *node, nil, nil) + // Allocate IPs for the test node using the database's IP allocator + // This is a simplified allocation for testing - in production this would use State.ipAlloc + ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID) + if err != nil { + panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err)) + } + + var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { + var err error + registeredNode, err = RegisterNode(tx, *node, ipv4, ipv6) return err }) if err != nil { panic(fmt.Sprintf("failed to register test node: %v", err)) } - registeredNode, err := hsdb.GetNodeByID(node.ID) - if err != nil { - panic(fmt.Sprintf("failed to get registered test node: %v", err)) - } - return registeredNode } @@ -842,3 +852,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int return nodes } + +// allocateTestIPs allocates sequential test IPs for nodes during testing. +func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) { + if !testing.Testing() { + panic("allocateTestIPs can only be called during tests") + } + + // Use simple sequential allocation for tests + // IPv4: 100.64.0.x (where x is nodeID) + // IPv6: fd7a:115c:a1e0::x (where x is nodeID) + + if nodeID > 254 { + return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID) + } + + ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)}) + ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)}) + + return &ipv4, &ipv6, nil +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 8819fbcf..59d9941a 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) { func TestAutoApproveRoutes(t *testing.T) { tests := []struct { - name string - acl string - routes []netip.Prefix - want []netip.Prefix - want2 []netip.Prefix + name string + acl string + routes []netip.Prefix + want []netip.Prefix + want2 []netip.Prefix + expectChange bool // whether to expect route changes }{ + { + name: "no-auto-approvers-empty-policy", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ] +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - no auto-approvers + want2: []netip.Prefix{}, // Should be empty - no auto-approvers + expectChange: false, // No changes expected + }, + { + name: "no-auto-approvers-explicit-empty", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ], + "autoApprovers": { + "routes": {}, + "exitNode": [] + } +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + expectChange: false, // No changes expected + }, { name: "2068-approve-issue-sub-kube", acl: ` @@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) { } } }`, - routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, - want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + expectChange: true, // Routes should be approved }, { name: "2068-approve-issue-sub-exit-tag", @@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) { tsaddr.AllIPv4(), tsaddr.AllIPv6(), }, + expectChange: true, // Routes should be approved }, } @@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) { require.NoError(t, err) require.NotNil(t, pm) - changed1 := policy.AutoApproveRoutes(pm, &node) - assert.True(t, changed1) + newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes) + assert.Equal(t, tt.expectChange, changed1) - err = adb.DB.Save(&node).Error - require.NoError(t, err) + if changed1 { + err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1) + require.NoError(t, err) + } - _ = policy.AutoApproveRoutes(pm, &nodeTagged) - - err = adb.DB.Save(&nodeTagged).Error - require.NoError(t, err) + newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes) + if changed2 { + err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2) + require.NoError(t, err) + } node1ByID, err := adb.GetNodeByID(1) require.NoError(t, err) - if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { + // For empty auto-approvers tests, handle nil vs empty slice comparison + expectedRoutes1 := tt.want + if len(expectedRoutes1) == 0 { + expectedRoutes1 = nil + } + if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } node2ByID, err := adb.GetNodeByID(2) require.NoError(t, err) - if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { + expectedRoutes2 := tt.want2 + if len(expectedRoutes2) == 0 { + expectedRoutes2 = nil + } + if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } }) @@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) { // No parameter means no filter, should return all peers nodes, err = db.ListPeers(1) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Empty node list should return all peers nodes, err = db.ListPeers(1, types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // No match in IDs should return empty list and no error @@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) { // Partial match in IDs nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs, but node ID is still filtered out nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) } @@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) { // No parameter means no filter, should return all nodes nodes, err = db.ListNodes() require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) // Empty node list should return all nodes nodes, err = db.ListNodes(types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) @@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) { // Partial match in IDs nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 1b333792..3684eb4a 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { } // AssignNodeToUser assigns a Node to a user. +// Note: Validation should be done in the state layer before calling this function. func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error { - node, err := GetNodeByID(tx, nodeID) - if err != nil { - return err + // Check if the user exists + var userExists bool + if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil { + return fmt.Errorf("failed to check if user exists: %w", err) } - user, err := GetUserByID(tx, uid) - if err != nil { - return err + + if !userExists { + return ErrUserNotFound } - node.User = *user - node.UserID = user.ID - if result := tx.Save(&node); result.Error != nil { - return result.Error + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil { + return fmt.Errorf("failed to assign node to user: %w", err) } return nil diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 4ef52106..20ff25c1 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -288,9 +288,9 @@ func (api headscaleV1APIServer) GetNode( ctx context.Context, request *v1.GetNodeRequest, ) (*v1.GetNodeResponse, error) { - node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } resp := node.Proto() @@ -334,7 +334,12 @@ func (api headscaleV1APIServer) SetApprovedRoutes( ctx context.Context, request *v1.SetApprovedRoutesRequest, ) (*v1.SetApprovedRoutesResponse, error) { - var routes []netip.Prefix + log.Debug(). + Uint64("node.id", request.GetNodeId()). + Strs("requestedRoutes", request.GetRoutes()). + Msg("gRPC SetApprovedRoutes called") + + var newApproved []netip.Prefix for _, route := range request.GetRoutes() { prefix, err := netip.ParsePrefix(route) if err != nil { @@ -344,31 +349,34 @@ func (api headscaleV1APIServer) SetApprovedRoutes( // If the prefix is an exit route, add both. The client expect both // to annotate the node as an exit node. if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() { - routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6()) + newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6()) } else { - routes = append(routes, prefix) + newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(routes) - routes = slices.Compact(routes) + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) - node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes) + node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) if err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } - routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...) - // Always propagate node changes from SetApprovedRoutes api.h.Change(nodeChange) - // If routes changed, propagate those changes too - if !routeChange.Empty() { - api.h.Change(routeChange) - } - proto := node.Proto() - proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID())) + // Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the + // routes that are actively served from the node (per architectural requirement in types/node.go) + primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID()) + proto.SubnetRoutes = util.PrefixesToString(primaryRoutes) + + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). + Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)). + Strs("finalSubnetRoutes", proto.SubnetRoutes). + Msg("gRPC SetApprovedRoutes completed") return &v1.SetApprovedRoutesResponse{Node: proto}, nil } @@ -390,9 +398,9 @@ func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, ) (*v1.DeleteNodeResponse, error) { - node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } nodeChange, err := api.h.state.DeleteNode(node) @@ -463,19 +471,13 @@ func (api headscaleV1APIServer) ListNodes( return nil, err } - nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID)) - if err != nil { - return nil, err - } + nodes := api.h.state.ListNodesByUser(types.UserID(user.ID)) response := nodesToProto(api.h.state, IsConnected, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } - nodes, err := api.h.state.ListNodes() - if err != nil { - return nil, err - } + nodes := api.h.state.ListNodes() response := nodesToProto(api.h.state, IsConnected, nodes) return &v1.ListNodesResponse{Nodes: response}, nil @@ -499,6 +501,7 @@ func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID, } } resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...)) + resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) response[index] = resp } @@ -674,11 +677,8 @@ func (api headscaleV1APIServer) SetPolicy( // a scenario where they might be allowed if the server has no nodes // yet, but it should help for the general case and for hot reloading // configurations. - nodes, err := api.h.state.ListNodes() - if err != nil { - return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) - } - changed, err := api.h.state.SetPolicy([]byte(p)) + nodes := api.h.state.ListNodes() + _, err := api.h.state.SetPolicy([]byte(p)) if err != nil { return nil, fmt.Errorf("setting policy: %w", err) } @@ -695,16 +695,16 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - // Only send update if the packet filter has changed. - if changed { - err = api.h.state.AutoApproveNodes() - if err != nil { - return nil, err - } - - api.h.Change(change.PolicyChange()) + // Always reload policy to ensure route re-evaluation, even if policy content hasn't changed. + // This ensures that routes are re-evaluated for auto-approval in cases where routes + // were manually disabled but could now be auto-approved with the current policy. + cs, err := api.h.state.ReloadPolicy() + if err != nil { + return nil, fmt.Errorf("reloading policy: %w", err) } + api.h.Change(cs...) + response := &v1.SetPolicyResponse{ Policy: updated.Data, UpdatedAt: timestamppb.New(updated.UpdatedAt), diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index b1e2be8d..aece38b1 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -94,10 +94,7 @@ func (h *Headscale) handleVerifyRequest( return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err) } - nodes, err := h.state.ListNodes() - if err != nil { - return fmt.Errorf("cannot list nodes: %w", err) - } + nodes := h.state.ListNodes() // Check if any node has the requested NodeKey var nodeKeyFound bool diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 21b2209f..8174c0f2 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -1,6 +1,7 @@ package mapper import ( + "errors" "fmt" "time" @@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher type Batcher interface { Start() Close() - AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error - RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error + RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] AddWork(c change.ChangeSet) @@ -119,7 +120,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet]. func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error { if nc == nil { - return fmt.Errorf("nodeConnection is nil") + return errors.New("nodeConnection is nil") } nodeID := nc.nodeID() diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index aeafa001..e4ff3237 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -3,7 +3,6 @@ package mapper import ( "context" "fmt" - "sync" "sync/atomic" "time" @@ -21,7 +20,6 @@ type LockFreeBatcher struct { mapper *mapper workers int - // Lock-free concurrent maps nodes *xsync.Map[types.NodeID, *nodeConn] connected *xsync.Map[types.NodeID, *time.Time] @@ -32,7 +30,6 @@ type LockFreeBatcher struct { // Batching state pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] - batchMutex sync.RWMutex // Metrics totalNodes atomic.Int64 @@ -46,16 +43,13 @@ type LockFreeBatcher struct { // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. // TODO(kradalby): See if we can move the isRouter argument somewhere else. -func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error { - // First validate that we can generate initial map before doing anything else - fullSelfChange := change.FullSelf(id) - +func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { // TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange. // This currently means that the goroutine for the node connection will do the processing // which means that we might have uncontrolled concurrency. // When we use MapResponseFromChange, it will be processed by the same worker pool, causing // it to be processed in a more controlled manner. - initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange) + initialMap, err := generateMapResponse(id, version, b.mapper, change.FullSelf(id)) if err != nil { return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) } @@ -73,10 +67,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse conn = newConn } - // Mark as connected only after validation succeeds b.connected.Store(id, nil) // nil = connected - log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher") + log.Info().Uint64("node.id", id.Uint64()).Msg("Node connected to batcher") // Send the validated initial map if initialMap != nil { @@ -86,9 +79,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse b.connected.Delete(id) return fmt.Errorf("failed to send initial map to node %d: %w", id, err) } - - // Notify other nodes that this node came online - b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter}) } return nil @@ -97,12 +87,14 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse // RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. // It validates the connection channel matches the current one, closes the connection, // and notifies other nodes that this node has gone offline. -func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) { +// It reports if the node was actually closed. Returns false if the channel does not match the current connection, +// indicating that we are actually not disconnecting the node, but rather ignoring the request. +func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { // Check if this is the current connection and mark it as closed if existing, ok := b.nodes.Load(id); ok { if !existing.matchesChannel(c) { - log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring") - return // Not the current connection, not an error + log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called on a different channel, ignoring") + return false // Not the current connection, not an error } // Mark the connection as closed to prevent further sends @@ -111,15 +103,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo } } - log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline") + log.Info().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher, marking as offline") // Remove node and mark disconnected atomically b.nodes.Delete(id) b.connected.Store(id, ptr.To(time.Now())) b.totalNodes.Add(-1) - // Notify other nodes that this node went offline - b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter}) + return true } // AddWork queues a change to be processed by the batcher. @@ -214,6 +205,7 @@ func (b *LockFreeBatcher) worker(workerID int) { Dur("duration", duration). Msg("slow synchronous work processing") } + continue } @@ -228,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) { Uint64("node.id", w.nodeID.Uint64()). Str("change", w.c.Change.String()). Msg("skipping work for closed connection") + continue } @@ -240,12 +233,6 @@ func (b *LockFreeBatcher) worker(workerID int) { Str("change", w.c.Change.String()). Msg("failed to apply change") } - } else { - log.Debug(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Msg("node not found for asynchronous work - node may have disconnected") } duration := time.Since(startTime) @@ -276,8 +263,10 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) { return true } b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil}) + return true }) + return } @@ -285,7 +274,7 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) { b.addToBatch(c) } -// queueWork safely queues work +// queueWork safely queues work. func (b *LockFreeBatcher) queueWork(w work) { b.workQueuedCount.Add(1) @@ -298,7 +287,7 @@ func (b *LockFreeBatcher) queueWork(w work) { } } -// shouldProcessImmediately determines if a change should bypass batching +// shouldProcessImmediately determines if a change should bypass batching. func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool { // Process these changes immediately to avoid delaying critical functionality switch c.Change { @@ -309,11 +298,8 @@ func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool { } } -// addToBatch adds a change to the pending batch +// addToBatch adds a change to the pending batch. func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { - b.batchMutex.Lock() - defer b.batchMutex.Unlock() - if c.SelfUpdateOnly { changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{}) changes = append(changes, c) @@ -329,15 +315,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) changes = append(changes, c) b.pendingChanges.Store(nodeID, changes) + return true }) } -// processBatchedChanges processes all pending batched changes +// processBatchedChanges processes all pending batched changes. func (b *LockFreeBatcher) processBatchedChanges() { - b.batchMutex.Lock() - defer b.batchMutex.Unlock() - if b.pendingChanges == nil { return } @@ -355,17 +339,27 @@ func (b *LockFreeBatcher) processBatchedChanges() { // Clear the pending changes for this node b.pendingChanges.Delete(nodeID) + return true }) } // IsConnected is lock-free read. func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { - if val, ok := b.connected.Load(id); ok { - // nil means connected - return val == nil + val, ok := b.connected.Load(id) + if !ok { + return false } - return false + + // nil means connected + if val == nil { + return true + } + + // During grace period, always return true to allow DNS resolution + // for logout HTTP requests to complete successfully + gracePeriod := 45 * time.Second + return time.Since(*val) < gracePeriod } // ConnectedMap returns a lock-free map of all connected nodes. @@ -487,5 +481,6 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error { // the channel is still open. connData.c <- data nc.updateCount.Add(1) + return nil } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 9419a008..7903fe22 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -26,6 +26,43 @@ type batcherTestCase struct { fn batcherFunc } +// testBatcherWrapper wraps a real batcher to add online/offline notifications +// that would normally be sent by poll.go in production +type testBatcherWrapper struct { + Batcher +} + +func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + // First add the node to the real batcher + err := t.Batcher.AddNode(id, c, version) + if err != nil { + return err + } + + // Then send the online notification that poll.go would normally send + t.Batcher.AddWork(change.NodeOnline(id)) + + return nil +} + +func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + // First remove from the real batcher + removed := t.Batcher.RemoveNode(id, c) + if !removed { + return false + } + + // Then send the offline notification that poll.go would normally send + t.Batcher.AddWork(change.NodeOffline(id)) + + return true +} + +// wrapBatcherForTest wraps a batcher with test-specific behavior +func wrapBatcherForTest(b Batcher) Batcher { + return &testBatcherWrapper{Batcher: b} +} + // allBatcherFunctions contains all batcher implementations to test. var allBatcherFunctions = []batcherTestCase{ {"LockFree", NewBatcherAndMapper}, @@ -176,8 +213,8 @@ func setupBatcherWithTestData( "acls": [ { "action": "accept", - "users": ["*"], - "ports": ["*:*"] + "src": ["*"], + "dst": ["*:*"] } ] }` @@ -187,8 +224,8 @@ func setupBatcherWithTestData( t.Fatalf("Failed to set allow-all policy: %v", err) } - // Create batcher with the state - batcher := bf(cfg, state) + // Create batcher with the state and wrap it for testing + batcher := wrapBatcherForTest(bf(cfg, state)) batcher.Start() testData := &TestData{ @@ -455,7 +492,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) time.Sleep(100 * time.Millisecond) // Let connection settle // Generate some work @@ -558,7 +595,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) for i := range allNodes { node := &allNodes[i] - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullSet) @@ -606,7 +643,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Disconnect all nodes for i := range allNodes { node := &allNodes[i] - batcher.RemoveNode(node.n.ID, node.ch, false) + batcher.RemoveNode(node.n.ID, node.ch) } // Give time for final updates to process @@ -724,7 +761,7 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := testData.Nodes[1] // Test AddNode with real node ID - batcher.AddNode(tn.n.ID, tn.ch, false, 100) + batcher.AddNode(tn.n.ID, tn.ch, 100) if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") } @@ -744,14 +781,14 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) // Add the second node and verify update message - batcher.AddNode(tn2.n.ID, tn2.ch, false, 100) + batcher.AddNode(tn2.n.ID, tn2.ch, 100) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. select { case data := <-tn.ch: assertOnlineMapResponse(t, data, true) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Did not receive expected Online response update") } @@ -765,19 +802,19 @@ func TestBatcherBasicOperations(t *testing.T) { len(data.Peers) >= 1 || data.Node != nil, "Should receive initial full map", ) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Second node should receive its initial full map") } // Disconnect the second node - batcher.RemoveNode(tn2.n.ID, tn2.ch, false) + batcher.RemoveNode(tn2.n.ID, tn2.ch) assert.False(t, batcher.IsConnected(tn2.n.ID)) // First node should get update that second has disconnected. select { case data := <-tn.ch: assertOnlineMapResponse(t, data, false) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Did not receive expected Online response update") } @@ -803,7 +840,7 @@ func TestBatcherBasicOperations(t *testing.T) { // } // Test RemoveNode - batcher.RemoveNode(tn.n.ID, tn.ch, false) + batcher.RemoveNode(tn.n.ID, tn.ch) if batcher.IsConnected(tn.n.ID) { t.Error("Node should be disconnected after RemoveNode") } @@ -949,7 +986,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1045,7 +1082,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) }() // Add real work during connection chaos @@ -1059,7 +1096,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { go func() { defer wg.Done() time.Sleep(1 * time.Microsecond) - batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }() // Remove second connection @@ -1067,7 +1104,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { go func() { defer wg.Done() time.Sleep(2 * time.Microsecond) - batcher.RemoveNode(testNode.n.ID, ch2, false) + batcher.RemoveNode(testNode.n.ID, ch2) }() wg.Wait() @@ -1142,7 +1179,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) batcher.AddWork(change.DERPSet) // Consumer goroutine to validate data and detect channel issues @@ -1184,7 +1221,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Rapid removal creates race between worker and removal time.Sleep(time.Duration(i%3) * 100 * time.Microsecond) - batcher.RemoveNode(testNode.n.ID, ch, false) + batcher.RemoveNode(testNode.n.ID, ch) // Give workers time to process and close channels time.Sleep(5 * time.Millisecond) @@ -1254,7 +1291,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for _, node := range stableNodes { ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1312,7 +1349,7 @@ func TestBatcherConcurrentClients(t *testing.T) { churningChannelsMutex.Lock() churningChannels[nodeID] = ch churningChannelsMutex.Unlock() - batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { @@ -1349,7 +1386,7 @@ func TestBatcherConcurrentClients(t *testing.T) { ch, exists := churningChannels[nodeID] churningChannelsMutex.Unlock() if exists { - batcher.RemoveNode(nodeID, ch, false) + batcher.RemoveNode(nodeID, ch) } }(node.n.ID) } @@ -1599,7 +1636,7 @@ func XTestBatcherScalability(t *testing.T) { var connectedNodesMutex sync.RWMutex for i := range testNodes { node := &testNodes[i] - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true connectedNodesMutex.Unlock() @@ -1666,7 +1703,7 @@ func XTestBatcherScalability(t *testing.T) { connectedNodesMutex.RUnlock() if isConnected { - batcher.RemoveNode(nodeID, channel, false) + batcher.RemoveNode(nodeID, channel) connectedNodesMutex.Lock() connectedNodes[nodeID] = false connectedNodesMutex.Unlock() @@ -1690,7 +1727,6 @@ func XTestBatcherScalability(t *testing.T) { batcher.AddNode( nodeID, channel, - false, tailcfg.CapabilityVersion(100), ) connectedNodesMutex.Lock() @@ -1792,7 +1828,7 @@ func XTestBatcherScalability(t *testing.T) { // Now disconnect all nodes from batcher to stop new updates for i := range testNodes { node := &testNodes[i] - batcher.RemoveNode(node.n.ID, node.ch, false) + batcher.RemoveNode(node.n.ID, node.ch) } // Give time for enhanced tracking goroutines to process any remaining data in channels @@ -1924,7 +1960,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time to avoid overwhelming the work queue for i, node := range allNodes { - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Small delay between connections to allow NodeCameOnline processing time.Sleep(50 * time.Millisecond) @@ -1936,12 +1972,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Check how many peers each node should see for i, node := range allNodes { - peers, err := testData.State.ListPeers(node.n.ID) - if err != nil { - t.Errorf("Error listing peers for node %d: %v", i, err) - } else { - t.Logf("Node %d should see %d peers from state", i, peers.Len()) - } + peers := testData.State.ListPeers(node.n.ID) + t.Logf("Node %d should see %d peers from state", i, peers.Len()) } // Send a full update - this should generate full peer lists @@ -1957,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { foundFullUpdate := false // Read all available updates for each node - for i := range len(allNodes) { + for i := range allNodes { nodeUpdates := 0 t.Logf("Reading updates for node %d:", i) @@ -2047,7 +2079,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) { t.Logf("=== WORK QUEUE TRACING TEST ===") // Connect first node - batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(nodes[0].n.ID, nodes[0].ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d", nodes[0].n.ID) // Wait for initial NodeCameOnline to be processed @@ -2102,14 +2134,10 @@ func TestBatcherWorkQueueTracing(t *testing.T) { } // Check if there should be peers available - peers, err := testData.State.ListPeers(nodes[0].n.ID) - if err != nil { - t.Errorf("Error getting peers from state: %v", err) - } else { - t.Logf("State shows %d peers available for this node", peers.Len()) - if peers.Len() > 0 && len(data.Peers) == 0 { - t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len()) - } + peers := testData.State.ListPeers(nodes[0].n.ID) + t.Logf("State shows %d peers available for this node", peers.Len()) + if peers.Len() > 0 && len(data.Peers) == 0 { + t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len()) } } else { t.Errorf("Response data is nil") diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 28bca095..4072d33c 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -1,6 +1,7 @@ package mapper import ( + "errors" "net/netip" "sort" "time" @@ -8,11 +9,12 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/multierr" ) -// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse +// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse. type MapResponseBuilder struct { resp *tailcfg.MapResponse mapper *mapper @@ -21,7 +23,7 @@ type MapResponseBuilder struct { errs []error } -// NewMapResponseBuilder creates a new builder with basic fields set +// NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() return &MapResponseBuilder{ @@ -35,37 +37,44 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder } } -// addError adds an error to the builder's error list +// addError adds an error to the builder's error list. func (b *MapResponseBuilder) addError(err error) { if err != nil { b.errs = append(b.errs, err) } } -// hasErrors returns true if the builder has accumulated any errors +// hasErrors returns true if the builder has accumulated any errors. func (b *MapResponseBuilder) hasErrors() bool { return len(b.errs) > 0 } -// WithCapabilityVersion sets the capability version for the response +// WithCapabilityVersion sets the capability version for the response. func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder { b.capVer = capVer return b } -// WithSelfNode adds the requesting node to the response +// WithSelfNode adds the requesting node to the response. func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } + // Always use batcher's view of online status for self node + // The batcher respects grace periods for logout scenarios + node := nodeView.AsStruct() + if b.mapper.batcher != nil { + node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID)) + } + _, matchers := b.mapper.state.Filter() tailnode, err := tailNode( - node, b.capVer, b.mapper.state, + node.View(), b.capVer, b.mapper.state, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) + return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) }, b.mapper.cfg) if err != nil { @@ -74,29 +83,30 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { } b.resp.Node = tailnode + return b } -// WithDERPMap adds the DERP map to the response +// WithDERPMap adds the DERP map to the response. func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder { b.resp.DERPMap = b.mapper.state.DERPMap() return b } -// WithDomain adds the domain configuration +// WithDomain adds the domain configuration. func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder { b.resp.Domain = b.mapper.cfg.Domain() return b } -// WithCollectServicesDisabled sets the collect services flag to false +// WithCollectServicesDisabled sets the collect services flag to false. func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder { b.resp.CollectServices.Set(false) return b } // WithDebugConfig adds debug configuration -// It disables log tailing if the mapper's LogTail is not enabled +// It disables log tailing if the mapper's LogTail is not enabled. func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, @@ -104,11 +114,11 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { return b } -// WithSSHPolicy adds SSH policy configuration for the requesting node +// WithSSHPolicy adds SSH policy configuration for the requesting node. func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } @@ -119,38 +129,41 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { } b.resp.SSHPolicy = sshPolicy + return b } -// WithDNSConfig adds DNS configuration for the requesting node +// WithDNSConfig adds DNS configuration for the requesting node. func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node) + return b } -// WithUserProfiles adds user profiles for the requesting node and given peers +// WithUserProfiles adds user profiles for the requesting node and given peers. func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } b.resp.UserProfiles = generateUserProfiles(node, peers) + return b } -// WithPacketFilters adds packet filter rules based on policy +// WithPacketFilters adds packet filter rules based on policy. func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } @@ -167,9 +180,8 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { return b } -// WithPeers adds full peer list with policy filtering (for full map response) +// WithPeers adds full peer list with policy filtering (for full map response). func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder { - tailPeers, err := b.buildTailPeers(peers) if err != nil { b.addError(err) @@ -177,12 +189,12 @@ func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapRe } b.resp.Peers = tailPeers + return b } -// WithPeerChanges adds changed peers with policy filtering (for incremental updates) +// WithPeerChanges adds changed peers with policy filtering (for incremental updates). func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder { - tailPeers, err := b.buildTailPeers(peers) if err != nil { b.addError(err) @@ -190,14 +202,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) } b.resp.PeersChanged = tailPeers + return b } -// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting +// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting. func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - return nil, err + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + return nil, errors.New("node not found") } filter, matchers := b.mapper.state.Filter() @@ -229,24 +242,24 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ( return tailPeers, nil } -// WithPeerChangedPatch adds peer change patches +// WithPeerChangedPatch adds peer change patches. func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { b.resp.PeersChangedPatch = changes return b } -// WithPeersRemoved adds removed peer IDs +// WithPeersRemoved adds removed peer IDs. func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { - var tailscaleIDs []tailcfg.NodeID for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } b.resp.PeersRemoved = tailscaleIDs + return b } -// Build finalizes the response and returns marshaled bytes +// Build finalizes the response and returns marshaled bytes. func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) { if len(b.errs) > 0 { return nil, multierr.New(b.errs...) diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go index c8ff59ec..92082cf7 100644 --- a/hscontrol/mapper/builder_test.go +++ b/hscontrol/mapper/builder_test.go @@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) { Enabled: true, }, } - + mockState := &state.State{} m := &mapper{ cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID) - + // Test basic builder creation assert.NotNil(t, builder) assert.Equal(t, nodeID, builder.nodeID) @@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) capVer := tailcfg.CapabilityVersion(42) - + builder := m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer) - + assert.Equal(t, capVer, builder.capVer) assert.False(t, builder.hasErrors()) } @@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) { ServerURL: "https://test.example.com", BaseDomain: domain, } - + mockState := &state.State{} m := &mapper{ cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithDomain() - + assert.Equal(t, domain, builder.resp.Domain) assert.False(t, builder.hasErrors()) } @@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithCollectServicesDisabled() - + value, isSet := builder.resp.CollectServices.Get() assert.True(t, isSet) assert.False(t, value) @@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) { func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { tests := []struct { - name string + name string logTailEnabled bool - expected bool + expected bool }{ { - name: "LogTail enabled", + name: "LogTail enabled", logTailEnabled: true, - expected: false, // DisableLogTail should be false when LogTail is enabled + expected: false, // DisableLogTail should be false when LogTail is enabled }, { - name: "LogTail disabled", + name: "LogTail disabled", logTailEnabled: false, - expected: true, // DisableLogTail should be true when LogTail is disabled + expected: true, // DisableLogTail should be true when LogTail is disabled }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &types.Config{ @@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithDebugConfig() - + require.NotNil(t, builder.resp.Debug) assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail) assert.False(t, builder.hasErrors()) @@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) changes := []*tailcfg.PeerChange{ { - NodeID: 123, + NodeID: 123, DERPRegion: 1, }, { - NodeID: 456, + NodeID: 456, DERPRegion: 2, }, } - + builder := m.NewMapResponseBuilder(nodeID). WithPeerChangedPatch(changes) - + assert.Equal(t, changes, builder.resp.PeersChangedPatch) assert.False(t, builder.hasErrors()) } @@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) removedID1 := types.NodeID(123) removedID2 := types.NodeID(456) - + builder := m.NewMapResponseBuilder(nodeID). WithPeersRemoved(removedID1, removedID2) - + expected := []tailcfg.NodeID{ removedID1.NodeID(), removedID2.NodeID(), @@ -197,23 +197,23 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + // Simulate an error in the builder builder := m.NewMapResponseBuilder(nodeID) builder.addError(assert.AnError) - + // All subsequent calls should continue to work and accumulate errors result := builder. WithDomain(). WithCollectServicesDisabled(). WithDebugConfig() - + assert.True(t, result.hasErrors()) assert.Len(t, result.errs, 1) assert.Equal(t, assert.AnError, result.errs[0]) - + // Build should return the error data, err := result.Build("none") assert.Nil(t, data) @@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) { Enabled: false, }, } - + mockState := &state.State{} m := &mapper{ cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) capVer := tailcfg.CapabilityVersion(99) - + builder := m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). WithDomain(). WithCollectServicesDisabled(). WithDebugConfig() - + // Verify all fields are set correctly assert.Equal(t, capVer, builder.capVer) assert.Equal(t, domain, builder.resp.Domain) @@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) removedID1 := types.NodeID(100) removedID2 := types.NodeID(200) - + // Test calling WithPeersRemoved multiple times builder := m.NewMapResponseBuilder(nodeID). WithPeersRemoved(removedID1). WithPeersRemoved(removedID2) - + // Second call should overwrite the first expected := []tailcfg.NodeID{removedID2.NodeID()} assert.Equal(t, expected, builder.resp.PeersRemoved) @@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithPeerChangedPatch([]*tailcfg.PeerChange{}) - + assert.Empty(t, builder.resp.PeersChangedPatch) assert.False(t, builder.hasErrors()) } @@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithPeerChangedPatch(nil) - + assert.Nil(t, builder.resp.PeersChangedPatch) assert.False(t, builder.hasErrors()) } @@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + // Create a builder and add multiple errors builder := m.NewMapResponseBuilder(nodeID) builder.addError(assert.AnError) builder.addError(assert.AnError) builder.addError(nil) // This should be ignored - + // All subsequent calls should continue to work result := builder. WithDomain(). WithCollectServicesDisabled() - + assert.True(t, result.hasErrors()) assert.Len(t, result.errs, 2) // nil error should be ignored - + // Build should return a multierr data, err := result.Build("none") assert.Nil(t, data) assert.Error(t, err) - + // The error should contain information about multiple errors assert.Contains(t, err.Error(), "multiple errors") -} \ No newline at end of file +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 7ffe2ede..78fc9b09 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -18,6 +18,7 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -49,6 +50,37 @@ type mapper struct { created time.Time } +// addOnlineStatusToPeers adds fresh online status from batcher to peer nodes. +// +// We do a last-minute copy-and-write on the NodeView to inject current online status +// from the batcher's connection map. Online status is not populated upstream in NodeStore +// for consistency reasons - it's runtime connection state that should come from the +// connection manager (batcher) to ensure map responses have the freshest data. +func (m *mapper) addOnlineStatusToPeers(peers views.Slice[types.NodeView]) views.Slice[types.NodeView] { + if peers.Len() == 0 || m.batcher == nil { + return peers + } + + result := make([]types.NodeView, 0, peers.Len()) + for _, peer := range peers.All() { + if !peer.Valid() { + result = append(result, peer) + continue + } + + // Get online status from batcher connection map + // The batcher respects grace periods for logout scenarios + isOnline := m.batcher.IsConnected(peer.ID()) + + // Create a mutable copy and set online status + peerCopy := peer.AsStruct() + peerCopy.IsOnline = ptr.To(isOnline) + result = append(result, peerCopy.View()) + } + + return views.SliceOf(result) +} + type patch struct { timestamp time.Time change *tailcfg.PeerChange @@ -140,10 +172,10 @@ func (m *mapper) fullMapResponse( capVer tailcfg.CapabilityVersion, messages ...string, ) (*tailcfg.MapResponse, error) { - peers, err := m.state.ListPeers(nodeID) - if err != nil { - return nil, err - } + peers := m.state.ListPeers(nodeID) + + // Add fresh online status to peers from batcher connection state + peersWithOnlineStatus := m.addOnlineStatusToPeers(peers) return m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). @@ -154,9 +186,9 @@ func (m *mapper) fullMapResponse( WithDebugConfig(). WithSSHPolicy(). WithDNSConfig(). - WithUserProfiles(peers). + WithUserProfiles(peersWithOnlineStatus). WithPacketFilters(). - WithPeers(peers). + WithPeers(peersWithOnlineStatus). Build(messages...) } @@ -185,16 +217,16 @@ func (m *mapper) peerChangeResponse( capVer tailcfg.CapabilityVersion, changedNodeID types.NodeID, ) (*tailcfg.MapResponse, error) { - peers, err := m.state.ListPeers(nodeID, changedNodeID) - if err != nil { - return nil, err - } + peers := m.state.ListPeers(nodeID, changedNodeID) + + // Add fresh online status to peers from batcher connection state + peersWithOnlineStatus := m.addOnlineStatusToPeers(peers) return m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). WithSelfNode(). - WithUserProfiles(peers). - WithPeerChanges(peers). + WithUserProfiles(peersWithOnlineStatus). + WithPeerChanges(peersWithOnlineStatus). Build() } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 9729301d..24491e22 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -133,11 +133,15 @@ func tailNode( tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} } - if !node.IsOnline().Valid() || !node.IsOnline().Get() { - // LastSeen is only set when node is - // not connected to the control server. - if node.LastSeen().Valid() { - lastSeen := node.LastSeen().Get() + // Always set LastSeen if it's valid, regardless of online status + // This ensures that during logout grace periods (when IsOnline might be true + // for DNS preservation), other nodes can still see when this node disconnected + if node.LastSeen().Valid() { + lastSeen := node.LastSeen().Get() + // Only set LastSeen if the node is offline OR if LastSeen is recent + // (indicating it disconnected recently but might be in grace period) + if !node.IsOnline().Valid() || !node.IsOnline().Get() || + time.Since(lastSeen) < 60*time.Second { tNode.LastSeen = &lastSeen } } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 9dd42468..bb59fea6 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -13,7 +13,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "golang.org/x/net/http2" - "gorm.io/gorm" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/tailcfg" @@ -296,12 +295,9 @@ func (ns *noiseServer) NoiseRegistrationHandler( // getAndValidateNode retrieves the node from the database using the NodeKey // and validates that it matches the MachineKey from the Noise session. func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { - nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) - } - return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil) + nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) + if !ok { + return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) } // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 52457c9b..c377ce4f 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "github.com/samber/lo" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -138,39 +139,61 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf return ret } -// AutoApproveRoutes approves any route that can be autoapproved from -// the nodes perspective according to the given policy. -// It reports true if any routes were approved. -// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes. -func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { +// ApproveRoutesWithPolicy checks if the node can approve the announced routes +// and returns the new list of approved routes. +// The approved routes will include: +// 1. ALL previously approved routes (regardless of whether they're still advertised) +// 2. New routes from announcedRoutes that can be auto-approved by policy +// This ensures that: +// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes) +// - New routes can be auto-approved according to policy +// - Routes can only be removed by explicit admin action (not by auto-approval) +func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) { if pm == nil { - return false + return currentApproved, false } - nodeView := node.View() - var newApproved []netip.Prefix - for _, route := range nodeView.AnnouncedRoutes() { - if pm.NodeCanApproveRoute(nodeView, route) { + + // Start with ALL currently approved routes - we never remove approved routes + newApproved := make([]netip.Prefix, len(currentApproved)) + copy(newApproved, currentApproved) + + // Then, check for new routes that can be auto-approved + for _, route := range announcedRoutes { + // Skip if already approved + if slices.Contains(newApproved, route) { + continue + } + + // Check if this new route can be auto-approved by policy + canApprove := pm.NodeCanApproveRoute(nv, route) + if canApprove { newApproved = append(newApproved, route) } + + log.Trace(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Str("route", route.String()). + Bool("can_approve", canApprove). + Msg("Evaluating route for auto-approval") } - // Only modify ApprovedRoutes if we have new routes to approve. - // This prevents clearing existing approved routes when nodes - // temporarily don't have announced routes during policy changes. - if len(newApproved) > 0 { - combined := append(newApproved, node.ApprovedRoutes...) - tsaddr.SortPrefixes(combined) - combined = slices.Compact(combined) - combined = lo.Filter(combined, func(route netip.Prefix, index int) bool { - return route.IsValid() - }) + // Sort and deduplicate + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) + newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { + return route.IsValid() + }) - // Only update if the routes actually changed - if !slices.Equal(node.ApprovedRoutes, combined) { - node.ApprovedRoutes = combined - return true - } + // Sort the current approved for comparison + sortedCurrent := make([]netip.Prefix, len(currentApproved)) + copy(sortedCurrent, currentApproved) + tsaddr.SortPrefixes(sortedCurrent) + + // Only update if the routes actually changed + if !slices.Equal(sortedCurrent, newApproved) { + return newApproved, true } - return false + return newApproved, false } diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go new file mode 100644 index 00000000..67fa4c96 --- /dev/null +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -0,0 +1,339 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/types/key" + "tailscale.com/types/ptr" + "tailscale.com/types/views" +) + +func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { + user1 := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser@", + } + user2 := types.User{ + Model: gorm.Model{ID: 2}, + Name: "otheruser@", + } + users := []types.User{user1, user2} + + node1 := &types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test-node", + UserID: user1.ID, + User: user1, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ForcedTags: []string{"tag:test"}, + } + + node2 := &types.Node{ + ID: 2, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "other-node", + UserID: user2.ID, + User: user2, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + } + + // Create a policy that auto-approves specific routes + policyJSON := `{ + "groups": { + "group:test": ["testuser@"] + }, + "tagOwners": { + "tag:test": ["testuser@"] + }, + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + } + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["testuser@", "tag:test"], + "10.1.0.0/24": ["testuser@"], + "10.2.0.0/24": ["testuser@"], + "192.168.0.0/24": ["tag:test"] + } + } + }` + + pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) + assert.NoError(t, err) + + tests := []struct { + name string + node *types.Node + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + description string + }{ + { + name: "previously_approved_route_no_longer_advertised_should_remain", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Should still be here! + }, + wantChanged: false, + description: "Previously approved routes should never be removed even when no longer advertised", + }, + { + name: "add_new_auto_approved_route_keeps_old_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8) + netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept + }, + wantChanged: true, + description: "New auto-approved routes should be added while keeping old approved routes", + }, + { + name: "no_announced_routes_keeps_all_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + description: "All approved routes should remain when no routes are announced", + }, + { + name: "no_changes_when_announced_equals_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + description: "No changes should occur when announced routes match approved routes", + }, + { + name: "auto_approve_multiple_new_routes", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) + netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved + netip.MustParsePrefix("172.16.0.0/24"), // Original kept + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + }, + wantChanged: true, + description: "Multiple new routes should be auto-approved while keeping existing approved routes", + }, + { + name: "node_without_permission_no_auto_approval", + node: node2, // Different node without the tag + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route + }, + wantChanged: false, + description: "Routes should not be auto-approved for nodes without proper permissions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) + + // Sort for comparison since ApproveRoutesWithPolicy sorts the results + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) + + // Verify that all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should never happen", prevRoute) + } + }) + } +} + +func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { + // Create a basic policy for edge case testing + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.1.0.0/24": ["test@"], + }, + }, +}` + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_policy_manager", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "nil_announced_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: nil, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "duplicate_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_slices", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{}, + wantApproved: []netip.Prefix{}, + wantChanged: false, + }, + } + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + // Create policy manager or use nil if specified + var pm PolicyManager + var err error + if tt.name != "nil_policy_manager" { + pm, err = pmf(users, nodes.ViewSlice()) + assert.NoError(t, err) + } else { + pm = nil + } + + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch") + + // Handle nil vs empty slice comparison + if tt.wantApproved == nil { + assert.Nil(t, gotApproved, "expected nil approved routes") + } else { + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") + } + }) + } + } +} \ No newline at end of file diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go new file mode 100644 index 00000000..b6e54e7b --- /dev/null +++ b/hscontrol/policy/policy_route_approval_test.go @@ -0,0 +1,361 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { + // Test policy that allows specific routes to be auto-approved + aclPolicy := ` +{ + "groups": { + "group:admins": ["test@"], + }, + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/24": ["test@"], + "192.168.0.0/24": ["group:admins"], + "172.16.0.0/16": ["tag:approved"], + }, + }, + "tagOwners": { + "tag:approved": ["test@"], + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + nodeHostname string + nodeUser string + nodeTags []string + wantApproved []netip.Prefix + wantChanged bool + wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result + }{ + { + name: "previously_approved_route_no_longer_advertised_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Should remain! + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed + }, + { + name: "add_new_auto_approved_route_keeps_existing", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Still advertised + netip.MustParsePrefix("192.168.0.0/24"), // New route + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group + }, + wantChanged: true, + }, + { + name: "no_announced_routes_keeps_all_approved", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced anymore + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + }, + { + name: "manually_approved_route_not_in_policy_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved + }, + wantChanged: true, + }, + { + name: "tagged_node_gets_tag_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route + }, + nodeUser: "test", + nodeTags: []string{"tag:approved"}, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved + netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + }, + wantChanged: true, + }, + { + name: "complex_scenario_multiple_changes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised + netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable + netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) + netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised + }, + wantChanged: true, + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: tt.nodeUser, + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: tt.nodeHostname, + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + ForcedTags: tt.nodeTags, + } + nodes := types.Nodes{&node} + + // Create policy manager + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, pm) + + // Test ApproveRoutesWithPolicy + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + // Check change flag + assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch") + + // Check approved routes match expected + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Logf("Want: %v", tt.wantApproved) + t.Logf("Got: %v", gotApproved) + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + + // Verify all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should NEVER happen", prevRoute) + } + + // Verify no routes were incorrectly removed + for _, removedRoute := range tt.wantRemovedRoutes { + assert.NotContains(t, gotApproved, removedRoute, + "route %s should have been removed but wasn't", removedRoute) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["test@"], + }, + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_current_approved", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "duplicate_routes_handled", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, // Duplicates are removed, so it's a change + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + assert.Equal(t, tt.wantChanged, gotChanged) + + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + + currentApproved := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + } + announcedRoutes := []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + } + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: currentApproved, + } + + // With nil policy manager, should return current approved unchanged + gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes) + + assert.False(t, gotChanged) + assert.Equal(t, currentApproved, gotApproved) +} \ No newline at end of file diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 5e332fd3..1e6fabf3 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, canApprove: false, }, + { + name: "policy-without-autoApprovers-section", + node: normalNode, + route: p("10.33.0.0/16"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admin"], + "dst": ["group:admin:*"] + }, + { + "action": "accept", + "src": ["group:admin"], + "dst": ["10.33.0.0/16:*"] + } + ] + }`, + canApprove: false, + }, } for _, tt := range tests { diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index de839770..5e7aa34b 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // The fast path is that a node requests to approve a prefix // where there is an exact entry, e.g. 10.0.0.0/8, then // check and return quickly - if _, ok := pm.autoApproveMap[route]; ok { - if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) { + if approvers, ok := pm.autoApproveMap[route]; ok { + canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains) + if canApprove { return true } } @@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // Check if prefix is larger (so containing) and then overlaps // the route to see if the node can approve a subset of an autoapprover if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { - if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) { + canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains) + if canApprove { return true } } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 15de78d3..6cfa8528 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -10,7 +10,6 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" @@ -112,6 +111,15 @@ func (m *mapSession) serve() { // This is the mechanism where the node gives us information about its // current configuration. // + // Process the MapRequest to update node state (endpoints, hostinfo, etc.) + c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + httpError(m.w, err) + return + } + + m.h.Change(c) + // If OmitPeers is true and Stream is false // then the server will let clients update their endpoints without // breaking existing long-polling (Stream == true) connections. @@ -122,14 +130,6 @@ func (m *mapSession) serve() { // the response and just wants a 200. // !req.stream && req.OmitPeers if m.isEndpointUpdate() { - c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req) - if err != nil { - httpError(m.w, err) - return - } - - m.h.Change(c) - m.w.WriteHeader(http.StatusOK) mapResponseEndpointUpdates.WithLabelValues("ok").Inc() } @@ -142,6 +142,8 @@ func (m *mapSession) serve() { func (m *mapSession) serveLongPoll() { m.beforeServeLongPoll() + log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("starting long poll session chan(%p)", m.ch) + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -149,18 +151,26 @@ func (m *mapSession) serveLongPoll() { close(m.cancelCh) m.cancelChMu.Unlock() - // TODO(kradalby): This can likely be made more effective, but likely most - // nodes has access to the same routes, so it might not be a big deal. - disconnectChange, err := m.h.state.Disconnect(m.node) - if err != nil { - m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removing session from batcher chan(%p)", m.ch) + + // Validate if we are actually closing the current session or + // if the connection has been replaced. If the connection has been replaced, + // do not run the rest of the disconnect logic. + if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) { + log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removed from batcher chan(%p)", m.ch) + // First update NodeStore to mark the node as offline + // This ensures the state is consistent before notifying the batcher + disconnectChange, err := m.h.state.Disconnect(m.node.ID) + if err != nil { + m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + } + + // Send the disconnect change notification + m.h.Change(disconnectChange) + + m.afterServeLongPoll() + m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) } - m.h.Change(disconnectChange) - - m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter()) - - m.afterServeLongPoll() - m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) }() // Set up the client stream @@ -172,25 +182,37 @@ func (m *mapSession) serveLongPoll() { m.keepAliveTicker = time.NewTicker(m.keepAlive) - // Add node to batcher BEFORE sending Connect change to prevent race condition - // where the change is sent before the node is in the batcher's node map - if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil { + // Add node to batcher so it can receive updates, + // adding this before connecting it to the state ensure that + // it does not miss any updates that might be sent in the split + // time between the node connecting and the batcher being ready. + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { m.errf(err, "failed to add node to batcher") - // Send empty response to client to fail fast for invalid/non-existent nodes - select { - case m.ch <- &tailcfg.MapResponse{}: - default: - // Channel might be closed - } + return } - // Now send the Connect change - the batcher handles NodeCameOnline internally - // but we still need to update routes and other state-level changes - connectChange := m.h.state.Connect(m.node) - if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline { - m.h.Change(connectChange) + // Process the initial MapRequest to update node state (endpoints, hostinfo, etc.) + // CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly + // synchronized. When nodes reconnect, they send their hostinfo with announced routes + // in the MapRequest. We need this data in NodeStore before Connect() sets up the + // primary routes, otherwise SubnetRoutes() returns empty and the node is removed + // from AvailableRoutes. + mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + m.errf(err, "failed to update node from initial MapRequest") + return } + m.h.Change(mapReqChange) + + // Connect the node after its state has been updated. + // We send two separate change notifications because these are distinct operations: + // 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo) + // 2. Connect: marks the node online and recalculates primary routes based on the updated state + // While this results in two notifications, it ensures route data is synchronized before + // primary route selection occurs, which is critical for proper HA subnet router failover. + connectChange := m.h.state.Connect(m.node.ID) + m.h.Change(connectChange) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 902d72ba..b319dcd9 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -1,11 +1,15 @@ package state import ( + "fmt" "maps" + "strings" "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" "tailscale.com/types/key" "tailscale.com/types/views" ) @@ -21,6 +25,56 @@ const ( update = 3 ) +const prometheusNamespace = "headscale" + +var ( + nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operations_total", + Help: "Total number of NodeStore operations", + }, []string{"operation"}) + nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operation_duration_seconds", + Help: "Duration of NodeStore operations", + Buckets: prometheus.DefBuckets, + }, []string{"operation"}) + nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_size", + Help: "Size of NodeStore write batches", + Buckets: []float64{1, 2, 5, 10, 20, 50, 100}, + }) + nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_duration_seconds", + Help: "Duration of NodeStore batch processing", + Buckets: prometheus.DefBuckets, + }) + nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_snapshot_build_duration_seconds", + Help: "Duration of NodeStore snapshot building from nodes", + Buckets: prometheus.DefBuckets, + }) + nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_nodes_total", + Help: "Total number of nodes in the NodeStore", + }) + nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_peers_calculation_duration_seconds", + Help: "Duration of peers calculation in NodeStore", + Buckets: prometheus.DefBuckets, + }) + nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_queue_depth", + Help: "Current depth of NodeStore write queue", + }) +) + // NodeStore is a thread-safe store for nodes. // It is a copy-on-write structure, replacing the "snapshot" // when a change to the structure occurs. It is optimised for reads, @@ -29,13 +83,14 @@ const ( // changes rapidly. // // Writes will block until committed, while reads are never -// blocked. +// blocked. This means that the caller of a write operation +// is responsible for ensuring an update depending on a write +// is not issued before the write is complete. type NodeStore struct { data atomic.Pointer[Snapshot] peersFunc PeersFunc writeQueue chan work - // TODO: metrics } func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore { @@ -50,9 +105,17 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore { } store.data.Store(&snap) + // Initialize node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + return store } +// Snapshot is the representation of the current state of the NodeStore. +// It contains all nodes and their relationships. +// It is a copy-on-write structure, meaning that when a write occurs, +// a new Snapshot is created with the updated state, +// and replaces the old one atomically. type Snapshot struct { // nodesByID is the main source of truth for nodes. nodesByID map[types.NodeID]types.Node @@ -64,15 +127,19 @@ type Snapshot struct { allNodes []types.NodeView } +// PeersFunc is a function that takes a list of nodes and returns a map +// with the relationships between nodes and their peers. +// This will typically be used to calculate which nodes can see each other +// based on the current policy. type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView +// work represents a single operation to be performed on the NodeStore. type work struct { - op int - nodeID types.NodeID - node types.Node - updateFn UpdateNodeFunc - result chan struct{} - immediate bool // For operations that need immediate processing + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} } // PutNode adds or updates a node in the store. @@ -80,6 +147,9 @@ type work struct { // If the node does not exist, it will be added. // This is a blocking operation that waits for the write to complete. func (s *NodeStore) PutNode(n types.Node) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put")) + defer timer.ObserveDuration() + work := work{ op: put, nodeID: n.ID, @@ -87,8 +157,12 @@ func (s *NodeStore) PutNode(n types.Node) { result: make(chan struct{}), } + nodeStoreQueueDepth.Inc() s.writeQueue <- work <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("put").Inc() } // UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. @@ -96,7 +170,21 @@ type UpdateNodeFunc func(n *types.Node) // UpdateNode applies a function to modify a specific node in the store. // This is a blocking operation that waits for the write to complete. +// This is analogous to a database "transaction", or, the caller should +// rather collect all data they want to change, and then call this function. +// Fewer calls are better. +// +// TODO(kradalby): Technically we could have a version of this that modifies the node +// in the current snapshot if _we know_ that the change will not affect the peer relationships. +// This is because the main nodesByID map contains the struct, and every other map is using a +// pointer to the underlying struct. The gotcha with this is that we will need to introduce +// a lock around the nodesByID map to ensure that no other writes are happening +// while we are modifying the node. Which mean we would need to implement read-write locks +// on all read operations. func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update")) + defer timer.ObserveDuration() + work := work{ op: update, nodeID: nodeID, @@ -104,48 +192,47 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) result: make(chan struct{}), } + nodeStoreQueueDepth.Inc() s.writeQueue <- work <-work.result -} + nodeStoreQueueDepth.Dec() -// UpdateNodeImmediate applies a function to modify a specific node in the store -// with immediate processing (bypassing normal batching delays). -// Use this for time-sensitive updates like online status changes. -func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) { - work := work{ - op: update, - nodeID: nodeID, - updateFn: updateFn, - result: make(chan struct{}), - immediate: true, - } - - s.writeQueue <- work - <-work.result + nodeStoreOperations.WithLabelValues("update").Inc() } // DeleteNode removes a node from the store by its ID. // This is a blocking operation that waits for the write to complete. func (s *NodeStore) DeleteNode(id types.NodeID) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete")) + defer timer.ObserveDuration() + work := work{ op: del, nodeID: id, result: make(chan struct{}), } + nodeStoreQueueDepth.Inc() s.writeQueue <- work <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("delete").Inc() } +// Start initializes the NodeStore and starts processing the write queue. func (s *NodeStore) Start() { s.writeQueue = make(chan work) go s.processWrite() } +// Stop stops the NodeStore and closes the write queue. func (s *NodeStore) Stop() { close(s.writeQueue) } +// processWrite processes the write queue in batches. +// It collects writes into batches and applies them periodically. func (s *NodeStore) processWrite() { c := time.NewTicker(batchTimeout) batch := make([]work, 0, batchSize) @@ -157,13 +244,7 @@ func (s *NodeStore) processWrite() { c.Stop() return } - - // Handle immediate operations right away - if w.immediate { - s.applyBatch([]work{w}) - continue - } - + batch = append(batch, w) if len(batch) >= batchSize { s.applyBatch(batch) @@ -181,7 +262,22 @@ func (s *NodeStore) processWrite() { } } +// applyBatch applies a batch of work to the node store. +// This means that it takes a copy of the current nodes, +// then applies the batch of operations to that copy, +// runs any precomputation needed (like calculating peers), +// and finally replaces the snapshot in the store with the new one. +// The replacement of the snapshot is atomic, ensuring that reads +// are never blocked by writes. +// Each write item is blocked until the batch is applied to ensure +// the caller knows the operation is complete and do not send any +// updates that are dependent on a read that is yet to be written. func (s *NodeStore) applyBatch(batch []work) { + timer := prometheus.NewTimer(nodeStoreBatchDuration) + defer timer.ObserveDuration() + + nodeStoreBatchSize.Observe(float64(len(batch))) + nodes := make(map[types.NodeID]types.Node) maps.Copy(nodes, s.data.Load().nodesByID) @@ -201,15 +297,25 @@ func (s *NodeStore) applyBatch(batch []work) { } newSnap := snapshotFromNodes(nodes, s.peersFunc) - s.data.Store(&newSnap) + // Update node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + for _, w := range batch { close(w.result) } } +// snapshotFromNodes creates a new Snapshot from the provided nodes. +// It builds a lot of "indexes" to make lookups fast for datasets we +// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser. +// This is not a fast operation, it is the "slow" part of our copy-on-write +// structure, but it allows us to have fast reads and efficient lookups. func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot { + timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration) + defer timer.ObserveDuration() + allNodes := make([]types.NodeView, 0, len(nodes)) for _, n := range nodes { allNodes = append(allNodes, n.View()) @@ -219,8 +325,17 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S nodesByID: nodes, allNodes: allNodes, nodesByNodeKey: make(map[key.NodePublic]types.NodeView), - peersByNode: peersFunc(allNodes), - nodesByUser: make(map[types.UserID][]types.NodeView), + + // peersByNode is most likely the most expensive operation, + // it will use the list of all nodes, combined with the + // current policy to precalculate which nodes are peers and + // can see each other. + peersByNode: func() map[types.NodeID][]types.NodeView { + peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) + defer peersTimer.ObserveDuration() + return peersFunc(allNodes) + }(), + nodesByUser: make(map[types.UserID][]types.NodeView), } // Build nodesByUser and nodesByNodeKey maps @@ -234,9 +349,21 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S } // GetNode retrieves a node by its ID. -func (s *NodeStore) GetNode(id types.NodeID) types.NodeView { - n := s.data.Load().nodesByID[id] - return n.View() +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get").Inc() + + n, exists := s.data.Load().nodesByID[id] + if !exists { + return types.NodeView{}, false + } + + return n.View(), true } // GetNodeByNodeKey retrieves a node by its NodeKey. @@ -306,15 +433,30 @@ func (s *NodeStore) DebugString() string { // ListNodes returns a slice of all nodes in the store. func (s *NodeStore) ListNodes() views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list").Inc() + return views.SliceOf(s.data.Load().allNodes) } // ListPeers returns a slice of all peers for a given node ID. func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_peers").Inc() + return views.SliceOf(s.data.Load().peersByNode[id]) } // ListNodesByUser returns a slice of all nodes for a given user ID. func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_by_user").Inc() + return views.SliceOf(s.data.Load().nodesByUser[uid]) } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 7af07b38..9666e5db 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -24,6 +24,7 @@ func TestSnapshotFromNodes(t *testing.T) { peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { return make(map[types.NodeID][]types.NodeView) } + return nodes, peersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { @@ -61,6 +62,7 @@ func TestSnapshotFromNodes(t *testing.T) { 1: createTestNode(1, 1, "user1", "node1"), 2: createTestNode(2, 1, "user1", "node2"), } + return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { @@ -85,6 +87,7 @@ func TestSnapshotFromNodes(t *testing.T) { 2: createTestNode(2, 2, "user2", "node2"), 3: createTestNode(3, 1, "user1", "node3"), } + return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { @@ -113,6 +116,7 @@ func TestSnapshotFromNodes(t *testing.T) { 4: createTestNode(4, 4, "user4", "node4"), } peersFunc := oddEvenPeersFunc + return nodes, peersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { @@ -191,6 +195,7 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView } ret[node.ID()] = peers } + return ret } @@ -214,6 +219,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView } ret[node.ID()] = peers } + return ret } @@ -329,6 +335,7 @@ func TestNodeStoreOperations(t *testing.T) { node2 := createTestNode(2, 1, "user1", "node2") node3 := createTestNode(3, 2, "user2", "node3") initialNodes := types.Nodes{&node1, &node2, &node3} + return NewNodeStore(initialNodes, allowAllPeersFunc) }, steps: []testStep{ diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index a63dad22..6bad903f 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1,5 +1,6 @@ // Package state provides core state management for Headscale, coordinating // between subsystems like database, IP allocation, policy management, and DERP routing. + package state import ( @@ -9,6 +10,8 @@ import ( "io" "net/netip" "os" + "slices" + "sync" "time" hsdb "github.com/juanfont/headscale/hscontrol/db" @@ -21,7 +24,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" - xslices "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -96,6 +99,12 @@ func NewState(cfg *types.Config) (*State, error) { if err != nil { return nil, fmt.Errorf("loading nodes: %w", err) } + + // On startup, all nodes should be marked as offline until they reconnect + // This ensures we don't have stale online status from previous runs + for _, node := range nodes { + node.IsOnline = ptr.To(false) + } users, err := db.ListUsers() if err != nil { return nil, fmt.Errorf("loading users: %w", err) @@ -190,30 +199,34 @@ func (s *State) DERPMap() *tailcfg.DERPMap { // ReloadPolicy reloads the access control policy and triggers auto-approval if changed. // Returns true if the policy changed. -func (s *State) ReloadPolicy() (bool, error) { +func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { pol, err := policyBytes(s.db, s.cfg) if err != nil { - return false, fmt.Errorf("loading policy: %w", err) + return nil, fmt.Errorf("loading policy: %w", err) } - changed, err := s.polMan.SetPolicy(pol) + _, err = s.polMan.SetPolicy(pol) if err != nil { - return false, fmt.Errorf("setting policy: %w", err) + return nil, fmt.Errorf("setting policy: %w", err) } - if changed { - err := s.autoApproveNodes() - if err != nil { - return false, fmt.Errorf("auto approving nodes: %w", err) - } + cs := []change.ChangeSet{change.PolicyChange()} + + // Always call autoApproveNodes during policy reload, regardless of whether + // the policy content has changed. This ensures that routes are re-evaluated + // when they might have been manually disabled but could now be auto-approved + // with the current policy. + rcs, err := s.autoApproveNodes() + if err != nil { + return nil, fmt.Errorf("auto approving nodes: %w", err) } - return changed, nil -} + // TODO(kradalby): These changes can probably be safely ignored. + // If the PolicyChange is happening, that will lead to a full update + // meaning that we do not need to send individual route changes. + cs = append(cs, rcs...) -// AutoApproveNodes processes pending nodes and auto-approves those meeting policy criteria. -func (s *State) AutoApproveNodes() error { - return s.autoApproveNodes() + return cs, nil } // CreateUser creates a new user and updates the policy manager. @@ -237,16 +250,14 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro // might now be resolvable when they weren't before. If there are existing // nodes, we should send a policy change to ensure they get updated SSH policies. if c.Empty() { - nodes, err := s.ListNodes() - if err == nil && nodes.Len() > 0 { + nodes := s.ListNodes() + if nodes.Len() > 0 { c = change.PolicyChange() } } log.Info().Str("user", user.Name).Bool("policyChanged", !c.Empty()).Msg("User created, policy manager updated") - // TODO(kradalby): implement the user in-memory cache - return &user, c, nil } @@ -282,7 +293,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err) } - // TODO(kradalby): implement the user in-memory cache + // TODO(kradalby): We might want to update nodestore with the user data return user, c, nil } @@ -326,32 +337,11 @@ func (s *State) ListAllUsers() ([]types.User, error) { return s.db.ListUsers() } -// CreateNode creates a new node and updates the policy manager. -// Returns the created node, change set, and any error. -func (s *State) CreateNode(node *types.Node) (types.NodeView, change.ChangeSet, error) { - s.nodeStore.PutNode(*node) - - if err := s.db.DB.Save(node).Error; err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("creating node: %w", err) - } - - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node creation: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - if c.Empty() { - c = change.NodeAdded(node.ID) - } - - return node.View(), c, nil -} - // updateNodeTx performs a database transaction to update a node and refresh the policy manager. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (types.NodeView, change.ChangeSet, error) { +// IMPORTANT: This function does NOT update the NodeStore. The caller MUST update the NodeStore +// BEFORE calling this function with the EXACT same changes that the database update will make. +// This ensures the NodeStore is the source of truth for the batcher and maintains consistency. +func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (types.NodeView, error) { s.mu.Lock() defer s.mu.Unlock() @@ -372,32 +362,25 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err return node, nil }) if err != nil { - return types.NodeView{}, change.EmptySet, err + return types.NodeView{}, err } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - if c.Empty() { - // Basic node change without specific details since this is a generic update - c = change.NodeAdded(node.ID) - } - - return node.View(), c, nil + return node.View(), nil } -// SaveNode persists an existing node to the database and updates the policy manager. -func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) { +// persistNodeToDB saves the current state of a node from NodeStore to the database. +// CRITICAL: This function MUST get the latest node from NodeStore to ensure consistency. +func (s *State) persistNodeToDB(nodeID types.NodeID) (types.NodeView, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() + // CRITICAL: Always get the latest node from NodeStore to ensure we save the current state + node, found := s.nodeStore.GetNode(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + nodePtr := node.AsStruct() - s.nodeStore.PutNode(*nodePtr) if err := s.db.DB.Save(nodePtr).Error; err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err) @@ -409,8 +392,6 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) } - // TODO(kradalby): implement the node in-memory cache - if c.Empty() { c = change.NodeAdded(node.ID()) } @@ -418,6 +399,16 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, return node, c, nil } +func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) { + // Update NodeStore first + nodePtr := node.AsStruct() + + s.nodeStore.PutNode(*nodePtr) + + // Then save to database + return s.persistNodeToDB(node.ID()) +} + // DeleteNode permanently removes a node and cleans up associated resources. // Returns whether policies changed and any error. This operation is irreversible. func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { @@ -443,61 +434,95 @@ func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { return c, nil } -func (s *State) Connect(node *types.Node) change.ChangeSet { - c := change.NodeOnline(node.ID) - routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) +// Connect marks a node as connected and updates its primary routes in the state. +func (s *State) Connect(id types.NodeID) change.ChangeSet { + c := change.NodeOnline(id) - if routeChange { - c = change.NodeAdded(node.ID) - } - - // Update nodestore with online status - node is connecting so it's online - // Use immediate update to ensure online status changes are not delayed by batching - s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) { - // Set the online status in the node's ephemeral field + // Update the online status in NodeStore + s.nodeStore.UpdateNode(id, func(n *types.Node) { n.IsOnline = ptr.To(true) }) + // Get fresh node data from NodeStore after the online status update + node, found := s.GetNodeByID(id) + if !found { + return change.EmptySet + } + + // Use the node's current routes for primary route update + // SubnetRoutes() returns only the intersection of announced AND approved routes + // We MUST use SubnetRoutes() to maintain the security model + routeChange := s.primaryRoutes.SetRoutes(id, node.SubnetRoutes()...) + + log.Trace().Msg("===============================================") + log.Trace().Bool("route-change", routeChange).Str("nid", id.String()).Msgf("NODE CONNECTING, SubR: %v, AnnoR: %v, ApprR: %v", node.SubnetRoutes(), node.AnnouncedRoutes(), node.ApprovedRoutes()) + log.Trace().Msg("===============================================") + + if routeChange { + c = change.NodeAdded(id) + } + return c } -func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) { - c := change.NodeOffline(node.ID) - - // Update nodestore with offline status - // Use immediate update to ensure online status changes are not delayed by batching - s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) { - // Set the online status to false in the node's ephemeral field - n.IsOnline = ptr.To(false) +// Disconnect marks a node as disconnected and updates its primary routes in the state. +func (s *State) Disconnect(id types.NodeID) (change.ChangeSet, error) { + now := time.Now() + s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.LastSeen = ptr.To(now) + // Do NOT mark as offline here - let the batcher's grace period handle it + // This ensures DNS continues to work during logout }) - _, _, err := s.SetLastSeen(node.ID, time.Now()) + _, err := s.updateNodeTx(id, func(tx *gorm.DB) error { + return hsdb.SetLastSeen(tx, id, now) + }) if err != nil { - return c, fmt.Errorf("disconnecting node: %w", err) + return change.EmptySet, fmt.Errorf("setting last seen: %w", err) } - if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange { + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } + + // The node is disconnecting so make sure that none of the routes it + // announced are served to any nodes. + routeChange := s.primaryRoutes.SetRoutes(id) + + // If we have a policy change or route change, return that as it's more comprehensive + // Otherwise, return the NodeOffline change to ensure nodes are notified + if c.IsFull() || routeChange { c = change.PolicyChange() + } else { + c = change.NodeOffline(id) } - // TODO(kradalby): This node should update the in memory state return c, nil } // GetNodeByID retrieves a node by ID. -func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, error) { - return s.nodeStore.GetNode(nodeID), nil +// GetNodeByID retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, bool) { + return s.nodeStore.GetNode(nodeID) } // GetNodeByNodeKey retrieves a node by its Tailscale public key. -func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) { - return s.nodeStore.GetNodeByNodeKey(nodeKey), nil +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) { + return s.nodeStore.GetNodeByNodeKey(nodeKey) } // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. -func (s *State) ListNodes(nodeIDs ...types.NodeID) (views.Slice[types.NodeView], error) { +func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] { if len(nodeIDs) == 0 { - return s.nodeStore.ListNodes(), nil + return s.nodeStore.ListNodes() } // Filter nodes by the requested IDs @@ -514,18 +539,18 @@ func (s *State) ListNodes(nodeIDs ...types.NodeID) (views.Slice[types.NodeView], } } - return views.SliceOf(filteredNodes), nil + return views.SliceOf(filteredNodes) } // ListNodesByUser retrieves all nodes belonging to a specific user. -func (s *State) ListNodesByUser(userID types.UserID) (views.Slice[types.NodeView], error) { - return s.nodeStore.ListNodesByUser(userID), nil +func (s *State) ListNodesByUser(userID types.UserID) views.Slice[types.NodeView] { + return s.nodeStore.ListNodesByUser(userID) } // ListPeers retrieves nodes that can communicate with the specified node based on policy. -func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (views.Slice[types.NodeView], error) { +func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) views.Slice[types.NodeView] { if len(peerIDs) == 0 { - return s.nodeStore.ListPeers(nodeID), nil + return s.nodeStore.ListPeers(nodeID) } // For specific peerIDs, filter from all nodes @@ -542,11 +567,11 @@ func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (views.S } } - return views.SliceOf(filteredNodes), nil + return views.SliceOf(filteredNodes) } // ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system. -func (s *State) ListEphemeralNodes() (views.Slice[types.NodeView], error) { +func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { allNodes := s.nodeStore.ListNodes() var ephemeralNodes []types.NodeView @@ -557,22 +582,33 @@ func (s *State) ListEphemeralNodes() (views.Slice[types.NodeView], error) { } } - return views.SliceOf(ephemeralNodes), nil + return views.SliceOf(ephemeralNodes) } // SetNodeExpiry updates the expiration time for a node. func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + // If the database update fails, the NodeStore change will remain, but since we return + // an error, no change notification will be sent to the batcher. + expiryPtr := expiry + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.Expiry = &expiryPtr + }) + + n, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.NodeSetExpiry(tx, nodeID, expiry) }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) } - // Update nodestore with the same change - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { - n.Expiry = &expiry - }) + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } if !c.IsFull() { c = change.KeyExpiry(nodeID) @@ -583,17 +619,25 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node // SetNodeTags assigns tags to a node for use in access control policies. func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ForcedTags = tags + }) + + n, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetTags(tx, nodeID, tags) }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err) } - // Update nodestore with the same change - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { - n.ForcedTags = tags - }) + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } if !c.IsFull() { c = change.NodeAdded(nodeID) @@ -604,20 +648,34 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, // SetApprovedRoutes sets the network routes that a node is approved to advertise. func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { + // TODO(kradalby): In principle we should call the AutoApprove logic here + // because even if the CLI removes an auto-approved route, it will be added + // back automatically. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ApprovedRoutes = routes + }) + + n, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetApprovedRoutes(tx, nodeID, routes) }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) } - // Update nodestore with the same change - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { - n.ApprovedRoutes = routes - }) + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } - // Update primary routes after changing approved routes - routeChange := s.primaryRoutes.SetRoutes(nodeID, n.AsStruct().SubnetRoutes()...) + // Get the node from NodeStore to ensure we have the latest state + nodeView, ok := s.GetNodeByID(nodeID) + if !ok { + return n, change.EmptySet, fmt.Errorf("node %d not found in NodeStore", nodeID) + } + // Use SubnetRoutes() instead of ApprovedRoutes() to ensure we only set + // primary routes for routes that are both announced AND approved + routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...) if routeChange || !c.IsFull() { c = change.PolicyChange() @@ -628,39 +686,42 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t // RenameNode changes the display name of a node. func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { + // Validate the new name before making any changes + if err := util.CheckForFQDNRules(newName); err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) + } + + // Check name uniqueness + nodes, err := s.db.ListNodes() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("checking name uniqueness: %w", err) + } + for _, node := range nodes { + if node.ID != nodeID && node.GivenName == newName { + return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName) + } + } + + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.GivenName = newName + }) + + n, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.RenameNode(tx, nodeID, newName) }) if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) } - // Update nodestore with the same change - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { - n.GivenName = newName - }) - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil -} - -// SetLastSeen updates when a node was last seen, used for connectivity monitoring. -func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (types.NodeView, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetLastSeen(tx, nodeID, lastSeen) - }) + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting last seen: %w", err) + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } - // Update nodestore with the same change - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { - n.LastSeen = &lastSeen - }) - if !c.IsFull() { c = change.NodeAdded(nodeID) } @@ -670,39 +731,98 @@ func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (types.Node // AssignNodeToUser transfers a node to a different user. func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (types.NodeView, change.ChangeSet, error) { - node, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { + // Validate that both node and user exist + _, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found: %d", nodeID) + } + + user, err := s.GetUserByID(userID) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("user not found: %w", err) + } + + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.User = *user + n.UserID = uint(userID) + }) + + n, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.AssignNodeToUser(tx, nodeID, userID) }) if err != nil { return types.NodeView{}, change.EmptySet, err } - // Update nodestore with the same change - // Get the updated user information from the database - user, err := s.GetUserByID(userID) - if err == nil { - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { - n.UserID = uint(userID) - n.User = *user - }) + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { c = change.NodeAdded(nodeID) } - return node, c, nil + return n, c, nil } // BackfillNodeIPs assigns IP addresses to nodes that don't have them. func (s *State) BackfillNodeIPs() ([]string, error) { - return s.db.BackfillNodeIPs(s.ipAlloc) + changes, err := s.db.BackfillNodeIPs(s.ipAlloc) + if err != nil { + return nil, err + } + + // Refresh NodeStore after IP changes to ensure consistency + if len(changes) > 0 { + nodes, err := s.db.ListNodes() + if err != nil { + return changes, fmt.Errorf("failed to refresh NodeStore after IP backfill: %w", err) + } + + for _, node := range nodes { + // Preserve online status when refreshing from database + existingNode, exists := s.nodeStore.GetNode(node.ID) + if exists && existingNode.Valid() { + node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + } + s.nodeStore.PutNode(*node) + } + } + + return changes, nil } // ExpireExpiredNodes finds and processes expired nodes since the last check. // Returns next check time, state update with expired nodes, and whether any were found. func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) { - return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck) + // Why capture start time: We need to ensure we don't miss nodes that expire + // while this function is running by using a consistent timestamp for the next check + started := time.Now() + + var updates []change.ChangeSet + + for _, node := range s.nodeStore.ListNodes().All() { + if !node.Valid() { + continue + } + + // Why check After(lastCheck): We only want to notify about nodes that + // expired since the last check to avoid duplicate notifications + if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { + updates = append(updates, change.KeyExpiry(node.ID())) + } + } + + if len(updates) > 0 { + return started, updates, true + } + + return started, nil, false } // SSHPolicy returns the SSH access policy for a node. @@ -726,21 +846,24 @@ func (s *State) SetPolicy(pol []byte) (bool, error) { } // AutoApproveRoutes checks if a node's routes should be auto-approved. -func (s *State) AutoApproveRoutes(node types.NodeView) bool { - nodePtr := node.AsStruct() - changed := policy.AutoApproveRoutes(s.polMan, nodePtr) +// AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them. +func (s *State) AutoApproveRoutes(nv types.NodeView) bool { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) if changed { - s.nodeStore.PutNode(*nodePtr) - // Update primaryRoutes manager with the newly approved routes - // This is essential for actual packet forwarding to work - s.primaryRoutes.SetRoutes(nodePtr.ID, nodePtr.SubnetRoutes()...) - } - return changed -} + // Persist the auto-approved routes to database and NodeStore via SetApprovedRoutes + // This ensures consistency between database and NodeStore + _, _, err := s.SetApprovedRoutes(nv.ID(), approved) + if err != nil { + log.Error(). + Uint64("node.id", nv.ID().Uint64()). + Err(err). + Msg("Failed to persist auto-approved routes") -// PolicyDebugString returns a debug representation of the current policy. -func (s *State) PolicyDebugString() string { - return s.polMan.DebugString() + return false + } + } + + return changed } // GetPolicy retrieves the current policy from the database. @@ -846,25 +969,67 @@ func (s *State) HandleNodeFromAuthPath( expiry *time.Time, registrationMethod string, ) (types.NodeView, change.ChangeSet, error) { - ipv4, ipv6, err := s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, err + // Get the registration entry to check the machine key + var ipv4, ipv6 *netip.Addr + var err error + + // Check if we have the registration entry to determine if we should reuse IPs + if regEntry, ok := s.GetRegistrationCacheEntry(registrationID); ok { + // Check if node already exists with same machine key and user + // to avoid allocating new IPs unnecessarily + existingNode, _ := s.db.GetNodeByMachineKey(regEntry.Node.MachineKey) + + // Only allocate new IPs if: + // 1. No existing node found, OR + // 2. Existing node belongs to a different user + if existingNode == nil || existingNode.UserID != uint(userID) { + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + } + // If existing node found for same user, HandleNodeFromAuthPath will reuse its IPs + } else { + // If no registration entry found, allocate new IPs (shouldn't happen in normal flow) + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, err + } } node, nodeChange, err := s.db.HandleNodeFromAuthPath( registrationID, userID, expiry, - util.RegisterMethodOIDC, + registrationMethod, ipv4, ipv6, ) if err != nil { return types.NodeView{}, change.EmptySet, err } - // Update nodestore with the newly registered/updated node + // Update NodeStore to ensure it has the latest node data + // For re-registrations (key expiry), mark as offline since node is reconnecting + // For new registrations, leave IsOnline as nil to let batcher manage connection state + if nodeChange.Change == change.NodeKeyExpiry { + // This is a re-registration/key refresh - node was disconnected and is coming back + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + } + // For new registrations (NodeNewOrUpdate), don't set IsOnline - batcher manages it s.nodeStore.PutNode(*node) + // Update policy manager with the new node if needed + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + return node.View(), nodeChange, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err) + } + + // If policy manager detected changes, use that instead + if !nodesChange.Empty() { + nodeChange = nodesChange + } + return node.View(), nodeChange, nil } @@ -905,10 +1070,21 @@ func (s *State) HandleNodeFromPreAuthKey( nodeToRegister.Expiry = ®Req.Expiry } - ipv4, ipv6, err := s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + // Check if node already exists with same machine key and user + // to avoid allocating new IPs unnecessarily + existingNode, _ := s.db.GetNodeByMachineKey(machineKey) + var ipv4, ipv6 *netip.Addr + + // Only allocate new IPs if: + // 1. No existing node found, OR + // 2. Existing node belongs to a different user + if existingNode == nil || existingNode.UserID != pak.User.ID { + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } } + // If existing node found for same user, RegisterNode will reuse its IPs node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { node, err := hsdb.RegisterNode(tx, @@ -939,9 +1115,23 @@ func (s *State) HandleNodeFromPreAuthKey( if err != nil { return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) } + return types.NodeView{}, c, nil } + // Update NodeStore BEFORE updating policy manager so it has the latest node data + // CRITICAL: For re-registration of existing nodes, we must update NodeStore + // to ensure it has the latest state from the database transaction + // For re-registrations of existing nodes, mark as offline since they're reconnecting + // For new registrations, leave IsOnline as nil to let batcher manage connection state + if existingNode != nil && existingNode.UserID == pak.User.ID { + // This is a re-registration of existing node - was disconnected and is coming back + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + } + // For new registrations, don't set IsOnline - batcher manages it + s.nodeStore.PutNode(*node) + // Check if policy manager needs updating // This is necessary because we just created a new node. // We need to ensure that the policy manager is aware of this new node. @@ -963,9 +1153,6 @@ func (s *State) HandleNodeFromPreAuthKey( c = change.NodeAdded(node.ID) } - // Update nodestore with the newly registered node - s.nodeStore.PutNode(*node) - return node.View(), c, nil } @@ -993,6 +1180,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { if changed { return change.PolicyChange(), nil } + return change.EmptySet, nil } @@ -1003,10 +1191,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { // the policy manager could have a remove or add list for nodes. // updatePolicyManagerNodes refreshes the policy manager with current node data. func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) { - nodes, err := s.ListNodes() - if err != nil { - return change.EmptySet, fmt.Errorf("listing nodes for policy update: %w", err) - } + nodes := s.ListNodes() changed, err := s.polMan.SetNodes(nodes) if err != nil { @@ -1016,6 +1201,7 @@ func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) { if changed { return change.PolicyChange(), nil } + return change.EmptySet, nil } @@ -1030,150 +1216,212 @@ func (s *State) PingDB(ctx context.Context) error { // TODO(kradalby): This is kind of messy, maybe this is another +1 // for an event bus. See example comments here. // autoApproveNodes automatically approves nodes based on policy rules. -func (s *State) autoApproveNodes() error { - err := s.db.Write(func(tx *gorm.DB) error { - nodes, err := hsdb.ListNodes(tx) - if err != nil { - return err - } +func (s *State) autoApproveNodes() ([]change.ChangeSet, error) { + nodes := s.ListNodes() - for _, node := range nodes { - // TODO(kradalby): This change should probably be sent to the rest of the system. - changed := policy.AutoApproveRoutes(s.polMan, node) + // Approve routes concurrently, this should make it likely + // that the writes end in the same batch in the nodestore write. + var errg errgroup.Group + var cs []change.ChangeSet + var mu sync.Mutex + for _, nv := range nodes.All() { + errg.Go(func() error { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) if changed { - // Update nodestore first if available - s.nodeStore.PutNode(*node) - - err = tx.Save(node).Error + _, c, err := s.SetApprovedRoutes(nv.ID(), approved) if err != nil { return err } - // TODO(kradalby): This should probably be done outside of the transaction, - // and the result of this should be propagated to the system. - s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) + mu.Lock() + cs = append(cs, c) + mu.Unlock() } - } - return nil - }) - if err != nil { - return fmt.Errorf("auto approving routes for nodes: %w", err) + return nil + }) } - return nil + err := errg.Wait() + if err != nil { + return nil, err + } + + return cs, nil } -// TODO(kradalby): This should just take the node ID? -func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) { - // TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, - // which means we could shortcut the whole change thing if there are no other important updates. - peerChange := node.PeerChangeFromMapRequest(req) +// UpdateNodeFromMapRequest processes a MapRequest and updates the node. +// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, +// which means we could shortcut the whole change thing if there are no other important updates. +// When a field is added to this function, remember to also add it to: +// - node.PeerChangeFromMapRequest +// - node.ApplyPeerChange +// - logTracePeerChange in poll.go. +func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) { + var routeChange bool + var hostinfoChanged bool + var needsRouteApproval bool + // We need to ensure we update the node as it is in the NodeStore at + // the time of the request. + s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { + peerChange := currentNode.PeerChangeFromMapRequest(req) + hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) - node.ApplyPeerChange(&peerChange) + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(peerChange) && !hostinfoChanged { + return + } - sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo) + // Calculate route approval before NodeStore update to avoid calling View() inside callback + var autoApprovedRoutes []netip.Prefix + hasNewRoutes := req.Hostinfo != nil && len(req.Hostinfo.RoutableIPs) > 0 + needsRouteApproval = hostinfoChanged && (routesChanged(currentNode.View(), req.Hostinfo) || (hasNewRoutes && len(currentNode.ApprovedRoutes) == 0)) + if needsRouteApproval { + autoApprovedRoutes, routeChange = policy.ApproveRoutesWithPolicy( + s.polMan, + currentNode.View(), + // We need to preserve currently approved routes to ensure + // routes outside of the policy approver is persisted. + currentNode.ApprovedRoutes, + // However, the node has updated its routable IPs, so we + // need to approve them using that as a context. + req.Hostinfo.RoutableIPs, + ) + } + + // Log when routes change but approval doesn't + if hostinfoChanged && req.Hostinfo != nil && routesChanged(currentNode.View(), req.Hostinfo) && !routeChange { + log.Debug(). + Uint64("node.id", id.Uint64()). + Strs("oldAnnouncedRoutes", util.PrefixesToString(currentNode.AnnouncedRoutes())). + Strs("newAnnouncedRoutes", util.PrefixesToString(req.Hostinfo.RoutableIPs)). + Strs("approvedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Bool("routeChange", routeChange). + Msg("announced routes changed but approved routes did not") + } - // The node might not set NetInfo if it has not changed and if - // the full HostInfo object is overwritten, the information is lost. - // If there is no NetInfo, keep the previous one. - // From 1.66 the client only sends it if changed: - // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 - // TODO(kradalby): evaluate if we need better comparing of hostinfo - // before we take the changes. - if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil { - req.Hostinfo.NetInfo = node.Hostinfo.NetInfo + currentNode.ApplyPeerChange(&peerChange) + + if hostinfoChanged { + // The node might not set NetInfo if it has not changed and if + // the full HostInfo object is overwritten, the information is lost. + // If there is no NetInfo, keep the previous one. + // From 1.66 the client only sends it if changed: + // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 + // TODO(kradalby): evaluate if we need better comparing of hostinfo + // before we take the changes. + // Preserve NetInfo only if the existing node actually has valid NetInfo + // This prevents copying nil NetInfo which would lose DERP relay assignments + if req.Hostinfo.NetInfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { + log.Debug(). + Uint64("node.id", id.Uint64()). + Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). + Msg("preserving NetInfo from previous Hostinfo") + req.Hostinfo.NetInfo = currentNode.Hostinfo.NetInfo + } + currentNode.Hostinfo = req.Hostinfo + currentNode.ApplyHostnameFromHostInfo(req.Hostinfo) + + if routeChange { + // Apply pre-calculated route approval + // Always apply the route approval result to ensure consistency, + // regardless of whether the policy evaluation detected changes. + // This fixes the bug where routes weren't properly cleared when + // auto-approvers were removed from the policy. + log.Info(). + Uint64("node.id", id.Uint64()). + Strs("oldApprovedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Strs("newApprovedRoutes", util.PrefixesToString(autoApprovedRoutes)). + Bool("routeChanged", routeChange). + Msg("applying route approval results") + currentNode.ApprovedRoutes = autoApprovedRoutes + } + } + }) + + nodeRouteChange := change.EmptySet + + // Handle route changes after NodeStore update + // We need to update node routes if either: + // 1. The approved routes changed (routeChange is true), OR + // 2. The announced routes changed (even if approved routes stayed the same) + // This is because SubnetRoutes is the intersection of announced AND approved routes. + needsRouteUpdate := false + routesChangedButNotApproved := hostinfoChanged && req.Hostinfo != nil && needsRouteApproval && !routeChange + if routeChange { + needsRouteUpdate = true + log.Debug(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because approved routes changed") + } else if routesChangedButNotApproved { + needsRouteUpdate = true + log.Debug(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because announced routes changed but approved routes did not") } - node.Hostinfo = req.Hostinfo + + if needsRouteUpdate { + // Get the updated node to access its subnet routes + updatedNode, exists := s.GetNodeByID(id) + if !exists { + return change.EmptySet, fmt.Errorf("node disappeared during update: %d", id) + } - // If there is no changes and nothing to save, - // return early. - if peerChangeEmpty(peerChange) && !sendUpdate { - // mapResponseEndpointUpdates.WithLabelValues("noop").Inc() - return change.EmptySet, nil + // SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes() + // which returns only the intersection of announced AND approved routes. + // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. + log.Debug(). + Uint64("node.id", id.Uint64()). + Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())). + Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())). + Strs("subnetRoutes", util.PrefixesToString(updatedNode.SubnetRoutes())). + Msg("updating node routes for distribution") + nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...) } - c := change.EmptySet - - // Check if the Hostinfo of the node has changed. - // If it has changed, check if there has been a change to - // the routable IPs of the host and update them in - // the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the route change. - // If the hostinfo has changed, but not the routes, just update - // hostinfo and let the function continue. - if routesChanged { - // Auto approve any routes that have been defined in policy as - // auto approved. Check if this actually changed the node. - _ = s.AutoApproveRoutes(node.View()) - - // Update the routes of the given node in the route manager to - // see if an update needs to be sent. - c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...) - } - - // Check if there has been a change to Hostname and update them - // in the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the hostname change. - node.ApplyHostnameFromHostInfo(req.Hostinfo) - - _, policyChange, err := s.SaveNode(node.View()) + _, policyChange, err := s.persistNodeToDB(id) if err != nil { - return change.EmptySet, err + return change.EmptySet, fmt.Errorf("saving to database: %w", err) } if policyChange.IsFull() { - c = policyChange + return policyChange, nil + } + if !nodeRouteChange.Empty() { + return nodeRouteChange, nil } - if c.Empty() { - c = change.NodeAdded(node.ID) - } - - return c, nil + return change.NodeAdded(id), nil } -// hostInfoChanged reports if hostInfo has changed in two ways, -// - first bool reports if an update needs to be sent to nodes -// - second reports if there has been changes to routes -// the caller can then use this info to save and update nodes -// and routes as needed. -func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { - if old.Equal(new) { - return false, false +func hostinfoEqual(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { + if !oldNode.Valid() && new == nil { + return true + } + if !oldNode.Valid() || new == nil { + return false + } + old := oldNode.AsStruct().Hostinfo + + return old.Equal(new) +} + +func routesChanged(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { + var oldRoutes []netip.Prefix + if oldNode.Valid() && oldNode.AsStruct().Hostinfo != nil { + oldRoutes = oldNode.AsStruct().Hostinfo.RoutableIPs } - if old == nil && new != nil { - return true, true - } - - // Routes - oldRoutes := make([]netip.Prefix, 0) - if old != nil { - oldRoutes = old.RoutableIPs - } newRoutes := new.RoutableIPs + if newRoutes == nil { + newRoutes = []netip.Prefix{} + } tsaddr.SortPrefixes(oldRoutes) tsaddr.SortPrefixes(newRoutes) - if !xslices.Equal(oldRoutes, newRoutes) { - return true, true - } - - // Services is mostly useful for discovery and not critical, - // except for peerapi, which is how nodes talk to each other. - // If peerapi was not part of the initial mapresponse, we - // need to make sure its sent out later as it is needed for - // Taildrop. - // TODO(kradalby): Length comparison is a bit naive, replace. - if len(old.Services) != len(new.Services) { - return true, false - } - - return false, false + return !slices.Equal(oldRoutes, newRoutes) } func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 3301cb35..e38a98f6 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool { case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate: return true } + return false } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index fa315bf5..2761789b 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -104,6 +104,7 @@ type Node struct { // headscale. It is best effort and not persisted. LastSeen *time.Time `gorm:"column:last_seen"` + // ApprovedRoutes is a list of routes that the node is allowed to announce // as a subnet router. They are not necessarily the routes that the node // announces at the moment. @@ -420,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix { } // SubnetRoutes returns the list of routes that the node announces and are approved. +// +// IMPORTANT: This method is used for internal data structures and should NOT be used +// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually +// with PrimaryRoutes to ensure it includes only routes actively served by the node. +// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto. func (node *Node) SubnetRoutes() []netip.Prefix { var routes []netip.Prefix @@ -525,7 +531,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } node.Hostname = hostInfo.Hostname - + log.Trace(). Str("node_id", node.ID.String()). Str("new_hostname", node.Hostname).