mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-22 00:11:47 +01:00
471 lines
12 KiB
Go
471 lines
12 KiB
Go
package notifier
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/puzpuzpuz/xsync/v3"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/sasha-s/go-deadlock"
|
|
"tailscale.com/envknob"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/util/set"
|
|
)
|
|
|
|
var (
|
|
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
|
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
|
)
|
|
|
|
func init() {
|
|
deadlock.Opts.Disable = !debugDeadlock
|
|
if debugDeadlock {
|
|
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
|
deadlock.Opts.PrintAllCurrentGoroutines = true
|
|
}
|
|
}
|
|
|
|
type Notifier struct {
|
|
l deadlock.Mutex
|
|
nodes map[types.NodeID]chan<- types.StateUpdate
|
|
connected *xsync.MapOf[types.NodeID, bool]
|
|
b *batcher
|
|
cfg *types.Config
|
|
closed bool
|
|
}
|
|
|
|
func NewNotifier(cfg *types.Config) *Notifier {
|
|
n := &Notifier{
|
|
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
|
connected: xsync.NewMapOf[types.NodeID, bool](),
|
|
cfg: cfg,
|
|
closed: false,
|
|
}
|
|
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
|
n.b = b
|
|
|
|
go b.doWork()
|
|
return n
|
|
}
|
|
|
|
// Close stops the batcher and closes all channels.
|
|
func (n *Notifier) Close() {
|
|
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
|
|
|
|
n.closed = true
|
|
n.b.close()
|
|
|
|
for _, c := range n.nodes {
|
|
close(c)
|
|
}
|
|
}
|
|
|
|
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
|
log.Trace().
|
|
Uint64("node.id", nID.Uint64()).
|
|
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
|
|
}
|
|
|
|
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
|
start := time.Now()
|
|
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
|
|
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
|
|
|
if n.closed {
|
|
return
|
|
}
|
|
|
|
// If a channel exists, it means the node has opened a new
|
|
// connection. Close the old channel and replace it.
|
|
if curr, ok := n.nodes[nodeID]; ok {
|
|
n.tracef(nodeID, "channel present, closing and replacing")
|
|
close(curr)
|
|
}
|
|
|
|
n.nodes[nodeID] = c
|
|
n.connected.Store(nodeID, true)
|
|
|
|
n.tracef(nodeID, "added new channel")
|
|
notifierNodeUpdateChans.Inc()
|
|
}
|
|
|
|
// RemoveNode removes a node and a given channel from the notifier.
|
|
// It checks that the channel is the same as currently being updated
|
|
// and ignores the removal if it is not.
|
|
// RemoveNode reports if the node/chan was removed.
|
|
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
|
|
start := time.Now()
|
|
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
|
|
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
|
|
|
if n.closed {
|
|
return true
|
|
}
|
|
|
|
if len(n.nodes) == 0 {
|
|
return true
|
|
}
|
|
|
|
// If the channel exist, but it does not belong
|
|
// to the caller, ignore.
|
|
if curr, ok := n.nodes[nodeID]; ok {
|
|
if curr != c {
|
|
n.tracef(nodeID, "channel has been replaced, not removing")
|
|
return false
|
|
}
|
|
}
|
|
|
|
delete(n.nodes, nodeID)
|
|
n.connected.Store(nodeID, false)
|
|
|
|
n.tracef(nodeID, "removed channel")
|
|
notifierNodeUpdateChans.Dec()
|
|
|
|
return true
|
|
}
|
|
|
|
// IsConnected reports if a node is connected to headscale and has a
|
|
// poll session open.
|
|
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
|
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
|
|
|
|
if val, ok := n.connected.Load(nodeID); ok {
|
|
return val
|
|
}
|
|
return false
|
|
}
|
|
|
|
// IsLikelyConnected reports if a node is connected to headscale and has a
|
|
// poll session open, but doesnt lock, so might be wrong.
|
|
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
|
if val, ok := n.connected.Load(nodeID); ok {
|
|
return val
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
|
return n.connected
|
|
}
|
|
|
|
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
|
n.NotifyWithIgnore(ctx, update)
|
|
}
|
|
|
|
func (n *Notifier) NotifyWithIgnore(
|
|
ctx context.Context,
|
|
update types.StateUpdate,
|
|
ignoreNodeIDs ...types.NodeID,
|
|
) {
|
|
if n.closed {
|
|
return
|
|
}
|
|
|
|
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
|
n.b.addOrPassthrough(update)
|
|
}
|
|
|
|
func (n *Notifier) NotifyByNodeID(
|
|
ctx context.Context,
|
|
update types.StateUpdate,
|
|
nodeID types.NodeID,
|
|
) {
|
|
start := time.Now()
|
|
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
|
|
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
|
|
|
if n.closed {
|
|
return
|
|
}
|
|
|
|
if c, ok := n.nodes[nodeID]; ok {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Error().
|
|
Err(ctx.Err()).
|
|
Uint64("node.id", nodeID.Uint64()).
|
|
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
|
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
|
Msgf("update not sent, context cancelled")
|
|
if debugHighCardinalityMetrics {
|
|
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
|
} else {
|
|
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
|
}
|
|
|
|
return
|
|
case c <- update:
|
|
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
|
|
if debugHighCardinalityMetrics {
|
|
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
|
} else {
|
|
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (n *Notifier) sendAll(update types.StateUpdate) {
|
|
start := time.Now()
|
|
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
|
|
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
|
|
|
if n.closed {
|
|
return
|
|
}
|
|
|
|
for id, c := range n.nodes {
|
|
// Whenever an update is sent to all nodes, there is a chance that the node
|
|
// has disconnected and the goroutine that was supposed to consume the update
|
|
// has shut down the channel and is waiting for the lock held here in RemoveNode.
|
|
// This means that there is potential for a deadlock which would stop all updates
|
|
// going out to clients. This timeout prevents that from happening by moving on to the
|
|
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
|
|
// call will succeed and the update will go to the correct nodes on the next call.
|
|
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
|
|
defer cancel()
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Error().
|
|
Err(ctx.Err()).
|
|
Uint64("node.id", id.Uint64()).
|
|
Msgf("update not sent, context cancelled")
|
|
if debugHighCardinalityMetrics {
|
|
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
|
|
} else {
|
|
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
|
|
}
|
|
|
|
return
|
|
case c <- update:
|
|
if debugHighCardinalityMetrics {
|
|
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
|
|
} else {
|
|
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (n *Notifier) String() string {
|
|
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
|
|
|
|
var b strings.Builder
|
|
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
|
|
|
|
var keys []types.NodeID
|
|
n.connected.Range(func(key types.NodeID, value bool) bool {
|
|
keys = append(keys, key)
|
|
return true
|
|
})
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
return keys[i] < keys[j]
|
|
})
|
|
|
|
for _, key := range keys {
|
|
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
|
|
}
|
|
|
|
b.WriteString("\n")
|
|
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
|
|
|
|
for _, key := range keys {
|
|
val, _ := n.connected.Load(key)
|
|
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
|
|
}
|
|
|
|
return b.String()
|
|
}
|
|
|
|
type batcher struct {
|
|
tick *time.Ticker
|
|
|
|
mu sync.Mutex
|
|
|
|
cancelCh chan struct{}
|
|
|
|
changedNodeIDs set.Slice[types.NodeID]
|
|
nodesChanged bool
|
|
patches map[types.NodeID]tailcfg.PeerChange
|
|
patchesChanged bool
|
|
|
|
n *Notifier
|
|
}
|
|
|
|
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
|
return &batcher{
|
|
tick: time.NewTicker(batchTime),
|
|
cancelCh: make(chan struct{}),
|
|
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
|
n: n,
|
|
}
|
|
}
|
|
|
|
func (b *batcher) close() {
|
|
b.cancelCh <- struct{}{}
|
|
}
|
|
|
|
// addOrPassthrough adds the update to the batcher, if it is not a
|
|
// type that is currently batched, it will be sent immediately.
|
|
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
|
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
|
|
|
|
switch update.Type {
|
|
case types.StatePeerChanged:
|
|
b.changedNodeIDs.Add(update.ChangeNodes...)
|
|
b.nodesChanged = true
|
|
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
|
|
|
|
case types.StatePeerChangedPatch:
|
|
for _, newPatch := range update.ChangePatches {
|
|
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
|
overwritePatch(&curr, newPatch)
|
|
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
|
} else {
|
|
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
|
}
|
|
}
|
|
b.patchesChanged = true
|
|
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
|
|
|
|
default:
|
|
b.n.sendAll(update)
|
|
}
|
|
}
|
|
|
|
// flush sends all the accumulated patches to all
|
|
// nodes in the notifier.
|
|
func (b *batcher) flush() {
|
|
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
|
|
|
|
if b.nodesChanged || b.patchesChanged {
|
|
var patches []*tailcfg.PeerChange
|
|
// If a node is getting a full update from a change
|
|
// node update, then the patch can be dropped.
|
|
for nodeID, patch := range b.patches {
|
|
if b.changedNodeIDs.Contains(nodeID) {
|
|
delete(b.patches, nodeID)
|
|
} else {
|
|
patches = append(patches, &patch)
|
|
}
|
|
}
|
|
|
|
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
|
sort.Slice(changedNodes, func(i, j int) bool {
|
|
return changedNodes[i] < changedNodes[j]
|
|
})
|
|
|
|
if b.changedNodeIDs.Slice().Len() > 0 {
|
|
update := types.StateUpdate{
|
|
Type: types.StatePeerChanged,
|
|
ChangeNodes: changedNodes,
|
|
}
|
|
|
|
b.n.sendAll(update)
|
|
}
|
|
|
|
if len(patches) > 0 {
|
|
patchUpdate := types.StateUpdate{
|
|
Type: types.StatePeerChangedPatch,
|
|
ChangePatches: patches,
|
|
}
|
|
|
|
b.n.sendAll(patchUpdate)
|
|
}
|
|
|
|
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
|
notifierBatcherChanges.WithLabelValues().Set(0)
|
|
b.nodesChanged = false
|
|
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
|
notifierBatcherPatches.WithLabelValues().Set(0)
|
|
b.patchesChanged = false
|
|
}
|
|
}
|
|
|
|
func (b *batcher) doWork() {
|
|
for {
|
|
select {
|
|
case <-b.cancelCh:
|
|
return
|
|
case <-b.tick.C:
|
|
b.flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// overwritePatch takes the current patch and a newer patch
|
|
// and override any field that has changed.
|
|
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
|
if newPatch.DERPRegion != 0 {
|
|
currPatch.DERPRegion = newPatch.DERPRegion
|
|
}
|
|
|
|
if newPatch.Cap != 0 {
|
|
currPatch.Cap = newPatch.Cap
|
|
}
|
|
|
|
if newPatch.CapMap != nil {
|
|
currPatch.CapMap = newPatch.CapMap
|
|
}
|
|
|
|
if newPatch.Endpoints != nil {
|
|
currPatch.Endpoints = newPatch.Endpoints
|
|
}
|
|
|
|
if newPatch.Key != nil {
|
|
currPatch.Key = newPatch.Key
|
|
}
|
|
|
|
if newPatch.KeySignature != nil {
|
|
currPatch.KeySignature = newPatch.KeySignature
|
|
}
|
|
|
|
if newPatch.DiscoKey != nil {
|
|
currPatch.DiscoKey = newPatch.DiscoKey
|
|
}
|
|
|
|
if newPatch.Online != nil {
|
|
currPatch.Online = newPatch.Online
|
|
}
|
|
|
|
if newPatch.LastSeen != nil {
|
|
currPatch.LastSeen = newPatch.LastSeen
|
|
}
|
|
|
|
if newPatch.KeyExpiry != nil {
|
|
currPatch.KeyExpiry = newPatch.KeyExpiry
|
|
}
|
|
}
|