diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index caac986c..6d6476fb 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -551,13 +551,12 @@ be assigned to nodes.`, } } - if confirm || force { ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force }) + changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force}) if err != nil { ErrorOutput( err, diff --git a/hscontrol/app.go b/hscontrol/app.go index 774aec46..47b38c83 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -137,9 +137,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { // Initialize ephemeral garbage collector ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { - node, err := app.state.GetNodeByID(ni) - if err != nil { - log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion") + node, ok := app.state.GetNodeByID(ni) + if !ok { + log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") + log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") return } @@ -379,15 +380,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writer http.ResponseWriter, req *http.Request, ) { - log.Trace(). - Caller(). - Str("client_address", req.RemoteAddr). - Msg("HTTP authentication invoked") - - authHeader := req.Header.Get("authorization") - - if !strings.HasPrefix(authHeader, AuthPrefix) { - log.Error(). + if err := func() error { + log.Trace(). Caller(). Str("client_address", req.RemoteAddr). Msg(`missing "Bearer " prefix in "Authorization" header`) @@ -501,11 +495,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { + var err error capver.CanOldCodeBeCleanedUp() if profilingEnabled { if profilingPath != "" { - err := os.MkdirAll(profilingPath, os.ModePerm) + err = os.MkdirAll(profilingPath, os.ModePerm) if err != nil { log.Fatal().Err(err).Msg("failed to create profiling directory") } @@ -559,12 +554,9 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() - ephmNodes, err := h.state.ListEphemeralNodes() - if err != nil { - return fmt.Errorf("failed to list ephemeral nodes: %w", err) - } - for _, node := range ephmNodes { - h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout) + ephmNodes := h.state.ListEphemeralNodes() + for _, node := range ephmNodes.All() { + h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) } if h.cfg.DNSConfig.ExtraRecordsPath != "" { @@ -794,23 +786,14 @@ func (h *Headscale) Serve() error { continue } - changed, err := h.state.ReloadPolicy() + changes, err := h.state.ReloadPolicy() if err != nil { log.Error().Err(err).Msgf("reloading policy") continue } - if changed { - log.Info(). - Msg("ACL policy successfully reloaded, notifying nodes of change") + h.Change(changes...) - err = h.state.AutoApproveNodes() - if err != nil { - log.Error().Err(err).Msg("failed to approve routes after new policy") - } - - h.Change(change.PolicySet) - } default: info := func(msg string) { log.Info().Msg(msg) } log.Info(). @@ -1020,6 +1003,6 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { // Change is used to send changes to nodes. // All change should be enqueued here and empty will be automatically // ignored. -func (h *Headscale) Change(c change.ChangeSet) { - h.mapBatcher.AddWork(c) +func (h *Headscale) Change(cs ...change.ChangeSet) { + h.mapBatcher.AddWork(cs...) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index cb284173..81032640 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -12,7 +12,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/rs/zerolog/log" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -29,28 +28,10 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, err := h.state.GetNodeByNodeKey(regReq.NodeKey) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("looking up node in database: %w", err) - } + node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey) - if node != nil { - // If an existing node is trying to register with an auth key, - // we need to validate the auth key even for existing nodes - if regReq.Auth != nil && regReq.Auth.AuthKey != "" { - resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) - if err != nil { - // Preserve HTTPError types so they can be handled properly by the HTTP layer - var httpErr HTTPError - if errors.As(err, &httpErr) { - return nil, httpErr - } - return nil, fmt.Errorf("handling register with auth key for existing node: %w", err) - } - return resp, nil - } - - resp, err := h.handleExistingNode(node, regReq, machineKey) + if ok { + resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey) if err != nil { return nil, fmt.Errorf("handling existing node: %w", err) } @@ -70,6 +51,7 @@ func (h *Headscale) handleRegister( if errors.As(err, &httpErr) { return nil, httpErr } + return nil, fmt.Errorf("handling register with auth key: %w", err) } @@ -89,13 +71,22 @@ func (h *Headscale) handleExistingNode( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - if node.MachineKey != machineKey { return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } expired := node.IsExpired() + // If the node is expired and this is not a re-authentication attempt, + // force the client to re-authenticate + if expired && regReq.Auth == nil { + return &tailcfg.RegisterResponse{ + NodeKeyExpired: true, + MachineAuthorized: false, + AuthURL: "", // Client will need to re-authenticate + }, nil + } + if !expired && !regReq.Expiry.IsZero() { requestExpiry := regReq.Expiry @@ -107,7 +98,7 @@ func (h *Headscale) handleExistingNode( // If the request expiry is in the past, we consider it a logout. if requestExpiry.Before(time.Now()) { if node.IsEphemeral() { - c, err := h.state.DeleteNode(node) + c, err := h.state.DeleteNode(node.View()) if err != nil { return nil, fmt.Errorf("deleting ephemeral node: %w", err) } @@ -118,15 +109,19 @@ func (h *Headscale) handleExistingNode( } } - _, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) + updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) if err != nil { return nil, fmt.Errorf("setting node expiry: %w", err) } h.Change(c) - } - return nodeToRegisterResponse(node), nil + // CRITICAL: Use the updated node view for the response + // The original node object has stale expiry information + node = updatedNode.AsStruct() + } + + return nodeToRegisterResponse(node), nil } func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { @@ -177,7 +172,7 @@ func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey( + node, changed, err := h.state.HandleNodeFromPreAuthKey( regReq, machineKey, ) @@ -193,8 +188,8 @@ func (h *Headscale) handleRegisterWithAuthKey( return nil, err } - // If node is nil, it means an ephemeral node was deleted during logout - if node == nil { + // If node is not valid, it means an ephemeral node was deleted during logout + if !node.Valid() { h.Change(changed) return nil, nil } @@ -213,26 +208,30 @@ func (h *Headscale) handleRegisterWithAuthKey( // TODO(kradalby): This needs to be ran as part of the batcher maybe? // now since we dont update the node/pol here anymore routeChange := h.state.AutoApproveRoutes(node) + if _, _, err := h.state.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } if routeChange && changed.Empty() { - changed = change.NodeAdded(node.ID) + changed = change.NodeAdded(node.ID()) } h.Change(changed) - // If policy changed due to node registration, send a separate policy change - if policyChanged { - policyChange := change.PolicyChange() - h.Change(policyChange) - } + // TODO(kradalby): I think this is covered above, but we need to validate that. + // // If policy changed due to node registration, send a separate policy change + // if policyChanged { + // policyChange := change.PolicyChange() + // h.Change(policyChange) + // } + + user := node.User() return &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), - User: *node.User.TailscaleUser(), - Login: *node.User.TailscaleLogin(), + User: *user.TailscaleUser(), + Login: *user.TailscaleLogin(), }, nil } @@ -266,6 +265,7 @@ func (h *Headscale) handleRegisterInteractive( ) log.Info().Msgf("Starting node registration using key: %s", registrationId) + return &tailcfg.RegisterResponse{ AuthURL: h.authProvider.AuthURL(registrationId), }, nil diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 83d62d3d..3531fc49 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { } // RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. If the name is not unique, it will return an error. +// and renames it. Validation should be done in the state layer before calling this function. func RenameNode(tx *gorm.DB, nodeID types.NodeID, newName string, ) error { - err := util.CheckForFQDNRules( - newName, - ) - if err != nil { - return fmt.Errorf("renaming node: %w", err) + // Check if the new name is unique + var count int64 + if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { + return fmt.Errorf("failed to check name uniqueness: %w", err) } - uniq, err := isUniqueName(tx, newName) - if err != nil { - return fmt.Errorf("checking if name is unique: %w", err) - } - - if !uniq { - return fmt.Errorf("name is not unique: %s", newName) + if count > 0 { + return errors.New("name is not unique") } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -333,108 +327,19 @@ func (hsdb *HSDatabase) DeleteEphemeralNode( }) } -// HandleNodeFromAuthPath is called from the OIDC or CLI auth path -// with a registrationID to register or reauthenticate a node. -// If the node found in the registration cache is not already registered, -// it will be registered with the user and the node will be removed from the cache. -// If the node is already registered, the expiry will be updated. -// The node, and a boolean indicating if it was a new node or not, will be returned. -func (hsdb *HSDatabase) HandleNodeFromAuthPath( - registrationID types.RegistrationID, - userID types.UserID, - nodeExpiry *time.Time, - registrationMethod string, - ipv4 *netip.Addr, - ipv6 *netip.Addr, -) (*types.Node, change.ChangeSet, error) { - var nodeChange change.ChangeSet - node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - if reg, ok := hsdb.regCache.Get(registrationID); ok { - if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { - user, err := GetUserByID(tx, userID) - if err != nil { - return nil, fmt.Errorf( - "failed to find user in register node from auth callback, %w", - err, - ) - } +// RegisterNodeForTest is used only for testing purposes to register a node directly in the database. +// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey. +func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { + if !testing.Testing() { + panic("RegisterNodeForTest can only be called during tests") + } - log.Debug(). - Str("registration_id", registrationID.String()). - Str("username", user.Username()). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). - Msg("Registering node from API/CLI or auth callback") - - // TODO(kradalby): This looks quite wrong? why ID 0? - // Why not always? - // Registration of expired node with different user - if reg.Node.ID != 0 && - reg.Node.UserID != user.ID { - return nil, ErrDifferentRegisteredUser - } - - reg.Node.UserID = user.ID - reg.Node.User = *user - reg.Node.RegisterMethod = registrationMethod - - if nodeExpiry != nil { - reg.Node.Expiry = nodeExpiry - } - - node, err := RegisterNode( - tx, - reg.Node, - ipv4, ipv6, - ) - - if err == nil { - hsdb.regCache.Delete(registrationID) - } - - // Signal to waiting clients that the machine has been registered. - select { - case reg.Registered <- node: - default: - } - close(reg.Registered) - - nodeChange = change.NodeAdded(node.ID) - - return node, err - } else { - // If the node is already registered, this is a refresh. - err := NodeSetExpiry(tx, node.ID, *nodeExpiry) - if err != nil { - return nil, err - } - - nodeChange = change.KeyExpiry(node.ID) - - return node, nil - } - } - - return nil, ErrNodeNotFoundRegistrationCache - }) - - return node, nodeChange, err -} - -func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - return RegisterNode(tx, node, ipv4, ipv6) - }) -} - -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. -func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). Str("user", node.User.Username()). - Msg("Registering node") + Msg("Registering test node") // If the a new node is registered with the same machine key, to the same user, // update the existing node. @@ -445,8 +350,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.ID = oldNode.ID node.GivenName = oldNode.GivenName node.ApprovedRoutes = oldNode.ApprovedRoutes - ipv4 = oldNode.IPv4 - ipv6 = oldNode.IPv6 + // Don't overwrite the provided IPs with old ones when they exist + if ipv4 == nil { + ipv4 = oldNode.IPv4 + } + if ipv6 == nil { + ipv6 = oldNode.IPv6 + } } // If the node exists and it already has IP(s), we just save it @@ -463,7 +373,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). Str("user", node.User.Username()). - Msg("Node authorized again") + Msg("Test node authorized again") return &node, nil } @@ -472,7 +382,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.IPv6 = ipv6 if node.GivenName == "" { - givenName, err := ensureUniqueGivenName(tx, node.Hostname) + givenName, err := EnsureUniqueGivenName(tx, node.Hostname) if err != nil { return nil, fmt.Errorf("failed to ensure unique given name: %w", err) } @@ -487,7 +397,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad log.Trace(). Caller(). Str("node", node.Hostname). - Msg("Node registered with the database") + Msg("Test node registered with the database") return &node, nil } @@ -560,7 +470,8 @@ func isUniqueName(tx *gorm.DB, name string) (bool, error) { return len(nodes) == 0, nil } -func ensureUniqueGivenName( +// EnsureUniqueGivenName generates a unique given name for a node based on its hostname. +func EnsureUniqueGivenName( tx *gorm.DB, name string, ) (string, error) { @@ -781,19 +692,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname . node := hsdb.CreateNodeForTest(user, hostname...) - err := hsdb.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, *node, nil, nil) + // Allocate IPs for the test node using the database's IP allocator + // This is a simplified allocation for testing - in production this would use State.ipAlloc + ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID) + if err != nil { + panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err)) + } + + var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { + var err error + registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) return err }) if err != nil { panic(fmt.Sprintf("failed to register test node: %v", err)) } - registeredNode, err := hsdb.GetNodeByID(node.ID) - if err != nil { - panic(fmt.Sprintf("failed to get registered test node: %v", err)) - } - return registeredNode } @@ -842,3 +757,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int return nodes } + +// allocateTestIPs allocates sequential test IPs for nodes during testing. +func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) { + if !testing.Testing() { + panic("allocateTestIPs can only be called during tests") + } + + // Use simple sequential allocation for tests + // IPv4: 100.64.0.x (where x is nodeID) + // IPv6: fd7a:115c:a1e0::x (where x is nodeID) + + if nodeID > 254 { + return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID) + } + + ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)}) + ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)}) + + return &ipv4, &ipv6, nil +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 8819fbcf..84e30e0a 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) { func TestAutoApproveRoutes(t *testing.T) { tests := []struct { - name string - acl string - routes []netip.Prefix - want []netip.Prefix - want2 []netip.Prefix + name string + acl string + routes []netip.Prefix + want []netip.Prefix + want2 []netip.Prefix + expectChange bool // whether to expect route changes }{ + { + name: "no-auto-approvers-empty-policy", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ] +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - no auto-approvers + want2: []netip.Prefix{}, // Should be empty - no auto-approvers + expectChange: false, // No changes expected + }, + { + name: "no-auto-approvers-explicit-empty", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ], + "autoApprovers": { + "routes": {}, + "exitNode": [] + } +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + expectChange: false, // No changes expected + }, { name: "2068-approve-issue-sub-kube", acl: ` @@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) { } } }`, - routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, - want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + expectChange: true, // Routes should be approved }, { name: "2068-approve-issue-sub-exit-tag", @@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) { tsaddr.AllIPv4(), tsaddr.AllIPv6(), }, + expectChange: true, // Routes should be approved }, } @@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) { require.NoError(t, err) require.NotNil(t, pm) - changed1 := policy.AutoApproveRoutes(pm, &node) - assert.True(t, changed1) + newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes) + assert.Equal(t, tt.expectChange, changed1) - err = adb.DB.Save(&node).Error - require.NoError(t, err) + if changed1 { + err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1) + require.NoError(t, err) + } - _ = policy.AutoApproveRoutes(pm, &nodeTagged) - - err = adb.DB.Save(&nodeTagged).Error - require.NoError(t, err) + newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes) + if changed2 { + err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2) + require.NoError(t, err) + } node1ByID, err := adb.GetNodeByID(1) require.NoError(t, err) - if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { + // For empty auto-approvers tests, handle nil vs empty slice comparison + expectedRoutes1 := tt.want + if len(expectedRoutes1) == 0 { + expectedRoutes1 = nil + } + if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } node2ByID, err := adb.GetNodeByID(2) require.NoError(t, err) - if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { + expectedRoutes2 := tt.want2 + if len(expectedRoutes2) == 0 { + expectedRoutes2 = nil + } + if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } }) @@ -620,11 +679,11 @@ func TestRenameNode(t *testing.T) { require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, node, nil, nil) + _, err := RegisterNodeForTest(tx, node, nil, nil) if err != nil { return err } - _, err = RegisterNode(tx, node2, nil, nil) + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err }) @@ -721,11 +780,11 @@ func TestListPeers(t *testing.T) { require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, node1, nil, nil) + _, err := RegisterNodeForTest(tx, node1, nil, nil) if err != nil { return err } - _, err = RegisterNode(tx, node2, nil, nil) + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err }) @@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) { // No parameter means no filter, should return all peers nodes, err = db.ListPeers(1) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Empty node list should return all peers nodes, err = db.ListPeers(1, types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // No match in IDs should return empty list and no error @@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) { // Partial match in IDs nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs, but node ID is still filtered out nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) } @@ -806,11 +865,11 @@ func TestListNodes(t *testing.T) { require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { - _, err := RegisterNode(tx, node1, nil, nil) + _, err := RegisterNodeForTest(tx, node1, nil, nil) if err != nil { return err } - _, err = RegisterNode(tx, node2, nil, nil) + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err }) @@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) { // No parameter means no filter, should return all nodes nodes, err = db.ListNodes() require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) // Empty node list should return all nodes nodes, err = db.ListNodes(types.NodeIDs{}...) require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) @@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) { // Partial match in IDs nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) require.NoError(t, err) - assert.Equal(t, 1, len(nodes)) + assert.Len(t, nodes, 1) assert.Equal(t, "test2", nodes[0].Hostname) // Several matched IDs nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) require.NoError(t, err) - assert.Equal(t, 2, len(nodes)) + assert.Len(t, nodes, 2) assert.Equal(t, "test1", nodes[0].Hostname) assert.Equal(t, "test2", nodes[1].Hostname) } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 2e60de2e..a36c1f13 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "slices" "strings" "time" @@ -47,8 +48,9 @@ func CreatePreAuthKey( return nil, err } - // Remove duplicates + // Remove duplicates and sort for consistency aclTags = set.SetOf(aclTags).Slice() + slices.Sort(aclTags) // TODO(kradalby): factor out and create a reusable tag validation, // check if there is one in Tailscale's lib. diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 1b333792..26d10060 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { } // AssignNodeToUser assigns a Node to a user. +// Note: Validation should be done in the state layer before calling this function. func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error { - node, err := GetNodeByID(tx, nodeID) - if err != nil { - return err + // Check if the user exists + var userExists bool + if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil { + return fmt.Errorf("failed to check if user exists: %w", err) } - user, err := GetUserByID(tx, uid) - if err != nil { - return err + + if !userExists { + return ErrUserNotFound } - node.User = *user - node.UserID = user.ID - if result := tx.Save(&node); result.Error != nil { - return result.Error + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil { + return fmt.Errorf("failed to assign node to user: %w", err) } return nil diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 60676a1d..c2b478b1 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { } sshPol := make(map[string]*tailcfg.SSHPolicy) - for _, node := range nodes { - pol, err := h.state.SSHPolicy(node.View()) + for _, node := range nodes.All() { + pol, err := h.state.SSHPolicy(node) if err != nil { httpError(w, err) return } - sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol + sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol } sshJSON, err := json.MarshalIndent(sshPol, "", " ") diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 722f8421..1b1a22e2 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "github.com/samber/lo" "google.golang.org/grpc/codes" @@ -25,6 +24,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/views" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/state" @@ -59,9 +59,10 @@ func (api headscaleV1APIServer) CreateUser( return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } - c := change.UserAdded(types.UserID(user.ID)) - if policyChanged { + + // TODO(kradalby): Both of these might be policy changes, find a better way to merge. + if !policyChanged.Empty() { c.Change = change.Policy } @@ -79,15 +80,13 @@ func (api headscaleV1APIServer) RenameUser( return nil, err } - _, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) + _, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } // Send policy update notifications if needed - if policyChanged { - api.h.Change(change.PolicyChange()) - } + api.h.Change(c) newUser, err := api.h.state.GetUserByName(request.GetNewName()) if err != nil { @@ -288,17 +287,13 @@ func (api headscaleV1APIServer) GetNode( ctx context.Context, request *v1.GetNodeRequest, ) (*v1.GetNodeResponse, error) { - node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } resp := node.Proto() - // Populate the online field based on - // currently connected nodes. - resp.Online = api.h.mapBatcher.IsConnected(node.ID) - return &v1.GetNodeResponse{Node: resp}, nil } @@ -323,7 +318,8 @@ func (api headscaleV1APIServer) SetTags( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Caller(). + Str("node", node.Hostname()). Strs("tags", request.GetTags()). Msg("Changing tags of node") @@ -334,7 +330,13 @@ func (api headscaleV1APIServer) SetApprovedRoutes( ctx context.Context, request *v1.SetApprovedRoutesRequest, ) (*v1.SetApprovedRoutesResponse, error) { - var routes []netip.Prefix + log.Debug(). + Caller(). + Uint64("node.id", request.GetNodeId()). + Strs("requestedRoutes", request.GetRoutes()). + Msg("gRPC SetApprovedRoutes called") + + var newApproved []netip.Prefix for _, route := range request.GetRoutes() { prefix, err := netip.ParsePrefix(route) if err != nil { @@ -344,31 +346,35 @@ func (api headscaleV1APIServer) SetApprovedRoutes( // If the prefix is an exit route, add both. The client expect both // to annotate the node as an exit node. if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() { - routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6()) + newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6()) } else { - routes = append(routes, prefix) + newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(routes) - routes = slices.Compact(routes) + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) - node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes) + node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) if err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } - routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) - // Always propagate node changes from SetApprovedRoutes api.h.Change(nodeChange) - // If routes changed, propagate those changes too - if !routeChange.Empty() { - api.h.Change(routeChange) - } - proto := node.Proto() - proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID)) + // Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the + // routes that are actively served from the node (per architectural requirement in types/node.go) + primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID()) + proto.SubnetRoutes = util.PrefixesToString(primaryRoutes) + + log.Debug(). + Caller(). + Uint64("node.id", node.ID().Uint64()). + Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). + Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)). + Strs("finalSubnetRoutes", proto.SubnetRoutes). + Msg("gRPC SetApprovedRoutes completed") return &v1.SetApprovedRoutesResponse{Node: proto}, nil } @@ -390,9 +396,9 @@ func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, ) (*v1.DeleteNodeResponse, error) { - node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } nodeChange, err := api.h.state.DeleteNode(node) @@ -420,8 +426,9 @@ func (api headscaleV1APIServer) ExpireNode( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). - Time("expiry", *node.Expiry). + Caller(). + Str("node", node.Hostname()). + Time("expiry", *node.AsStruct().Expiry). Msg("node expired") return &v1.ExpireNodeResponse{Node: node.Proto()}, nil @@ -440,7 +447,8 @@ func (api headscaleV1APIServer) RenameNode( api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Caller(). + Str("node", node.Hostname()). Str("new_name", request.GetNewName()). Msg("node renamed") @@ -455,58 +463,45 @@ func (api headscaleV1APIServer) ListNodes( // the filtering of nodes by user, vs nodes as a whole can // probably be done once. // TODO(kradalby): This should be done in one tx. - - IsConnected := api.h.mapBatcher.ConnectedMap() if request.GetUser() != "" { user, err := api.h.state.GetUserByName(request.GetUser()) if err != nil { return nil, err } - nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID)) - if err != nil { - return nil, err - } + nodes := api.h.state.ListNodesByUser(types.UserID(user.ID)) - response := nodesToProto(api.h.state, IsConnected, nodes) + response := nodesToProto(api.h.state, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } - nodes, err := api.h.state.ListNodes() - if err != nil { - return nil, err - } + nodes := api.h.state.ListNodes() - sort.Slice(nodes, func(i, j int) bool { - return nodes[i].ID < nodes[j].ID - }) - - response := nodesToProto(api.h.state, IsConnected, nodes) + response := nodesToProto(api.h.state, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } -func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node { - response := make([]*v1.Node, len(nodes)) - for index, node := range nodes { +func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node { + response := make([]*v1.Node, nodes.Len()) + for index, node := range nodes.All() { resp := node.Proto() - // Populate the online field based on - // currently connected nodes. - if val, ok := IsConnected.Load(node.ID); ok && val { - resp.Online = true - } - var tags []string for _, tag := range node.RequestTags() { - if state.NodeCanHaveTag(node.View(), tag) { + if state.NodeCanHaveTag(node, tag) { tags = append(tags, tag) } } - resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...)) - resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...)) + resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...)) + + resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) response[index] = resp } + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) + return response } @@ -674,17 +669,15 @@ func (api headscaleV1APIServer) SetPolicy( // a scenario where they might be allowed if the server has no nodes // yet, but it should help for the general case and for hot reloading // configurations. - nodes, err := api.h.state.ListNodes() - if err != nil { - return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) - } - changed, err := api.h.state.SetPolicy([]byte(p)) + nodes := api.h.state.ListNodes() + + _, err := api.h.state.SetPolicy([]byte(p)) if err != nil { return nil, fmt.Errorf("setting policy: %w", err) } - if len(nodes) > 0 { - _, err = api.h.state.SSHPolicy(nodes[0].View()) + if nodes.Len() > 0 { + _, err = api.h.state.SSHPolicy(nodes.At(0)) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -695,14 +688,20 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - // Only send update if the packet filter has changed. - if changed { - err = api.h.state.AutoApproveNodes() - if err != nil { - return nil, err - } + // Always reload policy to ensure route re-evaluation, even if policy content hasn't changed. + // This ensures that routes are re-evaluated for auto-approval in cases where routes + // were manually disabled but could now be auto-approved with the current policy. + cs, err := api.h.state.ReloadPolicy() + if err != nil { + return nil, fmt.Errorf("reloading policy: %w", err) + } - api.h.Change(change.PolicyChange()) + if len(cs) > 0 { + api.h.Change(cs...) + } else { + log.Debug(). + Caller(). + Msg("No policy changes to distribute because ReloadPolicy returned empty changeset") } response := &v1.SetPolicyResponse{ diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 2d664104..cac4ff0f 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -94,13 +94,19 @@ func (h *Headscale) handleVerifyRequest( return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) } - nodes, err := h.state.ListNodes() - if err != nil { - return fmt.Errorf("cannot list nodes: %w", err) + nodes := h.state.ListNodes() + + // Check if any node has the requested NodeKey + var nodeKeyFound bool + for _, node := range nodes.All() { + if node.NodeKey() == derpAdmitClientRequest.NodePublic { + nodeKeyFound = true + break + } } resp := &tailcfg.DERPAdmitClientResponse{ - Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), + Allow: nodeKeyFound, } return json.NewEncoder(writer).Encode(resp) diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index bb69eac2..1299ed54 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -1,6 +1,7 @@ package mapper import ( + "errors" "fmt" "time" @@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher type Batcher interface { Start() Close() - AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error - RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error + RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] AddWork(c change.ChangeSet) @@ -120,7 +121,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet]. func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error { if nc == nil { - return fmt.Errorf("nodeConnection is nil") + return errors.New("nodeConnection is nil") } nodeID := nc.nodeID() diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index e733e29a..7476b72f 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -21,8 +21,7 @@ type LockFreeBatcher struct { mapper *mapper workers int - // Lock-free concurrent maps - nodes *xsync.Map[types.NodeID, *nodeConn] + nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] connected *xsync.Map[types.NodeID, *time.Time] // Work queue channel @@ -32,7 +31,6 @@ type LockFreeBatcher struct { // Batching state pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] - batchMutex sync.RWMutex // Metrics totalNodes atomic.Int64 @@ -45,65 +43,63 @@ type LockFreeBatcher struct { // AddNode registers a new node connection with the batcher and sends an initial map response. // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. -// TODO(kradalby): See if we can move the isRouter argument somewhere else. -func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error { - // First validate that we can generate initial map before doing anything else - fullSelfChange := change.FullSelf(id) +func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + addNodeStart := time.Now() - // TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange. - // This currently means that the goroutine for the node connection will do the processing - // which means that we might have uncontrolled concurrency. - // When we use MapResponseFromChange, it will be processed by the same worker pool, causing - // it to be processed in a more controlled manner. - initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange) - if err != nil { - return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) + // Generate connection ID + connID := generateConnectionID() + + // Create new connection entry + now := time.Now() + newEntry := &connectionEntry{ + id: connID, + c: c, + version: version, + created: now, } // Only after validation succeeds, create or update node connection newConn := newNodeConn(id, c, version, b.mapper) - var conn *nodeConn - if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded { - // Update existing connection - existing.updateConnection(c, version) - conn = existing - } else { + if !loaded { b.totalNodes.Add(1) conn = newConn } - // Mark as connected only after validation succeeds b.connected.Store(id, nil) // nil = connected - log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher") + if err != nil { + log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") + nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) + } - // Send the validated initial map - if initialMap != nil { - if err := conn.send(initialMap); err != nil { - // Clean up the connection state on send failure - b.nodes.Delete(id) - b.connected.Delete(id) - return fmt.Errorf("failed to send initial map to node %d: %w", id, err) - } - - // Notify other nodes that this node came online - b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter}) + // Use a blocking send with timeout for initial map since the channel should be ready + // and we want to avoid the race condition where the receiver isn't ready yet + select { + case c <- initialMap: + // Success + case <-time.After(5 * time.Second): + log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). + Msg("Initial map send timed out because channel was blocked or receiver not ready") + nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to send initial map to node %d: timeout", id) } return nil } // RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. -// It validates the connection channel matches the current one, closes the connection, -// and notifies other nodes that this node has gone offline. -func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) { - // Check if this is the current connection and mark it as closed - if existing, ok := b.nodes.Load(id); ok { - if !existing.matchesChannel(c) { - log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring") - return // Not the current connection, not an error - } +// It validates the connection channel matches one of the current connections, closes that specific connection, +// and keeps the node entry alive for rapid reconnections instead of aggressive deletion. +// Reports if the node still has active connections after removal. +func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + nodeConn, exists := b.nodes.Load(id) + if !exists { + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher") + return false + } // Mark the connection as closed to prevent further sends if connData := existing.connData.Load(); connData != nil { @@ -111,15 +107,20 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo } } - log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline") + // Check if node has any remaining active connections + if nodeConn.hasActiveConnections() { + log.Debug().Caller().Uint64("node.id", id.Uint64()). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("Node connection removed but keeping online because other connections remain") + return true // Node still has active connections + } // Remove node and mark disconnected atomically b.nodes.Delete(id) b.connected.Store(id, ptr.To(time.Now())) b.totalNodes.Add(-1) - // Notify other nodes that this node went offline - b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter}) + return false } // AddWork queues a change to be processed by the batcher. @@ -205,15 +206,6 @@ func (b *LockFreeBatcher) worker(workerID int) { return } - duration := time.Since(startTime) - if duration > 100*time.Millisecond { - log.Warn(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Dur("duration", duration). - Msg("slow synchronous work processing") - } continue } @@ -221,16 +213,8 @@ func (b *LockFreeBatcher) worker(workerID int) { // that should be processed and sent to the node instead of // returned to the caller. if nc, exists := b.nodes.Load(w.nodeID); exists { - // Check if this connection is still active before processing - if connData := nc.connData.Load(); connData != nil && connData.closed.Load() { - log.Debug(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Msg("skipping work for closed connection") - continue - } - + // Apply change to node - this will handle offline nodes gracefully + // and queue work for when they reconnect err := nc.change(w.c) if err != nil { b.workErrors.Add(1) @@ -240,52 +224,18 @@ func (b *LockFreeBatcher) worker(workerID int) { Str("change", w.c.Change.String()). Msg("failed to apply change") } - } else { - log.Debug(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Msg("node not found for asynchronous work - node may have disconnected") } - - duration := time.Since(startTime) - if duration > 100*time.Millisecond { - log.Warn(). - Int("workerID", workerID). - Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). - Dur("duration", duration). - Msg("slow asynchronous work processing") - } - case <-b.ctx.Done(): return } } } -func (b *LockFreeBatcher) addWork(c change.ChangeSet) { - // For critical changes that need immediate processing, send directly - if b.shouldProcessImmediately(c) { - if c.SelfUpdateOnly { - b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil}) - return - } - b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool { - if c.NodeID == nodeID && !c.AlsoSelf() { - return true - } - b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil}) - return true - }) - return - } - - // For non-critical changes, add to batch - b.addToBatch(c) +func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) { + b.addToBatch(c...) } -// queueWork safely queues work +// queueWork safely queues work. func (b *LockFreeBatcher) queueWork(w work) { b.workQueuedCount.Add(1) @@ -298,26 +248,21 @@ func (b *LockFreeBatcher) queueWork(w work) { } } -// shouldProcessImmediately determines if a change should bypass batching -func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool { - // Process these changes immediately to avoid delaying critical functionality - switch c.Change { - case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy: - return true - default: - return false +// addToBatch adds a change to the pending batch. +func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) { + // Short circuit if any of the changes is a full update, which + // means we can skip sending individual changes. + if change.HasFull(c) { + b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}}) + + return true + }) + return } } -// addToBatch adds a change to the pending batch -func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { - b.batchMutex.Lock() - defer b.batchMutex.Unlock() - if c.SelfUpdateOnly { - changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{}) - changes = append(changes, c) - b.pendingChanges.Store(c.NodeID, changes) return } @@ -329,15 +274,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) { changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) changes = append(changes, c) b.pendingChanges.Store(nodeID, changes) + return true }) } -// processBatchedChanges processes all pending batched changes +// processBatchedChanges processes all pending batched changes. func (b *LockFreeBatcher) processBatchedChanges() { - b.batchMutex.Lock() - defer b.batchMutex.Unlock() - if b.pendingChanges == nil { return } @@ -355,16 +298,31 @@ func (b *LockFreeBatcher) processBatchedChanges() { // Clear the pending changes for this node b.pendingChanges.Delete(nodeID) + return true }) } // IsConnected is lock-free read. func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { - if val, ok := b.connected.Load(id); ok { - // nil means connected - return val == nil + // First check if we have active connections for this node + if nodeConn, exists := b.nodes.Load(id); exists { + if nodeConn.hasActiveConnections() { + return true + } } + + // Check disconnected timestamp with grace period + val, ok := b.connected.Load(id) + if !ok { + return false + } + + // nil means connected + if val == nil { + return true + } + return false } @@ -372,9 +330,26 @@ func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret := xsync.NewMap[types.NodeID, bool]() + // First, add all nodes with active connections + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + if nodeConn.hasActiveConnections() { + ret.Store(id, true) + } + return true + }) + + // Then add all entries from the connected map b.connected.Range(func(id types.NodeID, val *time.Time) bool { - // nil means connected - ret.Store(id, val == nil) + // Only add if not already added as connected above + if _, exists := ret.Load(id); !exists { + if val == nil { + // nil means connected + ret.Store(id, true) + } else { + // timestamp means disconnected + ret.Store(id, false) + } + } return true }) @@ -482,12 +457,21 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error { return fmt.Errorf("node %d: connection closed", nc.id) } - // TODO(kradalby): We might need some sort of timeout here if the client is not reading - // the channel. That might mean that we are sending to a node that has gone offline, but - // the channel is still open. - connData.c <- data - nc.updateCount.Add(1) - return nil + // Add all entries from the connected map to capture both connected and disconnected nodes + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // Only add if not already processed above + if _, exists := result[id]; !exists { + // Use immediate connection status for debug (no grace period) + connected := (val == nil) // nil means connected, timestamp means disconnected + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: 0, + } + } + return true + }) + + return result } func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 12bb37be..6cf63dca 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -27,6 +27,60 @@ type batcherTestCase struct { fn batcherFunc } +// testBatcherWrapper wraps a real batcher to add online/offline notifications +// that would normally be sent by poll.go in production. +type testBatcherWrapper struct { + Batcher + state *state.State +} + +func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + // Mark node as online in state before AddNode to match production behavior + // This ensures the NodeStore has correct online status for change processing + if t.state != nil { + // Use Connect to properly mark node online in NodeStore but don't send its changes + _ = t.state.Connect(id) + } + + // First add the node to the real batcher + err := t.Batcher.AddNode(id, c, version) + if err != nil { + return err + } + + // Send the online notification that poll.go would normally send + // This ensures other nodes get notified about this node coming online + t.AddWork(change.NodeOnline(id)) + + return nil +} + +func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + // Mark node as offline in state BEFORE removing from batcher + // This ensures the NodeStore has correct offline status when the change is processed + if t.state != nil { + // Use Disconnect to properly mark node offline in NodeStore but don't send its changes + _, _ = t.state.Disconnect(id) + } + + // Send the offline notification that poll.go would normally send + // Do this BEFORE removing from batcher so the change can be processed + t.AddWork(change.NodeOffline(id)) + + // Finally remove from the real batcher + removed := t.Batcher.RemoveNode(id, c) + if !removed { + return false + } + + return true +} + +// wrapBatcherForTest wraps a batcher with test-specific behavior. +func wrapBatcherForTest(b Batcher, state *state.State) Batcher { + return &testBatcherWrapper{Batcher: b, state: state} +} + // allBatcherFunctions contains all batcher implementations to test. var allBatcherFunctions = []batcherTestCase{ {"LockFree", NewBatcherAndMapper}, @@ -183,8 +237,8 @@ func setupBatcherWithTestData( "acls": [ { "action": "accept", - "users": ["*"], - "ports": ["*:*"] + "src": ["*"], + "dst": ["*:*"] } ] }` @@ -194,8 +248,8 @@ func setupBatcherWithTestData( t.Fatalf("Failed to set allow-all policy: %v", err) } - // Create batcher with the state - batcher := bf(cfg, state) + // Create batcher with the state and wrap it for testing + batcher := wrapBatcherForTest(bf(cfg, state), state) batcher.Start() testData := &TestData{ @@ -462,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) time.Sleep(100 * time.Millisecond) // Let connection settle // Generate some work @@ -566,7 +620,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) for i := range allNodes { node := &allNodes[i] - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullSet) @@ -614,7 +668,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Disconnect all nodes for i := range allNodes { node := &allNodes[i] - batcher.RemoveNode(node.n.ID, node.ch, false) + batcher.RemoveNode(node.n.ID, node.ch) } // Give time for final updates to process @@ -732,7 +786,8 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := testData.Nodes[1] // Test AddNode with real node ID - batcher.AddNode(tn.n.ID, tn.ch, false, 100) + batcher.AddNode(tn.n.ID, tn.ch, 100) + if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") } @@ -752,14 +807,14 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) // Add the second node and verify update message - batcher.AddNode(tn2.n.ID, tn2.ch, false, 100) + batcher.AddNode(tn2.n.ID, tn2.ch, 100) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. select { case data := <-tn.ch: assertOnlineMapResponse(t, data, true) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Did not receive expected Online response update") } @@ -778,14 +833,14 @@ func TestBatcherBasicOperations(t *testing.T) { } // Disconnect the second node - batcher.RemoveNode(tn2.n.ID, tn2.ch, false) - assert.False(t, batcher.IsConnected(tn2.n.ID)) + batcher.RemoveNode(tn2.n.ID, tn2.ch) + // Note: IsConnected may return true during grace period for DNS resolution // First node should get update that second has disconnected. select { case data := <-tn.ch: assertOnlineMapResponse(t, data, false) - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): t.Error("Did not receive expected Online response update") } @@ -811,10 +866,9 @@ func TestBatcherBasicOperations(t *testing.T) { // } // Test RemoveNode - batcher.RemoveNode(tn.n.ID, tn.ch, false) - if batcher.IsConnected(tn.n.ID) { - t.Error("Node should be disconnected after RemoveNode") - } + batcher.RemoveNode(tn.n.ID, tn.ch) + // Note: IsConnected may return true during grace period for DNS resolution + // The node is actually removed from active connections but grace period allows DNS lookups }) } } @@ -957,7 +1011,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1053,7 +1107,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100)) + + batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) }() // Add real work during connection chaos @@ -1067,7 +1122,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { go func() { defer wg.Done() time.Sleep(1 * time.Microsecond) - batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }() // Remove second connection @@ -1075,7 +1130,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { go func() { defer wg.Done() time.Sleep(2 * time.Microsecond) - batcher.RemoveNode(testNode.n.ID, ch2, false) + batcher.RemoveNode(testNode.n.ID, ch2) }() wg.Wait() @@ -1150,7 +1205,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) batcher.AddWork(change.DERPSet) // Consumer goroutine to validate data and detect channel issues @@ -1192,7 +1247,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Rapid removal creates race between worker and removal time.Sleep(time.Duration(i%3) * 100 * time.Microsecond) - batcher.RemoveNode(testNode.n.ID, ch, false) + batcher.RemoveNode(testNode.n.ID, ch) // Give workers time to process and close channels time.Sleep(5 * time.Millisecond) @@ -1262,7 +1317,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for _, node := range stableNodes { ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1320,7 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) { churningChannelsMutex.Lock() churningChannels[nodeID] = ch churningChannelsMutex.Unlock() - batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { @@ -1357,7 +1412,7 @@ func TestBatcherConcurrentClients(t *testing.T) { ch, exists := churningChannels[nodeID] churningChannelsMutex.Unlock() if exists { - batcher.RemoveNode(nodeID, ch, false) + batcher.RemoveNode(nodeID, ch) } }(node.n.ID) } @@ -1608,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) { var connectedNodesMutex sync.RWMutex for i := range testNodes { node := &testNodes[i] - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true connectedNodesMutex.Unlock() @@ -1675,7 +1730,7 @@ func XTestBatcherScalability(t *testing.T) { connectedNodesMutex.RUnlock() if isConnected { - batcher.RemoveNode(nodeID, channel, false) + batcher.RemoveNode(nodeID, channel) connectedNodesMutex.Lock() connectedNodes[nodeID] = false connectedNodesMutex.Unlock() @@ -1800,7 +1855,7 @@ func XTestBatcherScalability(t *testing.T) { // Now disconnect all nodes from batcher to stop new updates for i := range testNodes { node := &testNodes[i] - batcher.RemoveNode(node.n.ID, node.ch, false) + batcher.RemoveNode(node.n.ID, node.ch) } // Give time for enhanced tracking goroutines to process any remaining data in channels @@ -1934,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time to avoid overwhelming the work queue for i, node := range allNodes { - batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100)) + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Small delay between connections to allow NodeCameOnline processing time.Sleep(50 * time.Millisecond) @@ -1946,12 +2001,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Check how many peers each node should see for i, node := range allNodes { - peers, err := testData.State.ListPeers(node.n.ID) - if err != nil { - t.Errorf("Error listing peers for node %d: %v", i, err) - } else { - t.Logf("Node %d should see %d peers from state", i, len(peers)) - } + peers := testData.State.ListPeers(node.n.ID) + t.Logf("Node %d should see %d peers from state", i, peers.Len()) } // Send a full update - this should generate full peer lists @@ -1967,7 +2018,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { foundFullUpdate := false // Read all available updates for each node - for i := range len(allNodes) { + for i := range allNodes { nodeUpdates := 0 t.Logf("Reading updates for node %d:", i) @@ -2056,9 +2107,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) { t.Logf("=== WORK QUEUE TRACING TEST ===") - // Connect first node - batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100)) - t.Logf("Connected node %d", nodes[0].n.ID) + time.Sleep(100 * time.Millisecond) // Let connections settle // Wait for initial NodeCameOnline to be processed time.Sleep(200 * time.Millisecond) @@ -2111,14 +2160,172 @@ func TestBatcherWorkQueueTracing(t *testing.T) { t.Errorf("ERROR: Received unknown update type!") } - // Check if there should be peers available - peers, err := testData.State.ListPeers(nodes[0].n.ID) - if err != nil { - t.Errorf("Error getting peers from state: %v", err) - } else { - t.Logf("State shows %d peers available for this node", len(peers)) - if len(peers) > 0 && len(data.Peers) == 0 { - t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers)) + batcher := testData.Batcher + node1 := testData.Nodes[0] + node2 := testData.Nodes[1] + + t.Logf("=== MULTI-CONNECTION TEST ===") + + // Phase 1: Connect first node with initial connection + t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node1: %v", err) + } + + // Connect second node for comparison + err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node2: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Phase 2: Add second connection for node1 (multi-connection scenario) + t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add second connection for node1: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Phase 3: Add third connection for node1 + t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add third connection for node1: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Phase 4: Verify debug status shows correct connection count + t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + + if info, exists := debugInfo[node1.n.ID]; exists { + t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 3 { + t.Errorf("Node1 should have 3 active connections, got %d", activeConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 3 active connections") + } + } + if connected, ok := infoMap["connected"].(bool); ok && !connected { + t.Errorf("Node1 should show as connected with 3 active connections") + } + } + } + + if info, exists := debugInfo[node2.n.ID]; exists { + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 1 { + t.Errorf("Node2 should have 1 active connection, got %d", activeConnections) + } + } + } + } + } + + // Phase 5: Send update and verify ALL connections receive it + t.Logf("Phase 5: Testing update distribution to all connections...") + + // Clear any existing updates from all channels + clearChannel := func(ch chan *tailcfg.MapResponse) { + for { + select { + case <-ch: + // drain + default: + return + } + } + } + + clearChannel(node1.ch) + clearChannel(secondChannel) + clearChannel(thirdChannel) + clearChannel(node2.ch) + + // Send a change notification from node2 (so node1 should receive it on all connections) + testChangeSet := change.ChangeSet{ + NodeID: node2.n.ID, + Change: change.NodeNewOrUpdate, + SelfUpdateOnly: false, + } + + batcher.AddWork(testChangeSet) + + time.Sleep(100 * time.Millisecond) // Let updates propagate + + // Verify all three connections for node1 receive the update + connection1Received := false + connection2Received := false + connection3Received := false + + select { + case mapResp := <-node1.ch: + connection1Received = (mapResp != nil) + t.Logf("Node1 connection 1 received update: %t", connection1Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 1 did not receive update") + } + + select { + case mapResp := <-secondChannel: + connection2Received = (mapResp != nil) + t.Logf("Node1 connection 2 received update: %t", connection2Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 2 did not receive update") + } + + select { + case mapResp := <-thirdChannel: + connection3Received = (mapResp != nil) + t.Logf("Node1 connection 3 received update: %t", connection3Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 3 did not receive update") + } + + if connection1Received && connection2Received && connection3Received { + t.Logf("SUCCESS: All three connections for node1 received the update") + } else { + t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t", + connection1Received, connection2Received, connection3Received) + } + + // Phase 6: Test connection removal and verify remaining connections still work + t.Logf("Phase 6: Testing connection removal...") + + // Remove the second connection + removed := batcher.RemoveNode(node1.n.ID, secondChannel) + if !removed { + t.Errorf("Failed to remove second connection for node1") + } + + time.Sleep(50 * time.Millisecond) + + // Verify debug status shows 2 connections now + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + if info, exists := debugInfo[node1.n.ID]; exists { + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 2 { + t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal") + } } } } else { diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index dfe9d68d..dc43b933 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -1,6 +1,7 @@ package mapper import ( + "errors" "net/netip" "sort" "time" @@ -12,7 +13,7 @@ import ( "tailscale.com/util/multierr" ) -// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse +// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse. type MapResponseBuilder struct { resp *tailcfg.MapResponse mapper *mapper @@ -21,7 +22,17 @@ type MapResponseBuilder struct { errs []error } -// NewMapResponseBuilder creates a new builder with basic fields set +type debugType string + +const ( + fullResponseDebug debugType = "full" + patchResponseDebug debugType = "patch" + removeResponseDebug debugType = "remove" + changeResponseDebug debugType = "change" + derpResponseDebug debugType = "derp" +) + +// NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() return &MapResponseBuilder{ @@ -35,32 +46,39 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder } } -// addError adds an error to the builder's error list +// addError adds an error to the builder's error list. func (b *MapResponseBuilder) addError(err error) { if err != nil { b.errs = append(b.errs, err) } } -// hasErrors returns true if the builder has accumulated any errors +// hasErrors returns true if the builder has accumulated any errors. func (b *MapResponseBuilder) hasErrors() bool { return len(b.errs) > 0 } -// WithCapabilityVersion sets the capability version for the response +// WithCapabilityVersion sets the capability version for the response. func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder { b.capVer = capVer return b } -// WithSelfNode adds the requesting node to the response +// WithSelfNode adds the requesting node to the response. func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } + // Always use batcher's view of online status for self node + // The batcher respects grace periods for logout scenarios + node := nodeView.AsStruct() + // if b.mapper.batcher != nil { + // node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID)) + // } + _, matchers := b.mapper.state.Filter() tailnode, err := tailNode( node.View(), b.capVer, b.mapper.state, @@ -74,29 +92,38 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { } b.resp.Node = tailnode + return b } -// WithDERPMap adds the DERP map to the response +func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder { + if debugDumpMapResponsePath != "" { + b.debugType = t + } + + return b +} + +// WithDERPMap adds the DERP map to the response. func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder { b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct() return b } -// WithDomain adds the domain configuration +// WithDomain adds the domain configuration. func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder { b.resp.Domain = b.mapper.cfg.Domain() return b } -// WithCollectServicesDisabled sets the collect services flag to false +// WithCollectServicesDisabled sets the collect services flag to false. func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder { b.resp.CollectServices.Set(false) return b } // WithDebugConfig adds debug configuration -// It disables log tailing if the mapper's LogTail is not enabled +// It disables log tailing if the mapper's LogTail is not enabled. func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, @@ -104,53 +131,56 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { return b } -// WithSSHPolicy adds SSH policy configuration for the requesting node +// WithSSHPolicy adds SSH policy configuration for the requesting node. func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } - sshPolicy, err := b.mapper.state.SSHPolicy(node.View()) + sshPolicy, err := b.mapper.state.SSHPolicy(node) if err != nil { b.addError(err) return b } b.resp.SSHPolicy = sshPolicy + return b } -// WithDNSConfig adds DNS configuration for the requesting node +// WithDNSConfig adds DNS configuration for the requesting node. func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node) + return b } -// WithUserProfiles adds user profiles for the requesting node and given peers -func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) +// WithUserProfiles adds user profiles for the requesting node and given peers. +func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } b.resp.UserProfiles = generateUserProfiles(node, peers) + return b } -// WithPacketFilters adds packet filter rules based on policy +// WithPacketFilters adds packet filter rules based on policy. func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - b.addError(err) + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) return b } @@ -161,15 +191,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node.View(), filter), + "base": policy.ReduceFilterRules(node, filter), } return b } -// WithPeers adds full peer list with policy filtering (for full map response) -func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { - +// WithPeers adds full peer list with policy filtering (for full map response). +func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder { tailPeers, err := b.buildTailPeers(peers) if err != nil { b.addError(err) @@ -177,12 +206,12 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder { } b.resp.Peers = tailPeers + return b } -// WithPeerChanges adds changed peers with policy filtering (for incremental updates) -func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder { - +// WithPeerChanges adds changed peers with policy filtering (for incremental updates). +func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder { tailPeers, err := b.buildTailPeers(peers) if err != nil { b.addError(err) @@ -190,14 +219,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil } b.resp.PeersChanged = tailPeers + return b } -// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting -func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - return nil, err +// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting. +func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + return nil, errors.New("node not found") } filter, matchers := b.mapper.state.Filter() @@ -206,15 +236,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, // access each-other at all and remove them from the peers. var changedViews views.Slice[types.NodeView] if len(filter) > 0 { - changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers) + changedViews = policy.ReduceNodes(node, peers, matchers) } else { - changedViews = peers.ViewSlice() + changedViews = peers } tailPeers, err := tailNodes( changedViews, b.capVer, b.mapper.state, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) + return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) }, b.mapper.cfg) if err != nil { @@ -229,19 +259,20 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, return tailPeers, nil } -// WithPeerChangedPatch adds peer change patches +// WithPeerChangedPatch adds peer change patches. func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { b.resp.PeersChangedPatch = changes return b } -// WithPeersRemoved adds removed peer IDs +// WithPeersRemoved adds removed peer IDs. func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { var tailscaleIDs []tailcfg.NodeID for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } b.resp.PeersRemoved = tailscaleIDs + return b } @@ -251,11 +282,7 @@ func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { return nil, multierr.New(b.errs...) } if debugDumpMapResponsePath != "" { - node, err := b.mapper.state.GetNodeByID(b.nodeID) - if err != nil { - return nil, err - } - writeDebugMapResponse(b.resp, node) + writeDebugMapResponse(b.resp, b.debugType, b.nodeID) } return b.resp, nil diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 59c92e24..bb8340d0 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -19,6 +19,7 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/views" ) const ( @@ -69,16 +70,18 @@ func newMapper( } func generateUserProfiles( - node *types.Node, - peers types.Nodes, + node types.NodeView, + peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { userMap := make(map[uint]*types.User) ids := make([]uint, 0, len(userMap)) - userMap[node.User.ID] = &node.User - ids = append(ids, node.User.ID) - for _, peer := range peers { - userMap[peer.User.ID] = &peer.User - ids = append(ids, peer.User.ID) + user := node.User() + userMap[user.ID] = &user + ids = append(ids, user.ID) + for _, peer := range peers.All() { + peerUser := peer.User() + userMap[peerUser.ID] = &peerUser + ids = append(ids, peerUser.ID) } slices.Sort(ids) @@ -95,7 +98,7 @@ func generateUserProfiles( func generateDNSConfig( cfg *types.Config, - node *types.Node, + node types.NodeView, ) *tailcfg.DNSConfig { if cfg.TailcfgDNSConfig == nil { return nil @@ -115,12 +118,12 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ - "device_name": []string{node.Hostname}, - "device_model": []string{node.Hostinfo.OS}, + "device_name": []string{node.Hostname()}, + "device_model": []string{node.Hostinfo().OS()}, } if len(node.IPs()) > 0 { @@ -138,10 +141,7 @@ func (m *mapper) fullMapResponse( capVer tailcfg.CapabilityVersion, messages ...string, ) (*tailcfg.MapResponse, error) { - peers, err := m.listPeers(nodeID) - if err != nil { - return nil, err - } + peers := m.state.ListPeers(nodeID) return m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). @@ -183,10 +183,7 @@ func (m *mapper) peerChangeResponse( capVer tailcfg.CapabilityVersion, changedNodeID types.NodeID, ) (*tailcfg.MapResponse, error) { - peers, err := m.listPeers(nodeID, changedNodeID) - if err != nil { - return nil, err - } + peers := m.state.ListPeers(nodeID, changedNodeID) return m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). @@ -208,7 +205,8 @@ func (m *mapper) peerRemovedResponse( func writeDebugMapResponse( resp *tailcfg.MapResponse, - node *types.Node, + t debugType, + nodeID types.NodeID, ) { body, err := json.MarshalIndent(resp, "", " ") if err != nil { @@ -236,25 +234,6 @@ func writeDebugMapResponse( } } -// listPeers returns peers of node, regardless of any Policy or if the node is expired. -// If no peer IDs are given, all peers are returned. -// If at least one peer ID is given, only these peer nodes will be returned. -func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - peers, err := m.state.ListPeers(nodeID, peerIDs...) - if err != nil { - return nil, err - } - - // TODO(kradalby): Add back online via batcher. This was removed - // to avoid a circular dependency between the mapper and the notification. - for _, peer := range peers { - online := m.batcher.IsConnected(peer.ID) - peer.IsOnline = &online - } - - return peers, nil -} - // routeFilterFunc is a function that takes a node ID and returns a list of // netip.Prefixes that are allowed for that node. It is used to filter routes // from the primary route manager to the node. diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 198ba6c4..b801f7dd 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) { &types.Config{ TailcfgDNSConfig: &dnsConfigOrig, }, - nodeInShared1, + nodeInShared1.View(), ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 9729301d..3a518d94 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -133,13 +133,12 @@ func tailNode( tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} } - if !node.IsOnline().Valid() || !node.IsOnline().Get() { - // LastSeen is only set when node is - // not connected to the control server. - if node.LastSeen().Valid() { - lastSeen := node.LastSeen().Get() - tNode.LastSeen = &lastSeen - } + // Set LastSeen only for offline nodes to avoid confusing Tailscale clients + // during rapid reconnection cycles. Online nodes should not have LastSeen set + // as this can make clients interpret them as "not online" despite Online=true. + if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() { + lastSeen := node.LastSeen().Get() + tNode.LastSeen = &lastSeen } return &tNode, nil diff --git a/hscontrol/noise.go b/hscontrol/noise.go index db39992e..bb59fea6 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -13,7 +13,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "golang.org/x/net/http2" - "gorm.io/gorm" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/tailcfg" @@ -296,16 +295,11 @@ func (ns *noiseServer) NoiseRegistrationHandler( // getAndValidateNode retrieves the node from the database using the NodeKey // and validates that it matches the MachineKey from the Noise session. func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { - node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) - } - return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil) + nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) + if !ok { + return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) } - nv := node.View() - // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. if ns.machineKey != nv.MachineKey() { return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 68361cae..021a6272 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -281,7 +281,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( util.LogErr(err, "could not get userinfo; only using claims from id token") } - // The user claims are now updated from the the userinfo endpoint so we can verify the user a + // The user claims are now updated from the userinfo endpoint so we can verify the user // against allowed emails, email domains, and groups. if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { httpError(writer, err) @@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims) + user, c, err := a.createOrUpdateUserFromClaim(&claims) if err != nil { log.Error(). Err(err). @@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } // Send policy update notifications if needed - if policyChanged { - a.h.Change(change.PolicyChange()) - } + a.h.Change(c) // TODO(kradalby): Is this comment right? // If the node exists, then the node should be reauthenticated, @@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( claims *types.OIDCClaims, -) (*types.User, bool, error) { +) (*types.User, change.ChangeSet, error) { var user *types.User var err error var newUser bool - var policyChanged bool + var c change.ChangeSet user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, false, fmt.Errorf("creating or updating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err) } // if the user is still not found, create a new empty user. @@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( user.FromClaim(claims) if newUser { - user, policyChanged, err = a.h.state.CreateUser(*user) + user, c, err = a.h.state.CreateUser(*user) if err != nil { - return nil, false, fmt.Errorf("creating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } } else { - _, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { + _, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { *u = *user return nil }) if err != nil { - return nil, false, fmt.Errorf("updating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("updating user: %w", err) } } - return user, policyChanged, nil + return user, c, nil } func (a *AuthProviderOIDC) handleRegistration( diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 52457c9b..6a74e59f 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "github.com/samber/lo" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -138,39 +139,74 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf return ret } -// AutoApproveRoutes approves any route that can be autoapproved from -// the nodes perspective according to the given policy. -// It reports true if any routes were approved. -// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes. -func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool { +// ApproveRoutesWithPolicy checks if the node can approve the announced routes +// and returns the new list of approved routes. +// The approved routes will include: +// 1. ALL previously approved routes (regardless of whether they're still advertised) +// 2. New routes from announcedRoutes that can be auto-approved by policy +// This ensures that: +// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes) +// - New routes can be auto-approved according to policy +// - Routes can only be removed by explicit admin action (not by auto-approval). +func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) { if pm == nil { - return false + return currentApproved, false } - nodeView := node.View() - var newApproved []netip.Prefix - for _, route := range nodeView.AnnouncedRoutes() { - if pm.NodeCanApproveRoute(nodeView, route) { + + // Start with ALL currently approved routes - we never remove approved routes + newApproved := make([]netip.Prefix, len(currentApproved)) + copy(newApproved, currentApproved) + + // Then, check for new routes that can be auto-approved + for _, route := range announcedRoutes { + // Skip if already approved + if slices.Contains(newApproved, route) { + continue + } + + // Check if this new route can be auto-approved by policy + canApprove := pm.NodeCanApproveRoute(nv, route) + if canApprove { newApproved = append(newApproved, route) } } - // Only modify ApprovedRoutes if we have new routes to approve. - // This prevents clearing existing approved routes when nodes - // temporarily don't have announced routes during policy changes. - if len(newApproved) > 0 { - combined := append(newApproved, node.ApprovedRoutes...) - tsaddr.SortPrefixes(combined) - combined = slices.Compact(combined) - combined = lo.Filter(combined, func(route netip.Prefix, index int) bool { - return route.IsValid() - }) + // Sort and deduplicate + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) + newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { + return route.IsValid() + }) - // Only update if the routes actually changed - if !slices.Equal(node.ApprovedRoutes, combined) { - node.ApprovedRoutes = combined - return true + // Sort the current approved for comparison + sortedCurrent := make([]netip.Prefix, len(currentApproved)) + copy(sortedCurrent, currentApproved) + tsaddr.SortPrefixes(sortedCurrent) + + // Only update if the routes actually changed + if !slices.Equal(sortedCurrent, newApproved) { + // Log what changed + var added, kept []netip.Prefix + for _, route := range newApproved { + if !slices.Contains(sortedCurrent, route) { + added = append(added, route) + } else { + kept = append(kept, route) + } } + + if len(added) > 0 { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.added", util.PrefixesToString(added)). + Strs("routes.kept", util.PrefixesToString(kept)). + Int("routes.total", len(newApproved)). + Msg("Routes auto-approved by policy") + } + + return newApproved, true } - return false + return newApproved, false } diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go new file mode 100644 index 00000000..6c0908b9 --- /dev/null +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -0,0 +1,339 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/types/key" + "tailscale.com/types/ptr" + "tailscale.com/types/views" +) + +func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { + user1 := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser@", + } + user2 := types.User{ + Model: gorm.Model{ID: 2}, + Name: "otheruser@", + } + users := []types.User{user1, user2} + + node1 := &types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test-node", + UserID: user1.ID, + User: user1, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ForcedTags: []string{"tag:test"}, + } + + node2 := &types.Node{ + ID: 2, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "other-node", + UserID: user2.ID, + User: user2, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + } + + // Create a policy that auto-approves specific routes + policyJSON := `{ + "groups": { + "group:test": ["testuser@"] + }, + "tagOwners": { + "tag:test": ["testuser@"] + }, + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + } + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["testuser@", "tag:test"], + "10.1.0.0/24": ["testuser@"], + "10.2.0.0/24": ["testuser@"], + "192.168.0.0/24": ["tag:test"] + } + } + }` + + pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) + assert.NoError(t, err) + + tests := []struct { + name string + node *types.Node + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + description string + }{ + { + name: "previously_approved_route_no_longer_advertised_should_remain", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Should still be here! + }, + wantChanged: false, + description: "Previously approved routes should never be removed even when no longer advertised", + }, + { + name: "add_new_auto_approved_route_keeps_old_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8) + netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept + }, + wantChanged: true, + description: "New auto-approved routes should be added while keeping old approved routes", + }, + { + name: "no_announced_routes_keeps_all_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + description: "All approved routes should remain when no routes are announced", + }, + { + name: "no_changes_when_announced_equals_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + description: "No changes should occur when announced routes match approved routes", + }, + { + name: "auto_approve_multiple_new_routes", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) + netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved + netip.MustParsePrefix("172.16.0.0/24"), // Original kept + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + }, + wantChanged: true, + description: "Multiple new routes should be auto-approved while keeping existing approved routes", + }, + { + name: "node_without_permission_no_auto_approval", + node: node2, // Different node without the tag + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route + }, + wantChanged: false, + description: "Routes should not be auto-approved for nodes without proper permissions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) + + // Sort for comparison since ApproveRoutesWithPolicy sorts the results + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) + + // Verify that all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should never happen", prevRoute) + } + }) + } +} + +func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { + // Create a basic policy for edge case testing + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.1.0.0/24": ["test@"], + }, + }, +}` + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_policy_manager", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "nil_announced_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: nil, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "duplicate_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_slices", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{}, + wantApproved: []netip.Prefix{}, + wantChanged: false, + }, + } + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + // Create policy manager or use nil if specified + var pm PolicyManager + var err error + if tt.name != "nil_policy_manager" { + pm, err = pmf(users, nodes.ViewSlice()) + assert.NoError(t, err) + } else { + pm = nil + } + + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch") + + // Handle nil vs empty slice comparison + if tt.wantApproved == nil { + assert.Nil(t, gotApproved, "expected nil approved routes") + } else { + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") + } + }) + } + } +} diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go new file mode 100644 index 00000000..610ce7b1 --- /dev/null +++ b/hscontrol/policy/policy_route_approval_test.go @@ -0,0 +1,361 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { + // Test policy that allows specific routes to be auto-approved + aclPolicy := ` +{ + "groups": { + "group:admins": ["test@"], + }, + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/24": ["test@"], + "192.168.0.0/24": ["group:admins"], + "172.16.0.0/16": ["tag:approved"], + }, + }, + "tagOwners": { + "tag:approved": ["test@"], + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + nodeHostname string + nodeUser string + nodeTags []string + wantApproved []netip.Prefix + wantChanged bool + wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result + }{ + { + name: "previously_approved_route_no_longer_advertised_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Should remain! + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed + }, + { + name: "add_new_auto_approved_route_keeps_existing", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Still advertised + netip.MustParsePrefix("192.168.0.0/24"), // New route + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group + }, + wantChanged: true, + }, + { + name: "no_announced_routes_keeps_all_approved", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced anymore + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + }, + { + name: "manually_approved_route_not_in_policy_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved + }, + wantChanged: true, + }, + { + name: "tagged_node_gets_tag_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route + }, + nodeUser: "test", + nodeTags: []string{"tag:approved"}, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved + netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + }, + wantChanged: true, + }, + { + name: "complex_scenario_multiple_changes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised + netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable + netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) + netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised + }, + wantChanged: true, + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: tt.nodeUser, + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: tt.nodeHostname, + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + ForcedTags: tt.nodeTags, + } + nodes := types.Nodes{&node} + + // Create policy manager + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, pm) + + // Test ApproveRoutesWithPolicy + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + // Check change flag + assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch") + + // Check approved routes match expected + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Logf("Want: %v", tt.wantApproved) + t.Logf("Got: %v", gotApproved) + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + + // Verify all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should NEVER happen", prevRoute) + } + + // Verify no routes were incorrectly removed + for _, removedRoute := range tt.wantRemovedRoutes { + assert.NotContains(t, gotApproved, removedRoute, + "route %s should have been removed but wasn't", removedRoute) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["test@"], + }, + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_current_approved", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "duplicate_routes_handled", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, // Duplicates are removed, so it's a change + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + assert.Equal(t, tt.wantChanged, gotChanged) + + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + + currentApproved := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + } + announcedRoutes := []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + } + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + User: user, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: currentApproved, + } + + // With nil policy manager, should return current approved unchanged + gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes) + + assert.False(t, gotChanged) + assert.Equal(t, currentApproved, gotApproved) +} diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 5e332fd3..1e6fabf3 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, canApprove: false, }, + { + name: "policy-without-autoApprovers-section", + node: normalNode, + route: p("10.33.0.0/16"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admin"], + "dst": ["group:admin:*"] + }, + { + "action": "accept", + "src": ["group:admin"], + "dst": ["10.33.0.0/16:*"] + } + ] + }`, + canApprove: false, + }, } for _, tt := range tests { diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index de839770..5e7aa34b 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // The fast path is that a node requests to approve a prefix // where there is an exact entry, e.g. 10.0.0.0/8, then // check and return quickly - if _, ok := pm.autoApproveMap[route]; ok { - if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) { + if approvers, ok := pm.autoApproveMap[route]; ok { + canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains) + if canApprove { return true } } @@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr // Check if prefix is larger (so containing) and then overlaps // the route to see if the node can approve a subset of an autoapprover if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { - if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) { + canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains) + if canApprove { return true } } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 1833f060..4809257b 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -10,7 +10,6 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" @@ -112,6 +111,15 @@ func (m *mapSession) serve() { // This is the mechanism where the node gives us information about its // current configuration. // + // Process the MapRequest to update node state (endpoints, hostinfo, etc.) + c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + httpError(m.w, err) + return + } + + m.h.Change(c) + // If OmitPeers is true and Stream is false // then the server will let clients update their endpoints without // breaking existing long-polling (Stream == true) connections. @@ -122,14 +130,6 @@ func (m *mapSession) serve() { // the response and just wants a 200. // !req.stream && req.OmitPeers if m.isEndpointUpdate() { - c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req) - if err != nil { - httpError(m.w, err) - return - } - - m.h.Change(c) - m.w.WriteHeader(http.StatusOK) mapResponseEndpointUpdates.WithLabelValues("ok").Inc() } @@ -142,6 +142,8 @@ func (m *mapSession) serve() { func (m *mapSession) serveLongPoll() { m.beforeServeLongPoll() + log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected") + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -149,18 +151,38 @@ func (m *mapSession) serveLongPoll() { close(m.cancelCh) m.cancelChMu.Unlock() - // TODO(kradalby): This can likely be made more effective, but likely most - // nodes has access to the same routes, so it might not be a big deal. - disconnectChange, err := m.h.state.Disconnect(m.node) - if err != nil { - m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) + + // When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather). + // Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects. + // If it does reconnect, the existing mapSession will be replaced and the node remains online. + // If it doesn't reconnect within the timeout, we mark it as offline. + // + // This avoids flapping nodes in the UI and unnecessary churn in the network. + // This is not my favourite solution, but it kind of works in our eventually consistent world. + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + disconnected := true + // Wait up to 10 seconds for the node to reconnect. + // 10 seconds was arbitrary chosen as a reasonable time to reconnect. + for range 10 { + if m.h.mapBatcher.IsConnected(m.node.ID) { + disconnected = false + break + } + <-ticker.C } - m.h.Change(disconnectChange) - m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter()) + if disconnected { + disconnectChanges, err := m.h.state.Disconnect(m.node.ID) + if err != nil { + m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + } - m.afterServeLongPoll() - m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) + m.h.Change(disconnectChanges...) + m.afterServeLongPoll() + m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) + } }() // Set up the client stream @@ -172,25 +194,25 @@ func (m *mapSession) serveLongPoll() { m.keepAliveTicker = time.NewTicker(m.keepAlive) - // Add node to batcher BEFORE sending Connect change to prevent race condition - // where the change is sent before the node is in the batcher's node map - if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil { - m.errf(err, "failed to add node to batcher") - // Send empty response to client to fail fast for invalid/non-existent nodes - select { - case m.ch <- &tailcfg.MapResponse{}: - default: - // Channel might be closed - } + // Process the initial MapRequest to update node state (endpoints, hostinfo, etc.) + // CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly + // synchronized. When nodes reconnect, they send their hostinfo with announced routes + // in the MapRequest. We need this data in NodeStore before Connect() sets up the + // primary routes, otherwise SubnetRoutes() returns empty and the node is removed + // from AvailableRoutes. + mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + m.errf(err, "failed to update node from initial MapRequest") return } - // Now send the Connect change - the batcher handles NodeCameOnline internally - // but we still need to update routes and other state-level changes - connectChange := m.h.state.Connect(m.node) - if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline { - m.h.Change(connectChange) - } + // Connect the node after its state has been updated. + // We send two separate change notifications because these are distinct operations: + // 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo) + // 2. Connect: marks the node online and recalculates primary routes based on the updated state + // While this results in two notifications, it ensures route data is synchronized before + // primary route selection occurs, which is critical for proper HA subnet router failover. + connectChanges := m.h.state.Connect(m.node.ID) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) @@ -235,6 +257,7 @@ func (m *mapSession) serveLongPoll() { mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) } mapResponseSent.WithLabelValues("ok", "keepalive").Inc() + m.resetKeepAlive() } } } diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go new file mode 100644 index 00000000..3fd50d26 --- /dev/null +++ b/hscontrol/state/node_store.go @@ -0,0 +1,403 @@ +package state + +import ( + "fmt" + "maps" + "strings" + "sync/atomic" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +const ( + batchSize = 10 + batchTimeout = 500 * time.Millisecond +) + +const ( + put = 1 + del = 2 + update = 3 +) + +const prometheusNamespace = "headscale" + +var ( + nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operations_total", + Help: "Total number of NodeStore operations", + }, []string{"operation"}) + nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operation_duration_seconds", + Help: "Duration of NodeStore operations", + Buckets: prometheus.DefBuckets, + }, []string{"operation"}) + nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_size", + Help: "Size of NodeStore write batches", + Buckets: []float64{1, 2, 5, 10, 20, 50, 100}, + }) + nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_duration_seconds", + Help: "Duration of NodeStore batch processing", + Buckets: prometheus.DefBuckets, + }) + nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_snapshot_build_duration_seconds", + Help: "Duration of NodeStore snapshot building from nodes", + Buckets: prometheus.DefBuckets, + }) + nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_nodes_total", + Help: "Total number of nodes in the NodeStore", + }) + nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_peers_calculation_duration_seconds", + Help: "Duration of peers calculation in NodeStore", + Buckets: prometheus.DefBuckets, + }) + nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_queue_depth", + Help: "Current depth of NodeStore write queue", + }) +) + +// NodeStore is a thread-safe store for nodes. +// It is a copy-on-write structure, replacing the "snapshot" +// when a change to the structure occurs. It is optimised for reads, +// and while batches are not fast, they are grouped together +// to do less of the expensive peer calculation if there are many +// changes rapidly. +// +// Writes will block until committed, while reads are never +// blocked. This means that the caller of a write operation +// is responsible for ensuring an update depending on a write +// is not issued before the write is complete. +type NodeStore struct { + data atomic.Pointer[Snapshot] + + peersFunc PeersFunc + writeQueue chan work +} + +func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore { + nodes := make(map[types.NodeID]types.Node, len(allNodes)) + for _, n := range allNodes { + nodes[n.ID] = *n + } + snap := snapshotFromNodes(nodes, peersFunc) + + store := &NodeStore{ + peersFunc: peersFunc, + } + store.data.Store(&snap) + + // Initialize node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + + return store +} + +// Snapshot is the representation of the current state of the NodeStore. +// It contains all nodes and their relationships. +// It is a copy-on-write structure, meaning that when a write occurs, +// a new Snapshot is created with the updated state, +// and replaces the old one atomically. +type Snapshot struct { + // nodesByID is the main source of truth for nodes. + nodesByID map[types.NodeID]types.Node + + // calculated from nodesByID + nodesByNodeKey map[key.NodePublic]types.NodeView + peersByNode map[types.NodeID][]types.NodeView + nodesByUser map[types.UserID][]types.NodeView + allNodes []types.NodeView +} + +// PeersFunc is a function that takes a list of nodes and returns a map +// with the relationships between nodes and their peers. +// This will typically be used to calculate which nodes can see each other +// based on the current policy. +type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView + +// work represents a single operation to be performed on the NodeStore. +type work struct { + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} +} + +// PutNode adds or updates a node in the store. +// If the node already exists, it will be replaced. +// If the node does not exist, it will be added. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) PutNode(n types.Node) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put")) + defer timer.ObserveDuration() + + work := work{ + op: put, + nodeID: n.ID, + node: n, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("put").Inc() +} + +// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. +type UpdateNodeFunc func(n *types.Node) + +// UpdateNode applies a function to modify a specific node in the store. +// This is a blocking operation that waits for the write to complete. +// This is analogous to a database "transaction", or, the caller should +// rather collect all data they want to change, and then call this function. +// Fewer calls are better. +// +// TODO(kradalby): Technically we could have a version of this that modifies the node +// in the current snapshot if _we know_ that the change will not affect the peer relationships. +// This is because the main nodesByID map contains the struct, and every other map is using a +// pointer to the underlying struct. The gotcha with this is that we will need to introduce +// a lock around the nodesByID map to ensure that no other writes are happening +// while we are modifying the node. Which mean we would need to implement read-write locks +// on all read operations. +func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update")) + defer timer.ObserveDuration() + + work := work{ + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("update").Inc() +} + +// DeleteNode removes a node from the store by its ID. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) DeleteNode(id types.NodeID) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete")) + defer timer.ObserveDuration() + + work := work{ + op: del, + nodeID: id, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("delete").Inc() +} + +// Start initializes the NodeStore and starts processing the write queue. +func (s *NodeStore) Start() { + s.writeQueue = make(chan work) + go s.processWrite() +} + +// Stop stops the NodeStore. +func (s *NodeStore) Stop() { + close(s.writeQueue) +} + +// processWrite processes the write queue in batches. +func (s *NodeStore) processWrite() { + c := time.NewTicker(batchTimeout) + defer c.Stop() + batch := make([]work, 0, batchSize) + + for { + select { + case w, ok := <-s.writeQueue: + if !ok { + // Channel closed, apply any remaining batch and exit + if len(batch) != 0 { + s.applyBatch(batch) + } + return + } + batch = append(batch, w) + if len(batch) >= batchSize { + s.applyBatch(batch) + batch = batch[:0] + c.Reset(batchTimeout) + } + case <-c.C: + if len(batch) != 0 { + s.applyBatch(batch) + batch = batch[:0] + } + c.Reset(batchTimeout) + } + } +} + +// applyBatch applies a batch of work to the node store. +// This means that it takes a copy of the current nodes, +// then applies the batch of operations to that copy, +// runs any precomputation needed (like calculating peers), +// and finally replaces the snapshot in the store with the new one. +// The replacement of the snapshot is atomic, ensuring that reads +// are never blocked by writes. +// Each write item is blocked until the batch is applied to ensure +// the caller knows the operation is complete and do not send any +// updates that are dependent on a read that is yet to be written. +func (s *NodeStore) applyBatch(batch []work) { + timer := prometheus.NewTimer(nodeStoreBatchDuration) + defer timer.ObserveDuration() + + nodeStoreBatchSize.Observe(float64(len(batch))) + + nodes := make(map[types.NodeID]types.Node) + maps.Copy(nodes, s.data.Load().nodesByID) + + for _, w := range batch { + switch w.op { + case put: + nodes[w.nodeID] = w.node + case update: + // Update the specific node identified by nodeID + if n, exists := nodes[w.nodeID]; exists { + w.updateFn(&n) + nodes[w.nodeID] = n + } + case del: + delete(nodes, w.nodeID) + } + } + + newSnap := snapshotFromNodes(nodes, s.peersFunc) + s.data.Store(&newSnap) + + // Update node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + + for _, w := range batch { + close(w.result) + } +} + +// snapshotFromNodes creates a new Snapshot from the provided nodes. +// It builds a lot of "indexes" to make lookups fast for datasets we +// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser. +// This is not a fast operation, it is the "slow" part of our copy-on-write +// structure, but it allows us to have fast reads and efficient lookups. +func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot { + timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration) + defer timer.ObserveDuration() + + allNodes := make([]types.NodeView, 0, len(nodes)) + for _, n := range nodes { + allNodes = append(allNodes, n.View()) + } + + newSnap := Snapshot{ + nodesByID: nodes, + allNodes: allNodes, + nodesByNodeKey: make(map[key.NodePublic]types.NodeView), + + // peersByNode is most likely the most expensive operation, + // it will use the list of all nodes, combined with the + // current policy to precalculate which nodes are peers and + // can see each other. + peersByNode: func() map[types.NodeID][]types.NodeView { + peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) + defer peersTimer.ObserveDuration() + return peersFunc(allNodes) + }(), + nodesByUser: make(map[types.UserID][]types.NodeView), + } + + // Build nodesByUser and nodesByNodeKey maps + for _, n := range nodes { + nodeView := n.View() + newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView) + newSnap.nodesByNodeKey[n.NodeKey] = nodeView + } + + return newSnap +} + +// GetNode retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get").Inc() + + n, exists := s.data.Load().nodesByID[id] + if !exists { + return types.NodeView{}, false + } + + return n.View(), true +} + +// GetNodeByNodeKey retrieves a node by its NodeKey. +func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView { + return s.data.Load().nodesByNodeKey[nodeKey] +} + +// ListNodes returns a slice of all nodes in the store. +func (s *NodeStore) ListNodes() views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list").Inc() + + return views.SliceOf(s.data.Load().allNodes) +} + +// ListPeers returns a slice of all peers for a given node ID. +func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_peers").Inc() + + return views.SliceOf(s.data.Load().peersByNode[id]) +} + +// ListNodesByUser returns a slice of all nodes for a given user ID. +func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_by_user").Inc() + + return views.SliceOf(s.data.Load().nodesByUser[uid]) +} diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go new file mode 100644 index 00000000..9666e5db --- /dev/null +++ b/hscontrol/state/node_store_test.go @@ -0,0 +1,501 @@ +package state + +import ( + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/key" +) + +func TestSnapshotFromNodes(t *testing.T) { + tests := []struct { + name string + setupFunc func() (map[types.NodeID]types.Node, PeersFunc) + validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) + }{ + { + name: "empty nodes", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := make(map[types.NodeID]types.Node) + peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + return make(map[types.NodeID][]types.NodeView) + } + + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "single node", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + } + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers + assert.Len(t, snapshot.nodesByUser[1], 1) + assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID()) + }, + }, + { + name: "multiple nodes same user", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 1, "user1", "node2"), + } + + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Each node sees the other as peer (but not itself) + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "multiple nodes different users", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 1, "user1", "node3"), + } + + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // Each node should have 2 peers (all others, but not itself) + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2 + }, + }, + { + name: "odd-even peers filtering", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 3, "user3", "node3"), + 4: createTestNode(4, 4, "user4", "node4"), + } + peersFunc := oddEvenPeersFunc + + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 4) + assert.Len(t, snapshot.allNodes, 4) + assert.Len(t, snapshot.peersByNode, 4) + assert.Len(t, snapshot.nodesByUser, 4) + + // Odd nodes should only see other odd nodes as peers + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // Even nodes should only see other even nodes as peers + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nodes, peersFunc := tt.setupFunc() + snapshot := snapshotFromNodes(nodes, peersFunc) + tt.validate(t, nodes, snapshot) + }) + } +} + +// Helper functions + +func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node { + now := time.Now() + machineKey := key.NewMachine() + nodeKey := key.NewNode() + discoKey := key.NewDisco() + + ipv4 := netip.MustParseAddr("100.64.0.1") + ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1") + + return types.Node{ + ID: nodeID, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + DiscoKey: discoKey.Public(), + Hostname: hostname, + GivenName: hostname, + UserID: userID, + User: types.User{ + Name: username, + DisplayName: username, + }, + RegisterMethod: "test", + IPv4: &ipv4, + IPv6: &ipv6, + CreatedAt: now, + UpdatedAt: now, + } +} + +// Peer functions + +func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + for _, n := range nodes { + if n.ID() != node.ID() { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + + return ret +} + +func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 + + for _, n := range nodes { + if n.ID() == node.ID() { + continue + } + + peerIsOdd := n.ID()%2 == 1 + + // Only add peer if both are odd or both are even + if nodeIsOdd == peerIsOdd { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + + return ret +} + +func TestNodeStoreOperations(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) *NodeStore + steps []testStep + }{ + { + name: "create empty store and add single node", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify empty store", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "add first node", + action: func(store *NodeStore) { + node := createTestNode(1, 1, "user1", "node1") + store.PutNode(node) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, node.ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no peers yet + assert.Len(t, snapshot.nodesByUser[1], 1) + }, + }, + }, + }, + { + name: "create store with initial node and add more", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + initialNodes := types.Nodes{&node1} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + assert.Empty(t, snapshot.peersByNode[1]) + }, + }, + { + name: "add second node same user", + action: func(store *NodeStore) { + node2 := createTestNode(2, 1, "user1", "node2") + store.PutNode(node2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Now both nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "add third node different user", + action: func(store *NodeStore) { + node3 := createTestNode(3, 2, "user2", "node3") + store.PutNode(node3) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // All nodes should see the other 2 as peers + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3 + }, + }, + }, + }, + { + name: "test node deletion", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + node3 := createTestNode(3, 2, "user2", "node3") + initialNodes := types.Nodes{&node1, &node2, &node3} + + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial 3 nodes", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + }, + }, + { + name: "delete middle node", + action: func(store *NodeStore) { + store.DeleteNode(2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 2) + + // Node 2 should be gone + assert.NotContains(t, snapshot.nodesByID, types.NodeID(2)) + + // Remaining nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // User groupings updated + assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3 + }, + }, + { + name: "delete all remaining nodes", + action: func(store *NodeStore) { + store.DeleteNode(1) + store.DeleteNode(3) + + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + }, + }, + { + name: "test node updates", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial hostnames", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "update node hostname", + action: func(store *NodeStore) { + store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "updated-node1" + n.GivenName = "updated-node1" + }) + + snapshot := store.data.Load() + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged + + // Peers should still work correctly + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Len(t, snapshot.peersByNode[2], 1) + }, + }, + }, + }, + { + name: "test with odd-even peers filtering", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, oddEvenPeersFunc) + }, + steps: []testStep{ + { + name: "add nodes with odd-even filtering", + action: func(store *NodeStore) { + // Add nodes in sequence + store.PutNode(createTestNode(1, 1, "user1", "node1")) + store.PutNode(createTestNode(2, 2, "user2", "node2")) + store.PutNode(createTestNode(3, 3, "user3", "node3")) + store.PutNode(createTestNode(4, 4, "user4", "node4")) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 4) + + // Verify odd-even peer relationships + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + { + name: "delete odd node and verify even nodes unaffected", + action: func(store *NodeStore) { + store.DeleteNode(1) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + + // Node 3 (odd) should now have no peers + assert.Empty(t, snapshot.peersByNode[3]) + + // Even nodes should still see each other + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := tt.setupFunc(t) + store.Start() + defer store.Stop() + + for _, step := range tt.steps { + t.Run(step.name, func(t *testing.T) { + step.action(store) + }) + } + }) + } +} + +type testStep struct { + name string + action func(store *NodeStore) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 0a743184..958a2f52 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1,5 +1,6 @@ // Package state provides core state management for Headscale, coordinating // between subsystems like database, IP allocation, policy management, and DERP routing. + package state import ( @@ -9,6 +10,8 @@ import ( "io" "net/netip" "os" + "slices" + "sync" "sync/atomic" "time" @@ -21,12 +24,13 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" - xslices "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" + "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -49,6 +53,9 @@ type State struct { // cfg holds the current Headscale configuration cfg *types.Config + // nodeStore provides an in-memory cache for nodes. + nodeStore *NodeStore + // subsystem keeping state // db provides persistent storage and database operations db *hsdb.HSDatabase @@ -90,6 +97,12 @@ func NewState(cfg *types.Config) (*State, error) { if err != nil { return nil, fmt.Errorf("loading nodes: %w", err) } + + // On startup, all nodes should be marked as offline until they reconnect + // This ensures we don't have stale online status from previous runs + for _, node := range nodes { + node.IsOnline = ptr.To(false) + } users, err := db.ListUsers() if err != nil { return nil, fmt.Errorf("loading users: %w", err) @@ -105,7 +118,13 @@ func NewState(cfg *types.Config) (*State, error) { return nil, fmt.Errorf("init policy manager: %w", err) } - s := &State{ + nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + _, matchers := polMan.Filter() + return policy.BuildPeerMap(views.SliceOf(nodes), matchers) + }) + nodeStore.Start() + + return &State{ cfg: cfg, db: db, @@ -113,13 +132,14 @@ func NewState(cfg *types.Config) (*State, error) { polMan: polMan, registrationCache: registrationCache, primaryRoutes: routes.New(), - } - - return s, nil + nodeStore: nodeStore, + }, nil } // Close gracefully shuts down the State instance and releases all resources. func (s *State) Close() error { + s.nodeStore.Stop() + if err := s.db.Close(); err != nil { return fmt.Errorf("closing database: %w", err) } @@ -180,69 +200,78 @@ func (s *State) DERPMap() tailcfg.DERPMapView { // ReloadPolicy reloads the access control policy and triggers auto-approval if changed. // Returns true if the policy changed. -func (s *State) ReloadPolicy() (bool, error) { +func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { pol, err := policyBytes(s.db, s.cfg) if err != nil { - return false, fmt.Errorf("loading policy: %w", err) + return nil, fmt.Errorf("loading policy: %w", err) } - changed, err := s.polMan.SetPolicy(pol) + policyChanged, err := s.polMan.SetPolicy(pol) if err != nil { - return false, fmt.Errorf("setting policy: %w", err) + return nil, fmt.Errorf("setting policy: %w", err) } - if changed { - err := s.autoApproveNodes() - if err != nil { - return false, fmt.Errorf("auto approving nodes: %w", err) - } + cs := []change.ChangeSet{change.PolicyChange()} + + // Always call autoApproveNodes during policy reload, regardless of whether + // the policy content has changed. This ensures that routes are re-evaluated + // when they might have been manually disabled but could now be auto-approved + // with the current policy. + rcs, err := s.autoApproveNodes() + if err != nil { + return nil, fmt.Errorf("auto approving nodes: %w", err) } - return changed, nil -} + // TODO(kradalby): These changes can probably be safely ignored. + // If the PolicyChange is happening, that will lead to a full update + // meaning that we do not need to send individual route changes. + cs = append(cs, rcs...) -// AutoApproveNodes processes pending nodes and auto-approves those meeting policy criteria. -func (s *State) AutoApproveNodes() error { - return s.autoApproveNodes() + if len(rcs) > 0 || policyChanged { + log.Info(). + Bool("policy.changed", policyChanged). + Int("route.changes", len(rcs)). + Int("total.changes", len(cs)). + Msg("Policy reload completed with changes") + } + + return cs, nil } // CreateUser creates a new user and updates the policy manager. -// Returns the created user, whether policies changed, and any error. -func (s *State) CreateUser(user types.User) (*types.User, bool, error) { +// Returns the created user, change set, and any error. +func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() if err := s.db.DB.Save(&user).Error; err != nil { - return nil, false, fmt.Errorf("creating user: %w", err) + return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerUsers() + c, err := s.updatePolicyManagerUsers() if err != nil { // Log the error but don't fail the user creation - return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err) + return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err) } // Even if the policy manager doesn't detect a filter change, SSH policies // might now be resolvable when they weren't before. If there are existing // nodes, we should send a policy change to ensure they get updated SSH policies. - if !policyChanged { - nodes, err := s.ListNodes() - if err == nil && len(nodes) > 0 { - policyChanged = true - } + // TODO(kradalby): detect this, or rebuild all SSH policies so we can determine + // this upstream. + if c.Empty() { + c = change.PolicyChange() } - log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated") + log.Info().Str("user.name", user.Name).Msg("User created") - // TODO(kradalby): implement the user in-memory cache - - return &user, policyChanged, nil + return &user, c, nil } // UpdateUser modifies an existing user using the provided update function within a transaction. -// Returns the updated user, whether policies changed, and any error. -func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, bool, error) { +// Returns the updated user, change set, and any error. +func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) { s.mu.Lock() defer s.mu.Unlock() @@ -263,18 +292,18 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return user, nil }) if err != nil { - return nil, false, err + return nil, change.EmptySet, err } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerUsers() + c, err := s.updatePolicyManagerUsers() if err != nil { - return user, false, fmt.Errorf("failed to update policy manager after user update: %w", err) + return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err) } - // TODO(kradalby): implement the user in-memory cache + // TODO(kradalby): We might want to update nodestore with the user data - return user, policyChanged, nil + return user, c, nil } // DeleteUser permanently removes a user and all associated data (nodes, API keys, etc). @@ -284,7 +313,7 @@ func (s *State) DeleteUser(userID types.UserID) error { } // RenameUser changes a user's name. The new name must be unique. -func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, bool, error) { +func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) { return s.UpdateUser(userID, func(user *types.User) error { user.Name = newName return nil @@ -316,33 +345,16 @@ func (s *State) ListAllUsers() ([]types.User, error) { return s.db.ListUsers() } -// CreateNode creates a new node and updates the policy manager. -// Returns the created node, whether policies changed, and any error. -func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if err := s.db.DB.Save(node).Error; err != nil { - return nil, false, fmt.Errorf("creating node: %w", err) - } - - // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() - if err != nil { - return node, false, fmt.Errorf("failed to update policy manager after node creation: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - return node, policyChanged, nil -} - // updateNodeTx performs a database transaction to update a node and refresh the policy manager. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) { +// IMPORTANT: This function does NOT update the NodeStore. The caller MUST update the NodeStore +// BEFORE calling this function with the EXACT same changes that the database update will make. +// This ensures the NodeStore is the source of truth for the batcher and maintains consistency. +// Returns error only; callers should get the updated NodeView from NodeStore to maintain consistency. +func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) error { s.mu.Lock() defer s.mu.Unlock() - node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + _, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { if err := updateFn(tx); err != nil { return nil, err } @@ -358,166 +370,283 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err return node, nil }) - if err != nil { - return nil, change.EmptySet, err + return err +} + +// persistNodeToDB saves the current state of a node from NodeStore to the database. +// CRITICAL: This function MUST get the latest node from NodeStore to ensure consistency. +func (s *State) persistNodeToDB(nodeID types.NodeID) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // CRITICAL: Always get the latest node from NodeStore to ensure we save the current state + node, found := s.nodeStore.GetNode(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + nodePtr := node.AsStruct() + + if err := s.db.DB.Save(nodePtr).Error; err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err) } // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() + c, err := s.updatePolicyManagerNodes() if err != nil { - return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) } - // TODO(kradalby): implement the node in-memory cache - - var c change.ChangeSet - if policyChanged { - c = change.PolicyChange() - } else { - // Basic node change without specific details since this is a generic update - c = change.NodeAdded(node.ID) + if c.Empty() { + c = change.NodeAdded(node.ID()) } return node, c, nil } -// SaveNode persists an existing node to the database and updates the policy manager. -func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) { + // Update NodeStore first + nodePtr := node.AsStruct() - if err := s.db.DB.Save(node).Error; err != nil { - return nil, change.EmptySet, fmt.Errorf("saving node: %w", err) - } + s.nodeStore.PutNode(*nodePtr) - // Check if policy manager needs updating - policyChanged, err := s.updatePolicyManagerNodes() - if err != nil { - return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) - } - - // TODO(kradalby): implement the node in-memory cache - - if policyChanged { - return node, change.PolicyChange(), nil - } - - return node, change.EmptySet, nil + // Then save to database + return s.persistNodeToDB(node.ID()) } // DeleteNode permanently removes a node and cleans up associated resources. // Returns whether policies changed and any error. This operation is irreversible. -func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) { - err := s.db.DeleteNode(node) +func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { + s.nodeStore.DeleteNode(node.ID()) + + err := s.db.DeleteNode(node.AsStruct()) if err != nil { return change.EmptySet, err } - c := change.NodeRemoved(node.ID) + c := change.NodeRemoved(node.ID()) // Check if policy manager needs updating after node deletion - policyChanged, err := s.updatePolicyManagerNodes() + policyChange, err := s.updatePolicyManagerNodes() if err != nil { return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err) } - if policyChanged { - c = change.PolicyChange() + if !policyChange.Empty() { + c = policyChange } return c, nil } -func (s *State) Connect(node *types.Node) change.ChangeSet { - c := change.NodeOnline(node.ID) - routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) +// Connect marks a node as connected and updates its primary routes in the state. +func (s *State) Connect(id types.NodeID) []change.ChangeSet { + // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification + // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, + // the NodeStore already reflects the correct online status for full map generation. + // now := time.Now() + s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.IsOnline = ptr.To(true) + // n.LastSeen = ptr.To(now) + }) + c := []change.ChangeSet{change.NodeOnline(id)} + + // Get fresh node data from NodeStore after the online status update + node, found := s.GetNodeByID(id) + if !found { + return nil + } + + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") + + // Use the node's current routes for primary route update + // SubnetRoutes() returns only the intersection of announced AND approved routes + // We MUST use SubnetRoutes() to maintain the security model + routeChange := s.primaryRoutes.SetRoutes(id, node.SubnetRoutes()...) if routeChange { - c = change.NodeAdded(node.ID) + c = append(c, change.NodeAdded(id)) } return c } -func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) { - c := change.NodeOffline(node.ID) +// Disconnect marks a node as disconnected and updates its primary routes in the state. +func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { + now := time.Now() - _, _, err := s.SetLastSeen(node.ID, time.Now()) + // Get node info before updating for logging + node, found := s.GetNodeByID(id) + var nodeName string + if found { + nodeName = node.Hostname() + } + + s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.LastSeen = ptr.To(now) + // NodeStore is the source of truth for all node state including online status. + n.IsOnline = ptr.To(false) + }) + + if found { + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Node disconnected") + } + + err := s.updateNodeTx(id, func(tx *gorm.DB) error { + // Update last_seen in the database + // Note: IsOnline is managed only in NodeStore (marked with gorm:"-"), not persisted to database + return hsdb.SetLastSeen(tx, id, now) + }) if err != nil { - return c, fmt.Errorf("disconnecting node: %w", err) + // Log error but don't fail the disconnection - NodeStore is already updated + // and we need to send change notifications to peers + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update last seen in database") } - if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange { - c = change.PolicyChange() + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + // Log error but continue - disconnection must proceed + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update policy manager after node disconnect") + c = change.EmptySet } - // TODO(kradalby): This node should update the in memory state - return c, nil + // The node is disconnecting so make sure that none of the routes it + // announced are served to any nodes. + routeChange := s.primaryRoutes.SetRoutes(id) + + cs := []change.ChangeSet{change.NodeOffline(id), c} + + // If we have a policy change or route change, return that as it's more comprehensive + // Otherwise, return the NodeOffline change to ensure nodes are notified + if c.IsFull() || routeChange { + cs = append(cs, change.PolicyChange()) + } + + return cs, nil } // GetNodeByID retrieves a node by ID. -func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) { - return s.db.GetNodeByID(nodeID) -} - -// GetNodeViewByID retrieves a node view by ID. -func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) { - node, err := s.db.GetNodeByID(nodeID) - if err != nil { - return types.NodeView{}, err - } - - return node.View(), nil +// GetNodeByID retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, bool) { + return s.nodeStore.GetNode(nodeID) } // GetNodeByNodeKey retrieves a node by its Tailscale public key. -func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { - return s.db.GetNodeByNodeKey(nodeKey) +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) { + return s.nodeStore.GetNodeByNodeKey(nodeKey) } -// GetNodeViewByNodeKey retrieves a node view by its Tailscale public key. -func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) { - node, err := s.db.GetNodeByNodeKey(nodeKey) - if err != nil { - return types.NodeView{}, err - } - - return node.View(), nil +// GetNodeByMachineKey retrieves a node by its machine key. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) { + return s.nodeStore.GetNodeByMachineKey(machineKey) } // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. -func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { +func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] { if len(nodeIDs) == 0 { - return s.db.ListNodes() + return s.nodeStore.ListNodes() } - return s.db.ListNodes(nodeIDs...) + // Filter nodes by the requested IDs + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(nodeIDs)) + for _, id := range nodeIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes) } // ListNodesByUser retrieves all nodes belonging to a specific user. -func (s *State) ListNodesByUser(userID types.UserID) (types.Nodes, error) { - return hsdb.Read(s.db.DB, func(rx *gorm.DB) (types.Nodes, error) { - return hsdb.ListNodesByUser(rx, userID) - }) +func (s *State) ListNodesByUser(userID types.UserID) views.Slice[types.NodeView] { + return s.nodeStore.ListNodesByUser(userID) } // ListPeers retrieves nodes that can communicate with the specified node based on policy. -func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - return s.db.ListPeers(nodeID, peerIDs...) +func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) views.Slice[types.NodeView] { + if len(peerIDs) == 0 { + return s.nodeStore.ListPeers(nodeID) + } + + // For specific peerIDs, filter from all nodes + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(peerIDs)) + for _, id := range peerIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes) } // ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system. -func (s *State) ListEphemeralNodes() (types.Nodes, error) { - return s.db.ListEphemeralNodes() +func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { + allNodes := s.nodeStore.ListNodes() + var ephemeralNodes []types.NodeView + + for _, node := range allNodes.All() { + // Check if node is ephemeral by checking its AuthKey + if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + ephemeralNodes = append(ephemeralNodes, node) + } + } + + return views.SliceOf(ephemeralNodes) } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) { + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + // If the database update fails, the NodeStore change will remain, but since we return + // an error, no change notification will be sent to the batcher. + expiryPtr := expiry + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.Expiry = &expiryPtr + }) + + err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.NodeSetExpiry(tx, nodeID, expiry) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -528,12 +657,32 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Nod } // SetNodeTags assigns tags to a node for use in access control policies. -func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ForcedTags = tags + }) + + err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetTags(tx, nodeID, tags) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err) + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -544,16 +693,42 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, ch } // SetApprovedRoutes sets the network routes that a node is approved to advertise. -func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) { + // TODO(kradalby): In principle we should call the AutoApprove logic here + // because even if the CLI removes an auto-approved route, it will be added + // back automatically. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ApprovedRoutes = routes + }) + + err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.SetApprovedRoutes(tx, nodeID, routes) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) } - // Update primary routes after changing approved routes - routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...) + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + } + + // Get the node from NodeStore to ensure we have the latest state + nodeView, ok := s.GetNodeByID(nodeID) + if !ok { + return n, change.EmptySet, fmt.Errorf("node %d not found in NodeStore", nodeID) + } + // Use SubnetRoutes() instead of ApprovedRoutes() to ensure we only set + // primary routes for routes that are both announced AND approved + routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...) if routeChange || !c.IsFull() { c = change.PolicyChange() @@ -563,12 +738,48 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (* } // RenameNode changes the display name of a node. -func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) { + // Validate the new name before making any changes + if err := util.CheckForFQDNRules(newName); err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) + } + + // Check name uniqueness + nodes, err := s.db.ListNodes() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("checking name uniqueness: %w", err) + } + for _, node := range nodes { + if node.ID != nodeID && node.GivenName == newName { + return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName) + } + } + + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.GivenName = newName + }) + + err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.RenameNode(tx, nodeID, newName) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -578,20 +789,45 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, ch return n, c, nil } -// SetLastSeen updates when a node was last seen, used for connectivity monitoring. -func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) { - return s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetLastSeen(tx, nodeID, lastSeen) - }) -} - // AssignNodeToUser transfers a node to a different user. -func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) { - n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { +func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (types.NodeView, change.ChangeSet, error) { + // Validate that both node and user exist + _, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found: %d", nodeID) + } + + user, err := s.GetUserByID(userID) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("user not found: %w", err) + } + + // CRITICAL: Update NodeStore BEFORE database to ensure consistency. + // The NodeStore update is blocking and will be the source of truth for the batcher. + // The database update MUST make the EXACT same change. + s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + n.User = *user + n.UserID = uint(userID) + }) + + err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { return hsdb.AssignNodeToUser(tx, nodeID, userID) }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err) + return types.NodeView{}, change.EmptySet, err + } + + // Get the updated node from NodeStore to ensure consistency + // TODO(kradalby): Validate if this NodeStore read makes sense after database update + n, found := s.GetNodeByID(nodeID) + if !found { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) } if !c.IsFull() { @@ -603,13 +839,59 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*typ // BackfillNodeIPs assigns IP addresses to nodes that don't have them. func (s *State) BackfillNodeIPs() ([]string, error) { - return s.db.BackfillNodeIPs(s.ipAlloc) + changes, err := s.db.BackfillNodeIPs(s.ipAlloc) + if err != nil { + return nil, err + } + + // Refresh NodeStore after IP changes to ensure consistency + if len(changes) > 0 { + nodes, err := s.db.ListNodes() + if err != nil { + return changes, fmt.Errorf("failed to refresh NodeStore after IP backfill: %w", err) + } + + for _, node := range nodes { + // Preserve online status when refreshing from database + existingNode, exists := s.nodeStore.GetNode(node.ID) + if exists && existingNode.Valid() { + node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + } + // TODO(kradalby): This should just update the IP addresses, nothing else in the node store. + // We should avoid PutNode here. + s.nodeStore.PutNode(*node) + } + } + + return changes, nil } // ExpireExpiredNodes finds and processes expired nodes since the last check. // Returns next check time, state update with expired nodes, and whether any were found. func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) { - return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck) + // Why capture start time: We need to ensure we don't miss nodes that expire + // while this function is running by using a consistent timestamp for the next check + started := time.Now() + + var updates []change.ChangeSet + + for _, node := range s.nodeStore.ListNodes().All() { + if !node.Valid() { + continue + } + + // Why check After(lastCheck): We only want to notify about nodes that + // expired since the last check to avoid duplicate notifications + if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { + updates = append(updates, change.KeyExpiry(node.ID())) + } + } + + if len(updates) > 0 { + return started, updates, true + } + + return started, nil, false } // SSHPolicy returns the SSH access policy for a node. @@ -633,13 +915,35 @@ func (s *State) SetPolicy(pol []byte) (bool, error) { } // AutoApproveRoutes checks if a node's routes should be auto-approved. -func (s *State) AutoApproveRoutes(node *types.Node) bool { - return policy.AutoApproveRoutes(s.polMan, node) -} +// AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them. +func (s *State) AutoApproveRoutes(nv types.NodeView) bool { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) + if changed { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.announced", util.PrefixesToString(nv.AnnouncedRoutes())). + Strs("routes.approved.old", util.PrefixesToString(nv.ApprovedRoutes().AsSlice())). + Strs("routes.approved.new", util.PrefixesToString(approved)). + Msg("Single node auto-approval detected route changes") -// PolicyDebugString returns a debug representation of the current policy. -func (s *State) PolicyDebugString() string { - return s.polMan.DebugString() + // Persist the auto-approved routes to database and NodeStore via SetApprovedRoutes + // This ensures consistency between database and NodeStore + _, _, err := s.SetApprovedRoutes(nv.ID(), approved) + if err != nil { + log.Error(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Err(err). + Msg("Failed to persist auto-approved routes") + + return false + } + + log.Info().Uint64("node.id", nv.ID().Uint64()).Str("node.name", nv.Hostname()).Strs("routes.approved", util.PrefixesToString(approved)).Msg("Routes approved") + } + + return changed } // GetPolicy retrieves the current policy from the database. @@ -744,36 +1048,238 @@ func (s *State) HandleNodeFromAuthPath( userID types.UserID, expiry *time.Time, registrationMethod string, -) (*types.Node, change.ChangeSet, error) { - ipv4, ipv6, err := s.ipAlloc.Next() - if err != nil { - return nil, change.EmptySet, err +) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Get the registration entry from cache + regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + if !ok { + return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache } - return s.db.HandleNodeFromAuthPath( - registrationID, - userID, - expiry, - util.RegisterMethodOIDC, - ipv4, ipv6, - ) + // Get the user + user, err := s.db.GetUserByID(userID) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) + } + + // Check if node already exists by node key + existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey) + if exists && existingNodeView.Valid() { + // Node exists - this is a refresh/re-registration + log.Debug(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Username()). + Str("registrationMethod", registrationMethod). + Str("node.name", existingNodeView.Hostname()). + Uint64("node.id", existingNodeView.ID().Uint64()). + Msg("Refreshing existing node registration") + + // Update NodeStore first with the new expiry + s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) { + if expiry != nil { + node.Expiry = expiry + } + // Mark as offline since node is reconnecting + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) + + // Save to database + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry) + if err != nil { + return nil, err + } + // Return the node to satisfy the Write signature + return hsdb.GetNodeByID(tx, existingNodeView.ID()) + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err) + } + + // Get updated node from NodeStore + updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) + + return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil + } + + // New node registration + log.Debug(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Username()). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", expiry)). + Msg("Registering new node from auth callback") + + // Check if node exists with same machine key + var existingMachineNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() { + existingMachineNode = nv.AsStruct() + } + + // Prepare the node for registration + nodeToRegister := regEntry.Node + nodeToRegister.UserID = uint(userID) + nodeToRegister.User = *user + nodeToRegister.RegisterMethod = registrationMethod + if expiry != nil { + nodeToRegister.Expiry = expiry + } + + // Handle IP allocation + var ipv4, ipv6 *netip.Addr + if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { + // Reuse existing IPs and properties + nodeToRegister.ID = existingMachineNode.ID + nodeToRegister.GivenName = existingMachineNode.GivenName + nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes + ipv4 = existingMachineNode.IPv4 + ipv6 = existingMachineNode.IPv6 + } else { + // Allocate new IPs + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } + } + + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 + + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + var savedNode *types.Node + if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.DiscoKey = nodeToRegister.DiscoKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + if expiry != nil { + node.Expiry = expiry + } + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) + + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) + } + + // Delete from registration cache + s.registrationCache.Delete(registrationID) + + // Signal to waiting clients + select { + case regEntry.Registered <- savedNode: + default: + } + close(regEntry.Registered) + + // Update policy manager + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err) + } + + if !nodesChange.Empty() { + return savedNode.View(), nodesChange, nil + } + + return savedNode.View(), change.NodeAdded(savedNode.ID), nil } // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, -) (*types.Node, change.ChangeSet, bool, error) { +) (types.NodeView, change.ChangeSet, error) { + s.mu.Lock() + defer s.mu.Unlock() + pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { - return nil, change.EmptySet, false, err + return types.NodeView{}, change.EmptySet, err } err = pak.Validate() if err != nil { - return nil, change.EmptySet, false, err + return types.NodeView{}, change.EmptySet, err } + // Check if this is a logout request for an ephemeral node + if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { + // Find the node to delete + var nodeToDelete types.NodeView + for _, nv := range s.nodeStore.ListNodes().All() { + if nv.Valid() && nv.MachineKey() == machineKey { + nodeToDelete = nv + break + } + } + if nodeToDelete.Valid() { + c, err := s.DeleteNode(nodeToDelete) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) + } + + return types.NodeView{}, c, nil + } + + return types.NodeView{}, change.EmptySet, nil + } + + log.Debug(). + Caller(). + Str("node.name", regReq.Hostinfo.Hostname). + Str("machine.key", machineKey.ShortString()). + Str("node.key", regReq.NodeKey.ShortString()). + Str("user.name", pak.User.Username()). + Msg("Registering node with pre-auth key") + + // Check if node already exists with same machine key + var existingNode *types.Node + if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { + existingNode = nv.AsStruct() + } + + // Prepare the node for registration nodeToRegister := types.Node{ Hostname: regReq.Hostinfo.Hostname, UserID: pak.User.ID, @@ -783,75 +1289,133 @@ func (s *State) HandleNodeFromPreAuthKey( Hostinfo: regReq.Hostinfo, LastSeen: ptr.To(time.Now()), RegisterMethod: util.RegisterMethodAuthKey, - - // TODO(kradalby): This should not be set on the node, - // they should be looked up through the key, which is - // attached to the node. - ForcedTags: pak.Proto().GetAclTags(), - AuthKey: pak, - AuthKeyID: &pak.ID, + ForcedTags: pak.Proto().GetAclTags(), + AuthKey: pak, + AuthKeyID: &pak.ID, } if !regReq.Expiry.IsZero() { nodeToRegister.Expiry = ®Req.Expiry } - ipv4, ipv6, err := s.ipAlloc.Next() - if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err) + // Handle IP allocation and existing node properties + var ipv4, ipv6 *netip.Addr + if existingNode != nil && existingNode.UserID == pak.User.ID { + // Reuse existing node properties + nodeToRegister.ID = existingNode.ID + nodeToRegister.GivenName = existingNode.GivenName + nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes + ipv4 = existingNode.IPv4 + ipv6 = existingNode.IPv6 + } else { + // Allocate new IPs + ipv4, ipv6, err = s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) + } } - node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - node, err := hsdb.RegisterNode(tx, - nodeToRegister, - ipv4, ipv6, - ) - if err != nil { - return nil, fmt.Errorf("registering node: %w", err) - } + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 - if !pak.Reusable { - err = hsdb.UsePreAuthKey(tx, pak) - if err != nil { - return nil, fmt.Errorf("using pre auth key: %w", err) + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + var savedNode *types.Node + if existingNode != nil && existingNode.UserID == pak.User.ID { + // Update existing node - NodeStore first, then database + s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { + node.NodeKey = nodeToRegister.NodeKey + node.Hostname = nodeToRegister.Hostname + node.Hostinfo = nodeToRegister.Hostinfo + node.Endpoints = nodeToRegister.Endpoints + node.RegisterMethod = nodeToRegister.RegisterMethod + node.ForcedTags = nodeToRegister.ForcedTags + node.AuthKey = nodeToRegister.AuthKey + node.AuthKeyID = nodeToRegister.AuthKeyID + if nodeToRegister.Expiry != nil { + node.Expiry = nodeToRegister.Expiry } - } + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + }) - return node, nil - }) - if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err) - } + log.Trace(). + Caller(). + Str("node.name", nodeToRegister.Hostname). + Uint64("node.id", existingNode.ID.Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", regReq.NodeKey.ShortString()). + Str("user.name", pak.User.Username()). + Msg("Node re-authorized") - // Check if this is a logout request for an ephemeral node - if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { - // This is a logout request for an ephemeral node, delete it immediately - c, err := s.DeleteNode(node) + // Save to database + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err) + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) } - return nil, c, false, nil + } else { + // New node - database first to get ID, then NodeStore + savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + } + + // Add to NodeStore after database creates the ID + s.nodeStore.PutNode(*savedNode) } - // Check if policy manager needs updating - // This is necessary because we just created a new node. - // We need to ensure that the policy manager is aware of this new node. - // Also update users to ensure all users are known when evaluating policies. - usersChanged, err := s.updatePolicyManagerUsers() + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err) + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err) } - nodesChanged, err := s.updatePolicyManagerNodes() + nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err) + return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err) } - policyChanged := usersChanged || nodesChanged + var c change.ChangeSet + if !usersChange.Empty() || !nodesChange.Empty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(savedNode.ID) + } - c := change.NodeAdded(node.ID) - - return node, c, policyChanged, nil + return savedNode.View(), c, nil } // AllocateNextIPs allocates the next available IPv4 and IPv6 addresses. @@ -865,22 +1429,26 @@ func (s *State) AllocateNextIPs() (*netip.Addr, *netip.Addr, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for users. // updatePolicyManagerUsers refreshes the policy manager with current user data. -func (s *State) updatePolicyManagerUsers() (bool, error) { +func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { users, err := s.ListAllUsers() if err != nil { - return false, fmt.Errorf("listing users for policy update: %w", err) + return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err) } log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users") changed, err := s.polMan.SetUsers(users) if err != nil { - return false, fmt.Errorf("updating policy manager users: %w", err) + return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err) } - log.Debug().Bool("changed", changed).Msg("Policy manager users updated") + log.Debug().Caller().Bool("policy.changed", changed).Msg("Policy manager user update completed because SetUsers operation finished") - return changed, nil + if changed { + return change.PolicyChange(), nil + } + + return change.EmptySet, nil } // updatePolicyManagerNodes updates the policy manager with current nodes. @@ -889,18 +1457,19 @@ func (s *State) updatePolicyManagerUsers() (bool, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for nodes. // updatePolicyManagerNodes refreshes the policy manager with current node data. -func (s *State) updatePolicyManagerNodes() (bool, error) { - nodes, err := s.ListNodes() +func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) { + nodes := s.ListNodes() + + changed, err := s.polMan.SetNodes(nodes) if err != nil { - return false, fmt.Errorf("listing nodes for policy update: %w", err) + return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err) } - changed, err := s.polMan.SetNodes(nodes.ViewSlice()) - if err != nil { - return false, fmt.Errorf("updating policy manager nodes: %w", err) + if changed { + return change.PolicyChange(), nil } - return changed, nil + return change.EmptySet, nil } // PingDB checks if the database connection is healthy. @@ -914,147 +1483,235 @@ func (s *State) PingDB(ctx context.Context) error { // TODO(kradalby): This is kind of messy, maybe this is another +1 // for an event bus. See example comments here. // autoApproveNodes automatically approves nodes based on policy rules. -func (s *State) autoApproveNodes() error { - err := s.db.Write(func(tx *gorm.DB) error { - nodes, err := hsdb.ListNodes(tx) - if err != nil { - return err - } +func (s *State) autoApproveNodes() ([]change.ChangeSet, error) { + nodes := s.ListNodes() - for _, node := range nodes { - // TODO(kradalby): This change should probably be sent to the rest of the system. - changed := policy.AutoApproveRoutes(s.polMan, node) + // Approve routes concurrently, this should make it likely + // that the writes end in the same batch in the nodestore write. + var errg errgroup.Group + var cs []change.ChangeSet + var mu sync.Mutex + for _, nv := range nodes.All() { + errg.Go(func() error { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) if changed { - err = tx.Save(node).Error + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.approved.old", util.PrefixesToString(nv.ApprovedRoutes().AsSlice())). + Strs("routes.approved.new", util.PrefixesToString(approved)). + Msg("Routes auto-approved by policy") + + _, c, err := s.SetApprovedRoutes(nv.ID(), approved) if err != nil { return err } - // TODO(kradalby): This should probably be done outside of the transaction, - // and the result of this should be propagated to the system. - s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) + mu.Lock() + cs = append(cs, c) + mu.Unlock() } - } - return nil - }) - if err != nil { - return fmt.Errorf("auto approving routes for nodes: %w", err) + return nil + }) } - return nil + err := errg.Wait() + if err != nil { + return nil, err + } + + return cs, nil } -// TODO(kradalby): This should just take the node ID? -func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) { - // TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, - // which means we could shortcut the whole change thing if there are no other important updates. - peerChange := node.PeerChangeFromMapRequest(req) +// UpdateNodeFromMapRequest processes a MapRequest and updates the node. +// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, +// which means we could shortcut the whole change thing if there are no other important updates. +// When a field is added to this function, remember to also add it to: +// - node.PeerChangeFromMapRequest +// - node.ApplyPeerChange +// - logTracePeerChange in poll.go. +func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) { + var routeChange bool + var hostinfoChanged bool + var needsRouteApproval bool + // We need to ensure we update the node as it is in the NodeStore at + // the time of the request. + s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { + peerChange := currentNode.PeerChangeFromMapRequest(req) + hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) - node.ApplyPeerChange(&peerChange) + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(peerChange) && !hostinfoChanged { + return + } - sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo) + // Calculate route approval before NodeStore update to avoid calling View() inside callback + var autoApprovedRoutes []netip.Prefix + hasNewRoutes := req.Hostinfo != nil && len(req.Hostinfo.RoutableIPs) > 0 + needsRouteApproval = hostinfoChanged && (routesChanged(currentNode.View(), req.Hostinfo) || (hasNewRoutes && len(currentNode.ApprovedRoutes) == 0)) + if needsRouteApproval { + autoApprovedRoutes, routeChange = policy.ApproveRoutesWithPolicy( + s.polMan, + currentNode.View(), + // We need to preserve currently approved routes to ensure + // routes outside of the policy approver is persisted. + currentNode.ApprovedRoutes, + // However, the node has updated its routable IPs, so we + // need to approve them using that as a context. + req.Hostinfo.RoutableIPs, + ) + } - // The node might not set NetInfo if it has not changed and if - // the full HostInfo object is overwritten, the information is lost. - // If there is no NetInfo, keep the previous one. - // From 1.66 the client only sends it if changed: - // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 - // TODO(kradalby): evaluate if we need better comparing of hostinfo - // before we take the changes. - if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil { - req.Hostinfo.NetInfo = node.Hostinfo.NetInfo - } - node.Hostinfo = req.Hostinfo + // Log when routes change but approval doesn't + if hostinfoChanged && req.Hostinfo != nil && routesChanged(currentNode.View(), req.Hostinfo) && !routeChange { + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("oldAnnouncedRoutes", util.PrefixesToString(currentNode.AnnouncedRoutes())). + Strs("newAnnouncedRoutes", util.PrefixesToString(req.Hostinfo.RoutableIPs)). + Strs("approvedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Bool("routeChange", routeChange). + Msg("announced routes changed but approved routes did not") + } - // If there is no changes and nothing to save, - // return early. - if peerChangeEmpty(peerChange) && !sendUpdate { - // mapResponseEndpointUpdates.WithLabelValues("noop").Inc() - return change.EmptySet, nil + currentNode.ApplyPeerChange(&peerChange) + + if hostinfoChanged { + // The node might not set NetInfo if it has not changed and if + // the full HostInfo object is overwritten, the information is lost. + // If there is no NetInfo, keep the previous one. + // From 1.66 the client only sends it if changed: + // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 + // TODO(kradalby): evaluate if we need better comparing of hostinfo + // before we take the changes. + // Preserve NetInfo only if the existing node actually has valid NetInfo + // This prevents copying nil NetInfo which would lose DERP relay assignments + if req.Hostinfo != nil && req.Hostinfo.NetInfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). + Msg("preserving NetInfo from previous Hostinfo in MapRequest") + req.Hostinfo.NetInfo = currentNode.Hostinfo.NetInfo + } else if req.Hostinfo == nil && currentNode.Hostinfo != nil && currentNode.Hostinfo.NetInfo != nil { + // When MapRequest has no Hostinfo but we have existing NetInfo, create a minimal + // Hostinfo to preserve the NetInfo to maintain DERP connectivity + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Int("preferredDERP", currentNode.Hostinfo.NetInfo.PreferredDERP). + Msg("creating minimal Hostinfo to preserve NetInfo in MapRequest") + req.Hostinfo = &tailcfg.Hostinfo{ + NetInfo: currentNode.Hostinfo.NetInfo, + } + } + currentNode.Hostinfo = req.Hostinfo + currentNode.ApplyHostnameFromHostInfo(req.Hostinfo) + + if routeChange { + // Apply pre-calculated route approval + // Always apply the route approval result to ensure consistency, + // regardless of whether the policy evaluation detected changes. + // This fixes the bug where routes weren't properly cleared when + // auto-approvers were removed from the policy. + log.Info(). + Uint64("node.id", id.Uint64()). + Strs("oldApprovedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Strs("newApprovedRoutes", util.PrefixesToString(autoApprovedRoutes)). + Bool("routeChanged", routeChange). + Msg("applying route approval results") + currentNode.ApprovedRoutes = autoApprovedRoutes + } + } + }) + + nodeRouteChange := change.EmptySet + + // Handle route changes after NodeStore update + // We need to update node routes if either: + // 1. The approved routes changed (routeChange is true), OR + // 2. The announced routes changed (even if approved routes stayed the same) + // This is because SubnetRoutes is the intersection of announced AND approved routes. + needsRouteUpdate := false + routesChangedButNotApproved := hostinfoChanged && req.Hostinfo != nil && needsRouteApproval && !routeChange + if routeChange { + needsRouteUpdate = true + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because approved routes changed") + } else if routesChangedButNotApproved { + needsRouteUpdate = true + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because announced routes changed but approved routes did not") } - c := change.EmptySet + if needsRouteUpdate { + // Get the updated node to access its subnet routes + updatedNode, exists := s.GetNodeByID(id) + if !exists { + return change.EmptySet, fmt.Errorf("node disappeared during update: %d", id) + } - // Check if the Hostinfo of the node has changed. - // If it has changed, check if there has been a change to - // the routable IPs of the host and update them in - // the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the route change. - // If the hostinfo has changed, but not the routes, just update - // hostinfo and let the function continue. - if routesChanged { - // Auto approve any routes that have been defined in policy as - // auto approved. Check if this actually changed the node. - _ = s.AutoApproveRoutes(node) - - // Update the routes of the given node in the route manager to - // see if an update needs to be sent. - c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...) + // SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes() + // which returns only the intersection of announced AND approved routes. + // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())). + Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())). + Strs("subnetRoutes", util.PrefixesToString(updatedNode.SubnetRoutes())). + Msg("updating node routes for distribution") + nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...) } - // Check if there has been a change to Hostname and update them - // in the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the hostname change. - node.ApplyHostnameFromHostInfo(req.Hostinfo) - - _, policyChange, err := s.SaveNode(node) + _, policyChange, err := s.persistNodeToDB(id) if err != nil { - return change.EmptySet, err + return change.EmptySet, fmt.Errorf("saving to database: %w", err) } if policyChange.IsFull() { - c = policyChange + return policyChange, nil + } + if !nodeRouteChange.Empty() { + return nodeRouteChange, nil } - if c.Empty() { - c = change.NodeAdded(node.ID) - } - - return c, nil + return change.NodeAdded(id), nil } -// hostInfoChanged reports if hostInfo has changed in two ways, -// - first bool reports if an update needs to be sent to nodes -// - second reports if there has been changes to routes -// the caller can then use this info to save and update nodes -// and routes as needed. -func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { - if old.Equal(new) { - return false, false +func hostinfoEqual(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { + if !oldNode.Valid() && new == nil { + return true + } + if !oldNode.Valid() || new == nil { + return false + } + old := oldNode.AsStruct().Hostinfo + + return old.Equal(new) +} + +func routesChanged(oldNode types.NodeView, new *tailcfg.Hostinfo) bool { + var oldRoutes []netip.Prefix + if oldNode.Valid() && oldNode.AsStruct().Hostinfo != nil { + oldRoutes = oldNode.AsStruct().Hostinfo.RoutableIPs } - if old == nil && new != nil { - return true, true - } - - // Routes - oldRoutes := make([]netip.Prefix, 0) - if old != nil { - oldRoutes = old.RoutableIPs - } newRoutes := new.RoutableIPs + if newRoutes == nil { + newRoutes = []netip.Prefix{} + } tsaddr.SortPrefixes(oldRoutes) tsaddr.SortPrefixes(newRoutes) - if !xslices.Equal(oldRoutes, newRoutes) { - return true, true - } - - // Services is mostly useful for discovery and not critical, - // except for peerapi, which is how nodes talk to each other. - // If peerapi was not part of the initial mapresponse, we - // need to make sure its sent out later as it is needed for - // Taildrop. - // TODO(kradalby): Length comparison is a bit naive, replace. - if len(old.Services) != len(new.Services) { - return true, false - } - - return false, false + return !slices.Equal(oldRoutes, newRoutes) } func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 3301cb35..e38a98f6 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool { case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate: return true } + return false } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 81a2a86a..959572a2 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -13,6 +13,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/tsaddr" @@ -355,6 +356,7 @@ func (node *Node) Proto() *v1.Node { GivenName: node.GivenName, User: node.User.Proto(), ForcedTags: node.ForcedTags, + Online: node.IsOnline != nil && *node.IsOnline, // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has // to be populated manually with PrimaryRoute, to ensure it includes the @@ -419,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix { } // SubnetRoutes returns the list of routes that the node announces and are approved. +// +// IMPORTANT: This method is used for internal data structures and should NOT be used +// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually +// with PrimaryRoutes to ensure it includes only routes actively served by the node. +// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto. func (node *Node) SubnetRoutes() []netip.Prefix { var routes []netip.Prefix @@ -511,11 +518,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } if node.Hostname != hostInfo.Hostname { + log.Trace(). + Str("node.id", node.ID.String()). + Str("old_hostname", node.Hostname). + Str("new_hostname", hostInfo.Hostname). + Str("old_given_name", node.GivenName). + Bool("given_name_changed", node.GivenNameHasBeenChanged()). + Msg("Updating hostname from hostinfo") + if node.GivenNameHasBeenChanged() { node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname) } node.Hostname = hostInfo.Hostname + + log.Trace(). + Str("node.id", node.ID.String()). + Str("new_hostname", node.Hostname). + Str("new_given_name", node.GivenName). + Msg("Hostname updated") } } @@ -759,6 +780,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix { return v.ж.ExitRoutes() } +// RequestTags returns the ACL tags that the node is requesting. +func (v NodeView) RequestTags() []string { + if !v.Valid() || !v.Hostinfo().Valid() { + return []string{} + } + return v.Hostinfo().RequestTags().AsSlice() +} + +// Proto converts the NodeView to a protobuf representation. +func (v NodeView) Proto() *v1.Node { + if !v.Valid() { + return nil + } + return v.ж.Proto() +} + // HasIP reports if a node has a given IP address. func (v NodeView) HasIP(i netip.Addr) bool { if !v.Valid() { diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index d7bc7897..97bb3da4 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Parse each hop line - hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?") + hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) for i := 1; i < len(lines); i++ { matches := hopRegex.FindStringSubmatch(lines[i]) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index d118b643..394d219b 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.NoError(ct, err) assert.Equal(ct, "NeedsLogin", status.BackendState) } + assertTailscaleNodesLogout(t, allClients) }, shortAccessTTL+10*time.Second, 5*time.Second) } diff --git a/integration/general_test.go b/integration/general_test.go index 4bf36567..9da61958 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -547,6 +547,8 @@ func TestUpdateHostnameFromClient(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second var nodes []*v1.Node assert.EventuallyWithT(t, func(ct *assert.CollectT) { err := executeAndUnmarshal( @@ -642,27 +644,34 @@ func TestUpdateHostnameFromClient(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "node", - "list", - "--output", - "json", - }, - &nodes, - ) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second + assert.Eventually(t, func() bool { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) - assertNoErr(t, err) - assert.Len(t, nodes, 3) + if err != nil || len(nodes) != 3 { + return false + } - for _, node := range nodes { - hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] - givenName := fmt.Sprintf("%d-givenname", node.GetId()) - assert.Equal(t, hostname+"NEW", node.GetName()) - assert.Equal(t, givenName, node.GetGivenName()) - } + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName { + return false + } + } + return true + }, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix") } func TestExpireNode(t *testing.T) { diff --git a/integration/route_test.go b/integration/route_test.go index 7243d3f2..bb13a47f 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, node.GetSubnetRoutes(), 1) } - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - assert.NotNil(t, peerStatus.PrimaryRoutes) - - assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3) - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + assert.NotNil(c, peerStatus.PrimaryRoutes) + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + } } - } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") _, err = headscale.ApproveRoutes( 1, @@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - - for _, node := range nodes { - if node.GetId() == 1 { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 - assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 - assert.Empty(t, node.GetSubnetRoutes()) - } else if node.GetId() == 2 { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 - assert.Empty(t, node.GetApprovedRoutes()) - assert.Empty(t, node.GetSubnetRoutes()) - } else { - assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 - assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 - assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + for _, node := range nodes { + if node.GetId() == 1 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetSubnetRoutes()) + } else if node.GetId() == 2 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetApprovedRoutes()) + assert.Empty(c, node.GetSubnetRoutes()) + } else { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + } } - } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the clients can see the new routes for _, client := range allClients { @@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - time.Sleep(3 * time.Second) + // Wait for route configuration changes after advertising routes + var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err := headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 0, 0) - requireNodeRouteCount(t, nodes[1], 1, 0, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved") // Verify that no routes has been sent to the client, // they are not yet enabled. @@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on first subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 0, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route") // Verify that the client has routes from the primary machine and can access // the webservice. @@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on second subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route") // Verify that the client has routes from the primary machine srs1 = subRouter1.MustStatus() @@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) { ) require.NoError(t, err) - time.Sleep(3 * time.Second) + // Wait for route approval on third subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0) + }, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route") // Verify that the client has routes from the primary machine srs1 = subRouter1.MustStatus() @@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) { require.NoError(t, err) assert.Len(t, result, 13) - tr, err = client.Traceroute(webip) - require.NoError(t, err) - assertTracerouteViaIP(t, tr, subRouter1.MustIPv4()) + // Wait for traceroute to work correctly through the expected router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + + // Get the expected router IP - use a more robust approach to handle temporary disconnections + ips, err := subRouter1.IPs() + assert.NoError(c, err) + assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") + + var expectedIP netip.Addr + for _, ip := range ips { + if ip.Is4() { + expectedIP = ip + break + } + } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") + + assertTracerouteViaIPWithCollect(c, tr, expectedIP) + }, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1") // Take down the current primary t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname()) @@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter1.Down() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r1 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs2 = subRouter2.MustStatus() + clientStatus = client.MustStatus() - srs2 = subRouter2.MustStatus() - clientStatus = client.MustStatus() + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down") - assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up") - assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up") + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) require.NotNil(t, srs2PeerStatus.PrimaryRoutes) @@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter2.Down() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r2 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // TODO(kradalby): Check client status - // Both are expected to be down + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - // Verify that the route is not presented from either router - clientStatus, err = client.Status() - require.NoError(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down") - assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down") - assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down") + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.False(c, srs2PeerStatus.Online, "r2 should be offline") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter1.Up() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for router status changes after r1 comes back up + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - require.NoError(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down") - assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down") - assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available") + assert.True(c, srs1PeerStatus.Online, "r1 should be back online") + assert.False(c, srs2PeerStatus.Online, "r2 should still be offline") + assert.True(c, srs3PeerStatus.Online, "r3 should still be online") + }, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) { err = subRouter2.Up() require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing to complete and online status to be updated + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - require.NoError(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] - - assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up") - assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up") - assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up") + assert.True(c, srs1PeerStatus.Online, "r1 should be online") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + }, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2") assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes) @@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname()) _, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{}) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + // After disabling route on r3, r1 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname()) _, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{}) - time.Sleep(5 * time.Second) + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + // After disabling route on r1, r2 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) { util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()), ) - time.Sleep(5 * time.Second) + // Wait for route state changes after re-enabling r1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 6) - - requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) - requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - require.Len(t, nodes, 2) - - requireNodeRouteCount(t, nodes[0], 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 0, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the client has routes from the primary machine srs1, _ := subRouter1.Status() @@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) { requireNodeRouteCount(t, nodes[0], 2, 2, 2) requireNodeRouteCount(t, nodes[1], 2, 2, 2) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - require.NotNil(t, peerStatus.AllowedIPs) - assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4) - assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) - assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + assert.NotNil(c, peerStatus.AllowedIPs) + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + } } - } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") } // TestSubnetRouterMultiNetwork is an evolution of the subnet router test. @@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes and clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 2) - requireNodeRouteCount(t, nodes[0], 1, 1, 1) + // Verify that the routes have been sent to the client + status, err = user2c.Status() + assert.NoError(c, err) - // Verify that the routes have been sent to the client. - status, err = user2c.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref) - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref}) - } + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients") usernet1, err := scenario.Network("usernet1") require.NoError(t, err) @@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes and clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) - nodes, err = headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, nodes, 2) - requireNodeRouteCount(t, nodes[0], 2, 2, 2) + // Verify that the routes have been sent to the client + status, err = user2c.Status() + assert.NoError(c, err) - // Verify that the routes have been sent to the client. - status, err = user2c.Status() - require.NoError(t, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) - } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients") // Tell user2c to use user1c as an exit node. command = []string{ @@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) + var nodes []*v1.Node opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoErrorf(t, err, "failed to advertise route: %s", err) } - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err := headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err := client.Status() @@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = headscale.SetPolicy(tt.pol) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { ) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = headscale.SetPolicy(tt.pol) require.NoError(t, err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerSubRoute.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") requireNodeRouteCount(t, nodes[1], 1, 1, 1) // Verify that the routes have been sent to the client. @@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerSubRoute.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - // These route should auto approve, so the node is expected to have a route - // for all counts. - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") requireNodeRouteCount(t, nodes[1], 1, 1, 0) requireNodeRouteCount(t, nodes[2], 0, 0, 0) @@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) { _, _, err = routerExitNode.Execute(command) require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(5 * time.Second) - - nodes, err = headscale.ListNodes() - require.NoError(t, err) - requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) - requireNodeRouteCount(t, nodes[1], 1, 1, 0) - requireNodeRouteCount(t, nodes[2], 2, 2, 2) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Verify that the routes have been sent to the client. status, err = client.Status() @@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) { require.Equal(t, tr.Route[0].IP, ip) } +// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT +func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { + assert.NotNil(c, tr) + assert.True(c, tr.Success) + assert.NoError(c, tr.Err) + assert.NotEmpty(c, tr.Route) + assert.Equal(c, tr.Route[0].IP, ip) +} + // requirePeerSubnetRoutes asserts that the peer has the expected subnet routes. func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) { t.Helper() @@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected } } +func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) { + if status.AllowedIPs.Len() <= 2 && len(expected) != 0 { + assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected)) + return + } + + if len(expected) == 0 { + expected = []netip.Prefix{} + } + + got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool { + if tsaddr.IsExitRoute(p) { + return true + } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) + }) + + if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { + assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff)) + } +} + func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) { t.Helper() require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) @@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) } +func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) { + assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) + assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes())) + assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) +} + // TestSubnetRouteACLFiltering tests that a node can only access subnet routes // that are explicitly allowed in the ACL. func TestSubnetRouteACLFiltering(t *testing.T) { @@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) { ) require.NoError(t, err) - // Give some time for the routes to propagate - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // List nodes and verify the router has 3 available routes + nodes, err = headscale.NodesByUser() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - // List nodes and verify the router has 3 available routes - nodes, err = headscale.NodesByUser() - require.NoError(t, err) - require.Len(t, nodes, 2) + // Find the router node + routerNode = nodes[routerUser][0] - // Find the router node - routerNode = nodes[routerUser][0] - - // Check that the router has 3 routes now approved and available - requireNodeRouteCount(t, routerNode, 3, 3, 3) + // Check that the router has 3 routes now approved and available + requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate") // Now check the client node status nodeStatus, err := nodeClient.Status() diff --git a/integration/scenario.go b/integration/scenario.go index 817d927b..8ce54b89 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -14,7 +14,6 @@ import ( "net/netip" "net/url" "os" - "sort" "strconv" "strings" "sync" @@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { return nil, fmt.Errorf("no network named: %s", name) } - for _, ipam := range net.Network.IPAM.Config { - pref, err := netip.ParsePrefix(ipam.Subnet) - if err != nil { - return nil, err - } - - return &pref, nil + if len(net.Network.IPAM.Config) == 0 { + return nil, fmt.Errorf("no IPAM config found in network: %s", name) } - return nil, fmt.Errorf("no prefix found in network: %s", name) + pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) + if err != nil { + return nil, err + } + + return &pref, nil } func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { @@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv( return err } - sort.Strings(s.spec.Users) for _, user := range s.spec.Users { u, err := s.CreateUser(user) if err != nil {