1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

mapper: send change instead of full update (#2775)

This commit is contained in:
Kristoffer Dalby 2025-09-17 14:23:21 +02:00 committed by GitHub
parent 4de56c40d8
commit ed3a9c8d6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 96 additions and 69 deletions

View File

@ -13,11 +13,9 @@ import (
"time" "time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
) )
@ -494,41 +492,6 @@ func EnsureUniqueGivenName(
return givenName, nil return givenName, nil
} }
// ExpireExpiredNodes checks for nodes that have expired since the last check
// and returns a time to be used for the next check, a StateUpdate
// containing the expired nodes, and a boolean indicating if any nodes were found.
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, []change.ChangeSet, bool) {
// use the time of the start of the function to ensure we
// dont miss some nodes by returning it _after_ we have
// checked everything.
started := time.Now()
expired := make([]*tailcfg.PeerChange, 0)
var updates []change.ChangeSet
nodes, err := ListNodes(tx)
if err != nil {
return time.Unix(0, 0), nil, false
}
for _, node := range nodes {
if node.IsExpired() && node.Expiry.After(lastCheck) {
expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
updates = append(updates, change.KeyExpiry(node.ID))
}
}
if len(expired) > 0 {
return started, updates, true
}
return started, nil, false
}
// EphemeralGarbageCollector is a garbage collector that will delete nodes after // EphemeralGarbageCollector is a garbage collector that will delete nodes after
// a certain amount of time. // a certain amount of time.
// It is used to delete ephemeral nodes that have disconnected and should be // It is used to delete ephemeral nodes that have disconnected and should be

View File

@ -88,16 +88,9 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router. // TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
mapResp, err = mapper.fullMapResponse(nodeID, version) mapResp, err = mapper.fullMapResponse(nodeID, version)
} else { } else {
// CRITICAL FIX: Read actual online status from NodeStore when available, // Trust the change type for online/offline status to avoid race conditions
// fall back to deriving from change type for unit tests or when NodeStore is empty // between NodeStore updates and change processing
var onlineStatus bool onlineStatus := c.Change == change.NodeCameOnline
if node, found := mapper.state.GetNodeByID(c.NodeID); found && node.IsOnline().Valid() {
// Use actual NodeStore status when available (production case)
onlineStatus = node.IsOnline().Get()
} else {
// Fall back to deriving from change type (unit test case or initial setup)
onlineStatus = c.Change == change.NodeCameOnline
}
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{ {
@ -108,11 +101,33 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
} }
case change.NodeNewOrUpdate: case change.NodeNewOrUpdate:
mapResp, err = mapper.fullMapResponse(nodeID, version) // If the node is the one being updated, we send a self update that preserves peer information
// to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes)
// without losing its view of peer status during rapid reconnection cycles
if c.IsSelfUpdate(nodeID) {
mapResp, err = mapper.selfMapResponse(nodeID, version)
} else {
mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID)
}
case change.NodeRemove: case change.NodeRemove:
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID) mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
case change.NodeKeyExpiry:
// If the node is the one whose key is expiring, we send a "full" self update
// as nodes will ignore patch updates about themselves (?).
if c.IsSelfUpdate(nodeID) {
mapResp, err = mapper.selfMapResponse(nodeID, version)
// mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
KeyExpiry: c.NodeExpiry,
},
})
}
default: default:
// The following will always hit this: // The following will always hit this:
// change.Full, change.Policy // change.Full, change.Policy

View File

@ -1028,7 +1028,9 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// Add multiple changes rapidly to test batching // Add multiple changes rapidly to test batching
batcher.AddWork(change.DERPSet) batcher.AddWork(change.DERPSet)
batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID)) // Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry))
batcher.AddWork(change.DERPSet) batcher.AddWork(change.DERPSet)
batcher.AddWork(change.NodeAdded(testNodes[1].n.ID)) batcher.AddWork(change.NodeAdded(testNodes[1].n.ID))
batcher.AddWork(change.DERPSet) batcher.AddWork(change.DERPSet)
@ -1278,7 +1280,9 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Add node-specific work occasionally // Add node-specific work occasionally
if i%10 == 0 { if i%10 == 0 {
batcher.AddWork(change.KeyExpiry(testNode.n.ID)) // Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry))
} }
// Rapid removal creates race between worker and removal // Rapid removal creates race between worker and removal
@ -1493,7 +1497,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
if i%7 == 0 && len(allNodes) > 0 { if i%7 == 0 && len(allNodes) > 0 {
// Node-specific changes using real nodes // Node-specific changes using real nodes
node := allNodes[i%len(allNodes)] node := allNodes[i%len(allNodes)]
batcher.AddWork(change.KeyExpiry(node.n.ID)) // Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry))
} }
// Small delay to allow some batching // Small delay to allow some batching

View File

