1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-08 00:11:42 +01:00
juanfont.headscale/hscontrol/notifier/notifier.go
Kristoffer Dalby 83769ba715
Replace database locks with transactions (#1701)
This commits removes the locks used to guard data integrity for the
database and replaces them with Transactions, turns out that SQL had
a way to deal with this all along.

This reduces the complexity we had with multiple locks that might stack
or recurse (database, nofitifer, mapper). All notifications and state
updates are now triggered _after_ a database change.


Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-02-08 17:28:19 +01:00

174 lines
3.9 KiB
Go

package notifier
import (
"context"
"fmt"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"tailscale.com/types/key"
)
type Notifier struct {
l sync.RWMutex
nodes map[string]chan<- types.StateUpdate
connected map[key.MachinePublic]bool
}
func NewNotifier() *Notifier {
return &Notifier{
nodes: make(map[string]chan<- types.StateUpdate),
connected: make(map[key.MachinePublic]bool),
}
}
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to add node")
n.l.Lock()
defer n.l.Unlock()
n.nodes[machineKey.String()] = c
n.connected[machineKey] = true
log.Trace().
Str("machine_key", machineKey.ShortString()).
Int("open_chans", len(n.nodes)).
Msg("Added new channel")
}
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
defer log.Trace().
Caller().
Str("key", machineKey.ShortString()).
Msg("releasing lock to remove node")
n.l.Lock()
defer n.l.Unlock()
if len(n.nodes) == 0 {
return
}
delete(n.nodes, machineKey.String())
n.connected[machineKey] = false
log.Trace().
Str("machine_key", machineKey.ShortString()).
Int("open_chans", len(n.nodes)).
Msg("Removed channel")
}
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
n.l.RLock()
defer n.l.RUnlock()
return n.connected[machineKey]
}
// TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
return n.connected
}
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
}
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignore ...string,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Interface("type", update.Type).
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
for key, c := range n.nodes {
if util.IsStringInSlice(ignore, key) {
continue
}
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", key).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}
func (n *Notifier) NotifyByMachineKey(
ctx context.Context,
update types.StateUpdate,
mKey key.MachinePublic,
) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Interface("type", update.Type).
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
if c, ok := n.nodes[mKey.String()]; ok {
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Str("mkey", mKey.String()).
Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}
func (n *Notifier) String() string {
n.l.RLock()
defer n.l.RUnlock()
str := []string{"Notifier, in map:\n"}
for k, v := range n.nodes {
str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
}
return strings.Join(str, "")
}