@ -28,6 +28,7 @@ type debugType string
const ( const (
fullResponseDebug debugType = "full" fullResponseDebug debugType = "full"
selfResponseDebug debugType = "self"
patchResponseDebug debugType = "patch" patchResponseDebug debugType = "patch"
removeResponseDebug debugType = "remove" removeResponseDebug debugType = "remove"
changeResponseDebug debugType = "change" changeResponseDebug debugType = "change"
@ -68,24 +69,17 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
// WithSelfNode adds the requesting node to the response. // WithSelfNode adds the requesting node to the response.
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID) nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
b.addError(errors.New("node not found")) b.addError(errors.New("node not found"))
return b return b
} }
// Always use batcher's view of online status for self node
// The batcher respects grace periods for logout scenarios
node := nodeView.AsStruct()
// if b.mapper.batcher != nil {
// node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
// }
_, matchers := b.mapper.state.Filter() _, matchers := b.mapper.state.Filter()
tailnode, err := tailNode( tailnode, err := tailNode(
node.View(), b.capVer, b.mapper.state, nv, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix { func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers) return policy.ReduceRoutes(nv, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
}, },
b.mapper.cfg) b.mapper.cfg)
if err != nil { if err != nil {

View File

@ -158,6 +158,26 @@ func (m *mapper) fullMapResponse(
Build() Build()
} }
func (m *mapper) selfMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
ma, err := m.NewMapResponseBuilder(nodeID).
WithDebugType(selfResponseDebug).
WithCapabilityVersion(capVer).
WithSelfNode().
Build()
if err != nil {
return nil, err
}
// Set the peers to nil, to ensure the node does not think
// its getting a new list.
ma.Peers = nil
return ma, err
}
func (m *mapper) derpMapResponse( func (m *mapper) derpMapResponse(
nodeID types.NodeID, nodeID types.NodeID,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
@ -190,7 +210,6 @@ func (m *mapper) peerChangeResponse(
return m.NewMapResponseBuilder(nodeID). return m.NewMapResponseBuilder(nodeID).
WithDebugType(changeResponseDebug). WithDebugType(changeResponseDebug).
WithCapabilityVersion(capVer). WithCapabilityVersion(capVer).
WithSelfNode().
WithUserProfiles(peers). WithUserProfiles(peers).
WithPeerChanges(peers). WithPeerChanges(peers).
Build() Build()

View File

@ -228,7 +228,23 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
defer pm.mu.Unlock() defer pm.mu.Unlock()
pm.users = users pm.users = users
return pm.updateLocked() // Clear SSH policy map when users change to force SSH policy recomputation
// This ensures that if SSH policy compilation previously failed due to missing users,
// it will be retried with the new user list
clear(pm.sshPolicyMap)
changed, err := pm.updateLocked()
if err != nil {
return false, err
}
// If SSH policies exist, force a policy change when users are updated
// This ensures nodes get updated SSH policies even if other policy hashes didn't change
if pm.pol != nil && pm.pol.SSHs != nil && len(pm.pol.SSHs) > 0 {
return true, nil
}
return changed, nil
} }
// SetNodes updates the nodes in the policy manager and updates the filter rules. // SetNodes updates the nodes in the policy manager and updates the filter rules.

View File

@ -650,7 +650,7 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
} }
if !c.IsFull() { if !c.IsFull() {
c = change.KeyExpiry(nodeID) c = change.KeyExpiry(nodeID, expiry)
} }
return n, c, nil return n, c, nil
@ -898,7 +898,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
// Why check After(lastCheck): We only want to notify about nodes that // Why check After(lastCheck): We only want to notify about nodes that
// expired since the last check to avoid duplicate notifications // expired since the last check to avoid duplicate notifications
if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) {
updates = append(updates, change.KeyExpiry(node.ID())) updates = append(updates, change.KeyExpiry(node.ID(), node.Expiry().Get()))
} }
} }
@ -1118,7 +1118,11 @@ func (s *State) HandleNodeFromAuthPath(
// Get updated node from NodeStore // Get updated node from NodeStore
updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID())
return updatedNode, change.KeyExpiry(existingNodeView.ID()), nil if expiry != nil {
return updatedNode, change.KeyExpiry(existingNodeView.ID(), *expiry), nil
}
return updatedNode, change.FullSet, nil
} }
// New node registration // New node registration

View File

@ -3,6 +3,7 @@ package change
import ( import (
"errors" "errors"
"time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
) )
@ -68,6 +69,9 @@ type ChangeSet struct {
// IsSubnetRouter indicates whether the node is a subnet router. // IsSubnetRouter indicates whether the node is a subnet router.
IsSubnetRouter bool IsSubnetRouter bool
// NodeExpiry is set if the change is NodeKeyExpiry.
NodeExpiry *time.Time
} }
func (c *ChangeSet) Validate() error { func (c *ChangeSet) Validate() error {
@ -126,6 +130,11 @@ func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
return ret return ret
} }
// IsSelfUpdate reports whether this ChangeSet represents an update to the given node itself.
func (c ChangeSet) IsSelfUpdate(nodeID types.NodeID) bool {
return c.NodeID == nodeID
}
func (c ChangeSet) AlsoSelf() bool { func (c ChangeSet) AlsoSelf() bool {
// If NodeID is 0, it means this ChangeSet is not related to a specific node, // If NodeID is 0, it means this ChangeSet is not related to a specific node,
// so we consider it as a change that should be sent to all nodes. // so we consider it as a change that should be sent to all nodes.
@ -179,10 +188,11 @@ func NodeOffline(id types.NodeID) ChangeSet {
} }
} }
func KeyExpiry(id types.NodeID) ChangeSet { func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet {
return ChangeSet{ return ChangeSet{
Change: NodeKeyExpiry, Change: NodeKeyExpiry,
NodeID: id, NodeID: id,
NodeExpiry: &expiry,
} }
} }