mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-20 19:09:07 +01:00
83769ba715
This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
174 lines
3.9 KiB
Go
174 lines
3.9 KiB
Go
package notifier
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"github.com/rs/zerolog/log"
|
|
"tailscale.com/types/key"
|
|
)
|
|
|
|
type Notifier struct {
|
|
l sync.RWMutex
|
|
nodes map[string]chan<- types.StateUpdate
|
|
connected map[key.MachinePublic]bool
|
|
}
|
|
|
|
func NewNotifier() *Notifier {
|
|
return &Notifier{
|
|
nodes: make(map[string]chan<- types.StateUpdate),
|
|
connected: make(map[key.MachinePublic]bool),
|
|
}
|
|
}
|
|
|
|
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
|
|
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
|
|
defer log.Trace().
|
|
Caller().
|
|
Str("key", machineKey.ShortString()).
|
|
Msg("releasing lock to add node")
|
|
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
|
|
n.nodes[machineKey.String()] = c
|
|
n.connected[machineKey] = true
|
|
|
|
log.Trace().
|
|
Str("machine_key", machineKey.ShortString()).
|
|
Int("open_chans", len(n.nodes)).
|
|
Msg("Added new channel")
|
|
}
|
|
|
|
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
|
|
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
|
|
defer log.Trace().
|
|
Caller().
|
|
Str("key", machineKey.ShortString()).
|
|
Msg("releasing lock to remove node")
|
|
|
|
n.l.Lock()
|
|
defer n.l.Unlock()
|
|
|
|
if len(n.nodes) == 0 {
|
|
return
|
|
}
|
|
|
|
delete(n.nodes, machineKey.String())
|
|
n.connected[machineKey] = false
|
|
|
|
log.Trace().
|
|
Str("machine_key", machineKey.ShortString()).
|
|
Int("open_chans", len(n.nodes)).
|
|
Msg("Removed channel")
|
|
}
|
|
|
|
// IsConnected reports if a node is connected to headscale and has a
|
|
// poll session open.
|
|
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
|
|
n.l.RLock()
|
|
defer n.l.RUnlock()
|
|
|
|
return n.connected[machineKey]
|
|
}
|
|
|
|
// TODO(kradalby): This returns a pointer and can be dangerous.
|
|
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
|
|
return n.connected
|
|
}
|
|
|
|
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
|
n.NotifyWithIgnore(ctx, update)
|
|
}
|
|
|
|
func (n *Notifier) NotifyWithIgnore(
|
|
ctx context.Context,
|
|
update types.StateUpdate,
|
|
ignore ...string,
|
|
) {
|
|
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
|
defer log.Trace().
|
|
Caller().
|
|
Interface("type", update.Type).
|
|
Msg("releasing lock, finished notifying")
|
|
|
|
n.l.RLock()
|
|
defer n.l.RUnlock()
|
|
|
|
for key, c := range n.nodes {
|
|
if util.IsStringInSlice(ignore, key) {
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Error().
|
|
Err(ctx.Err()).
|
|
Str("mkey", key).
|
|
Any("origin", ctx.Value("origin")).
|
|
Any("hostname", ctx.Value("hostname")).
|
|
Msgf("update not sent, context cancelled")
|
|
|
|
return
|
|
case c <- update:
|
|
log.Trace().
|
|
Str("mkey", key).
|
|
Any("origin", ctx.Value("origin")).
|
|
Any("hostname", ctx.Value("hostname")).
|
|
Msgf("update successfully sent on chan")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (n *Notifier) NotifyByMachineKey(
|
|
ctx context.Context,
|
|
update types.StateUpdate,
|
|
mKey key.MachinePublic,
|
|
) {
|
|
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
|
|
defer log.Trace().
|
|
Caller().
|
|
Interface("type", update.Type).
|
|
Msg("releasing lock, finished notifying")
|
|
|
|
n.l.RLock()
|
|
defer n.l.RUnlock()
|
|
|
|
if c, ok := n.nodes[mKey.String()]; ok {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Error().
|
|
Err(ctx.Err()).
|
|
Str("mkey", mKey.String()).
|
|
Any("origin", ctx.Value("origin")).
|
|
Any("hostname", ctx.Value("hostname")).
|
|
Msgf("update not sent, context cancelled")
|
|
|
|
return
|
|
case c <- update:
|
|
log.Trace().
|
|
Str("mkey", mKey.String()).
|
|
Any("origin", ctx.Value("origin")).
|
|
Any("hostname", ctx.Value("hostname")).
|
|
Msgf("update successfully sent on chan")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (n *Notifier) String() string {
|
|
n.l.RLock()
|
|
defer n.l.RUnlock()
|
|
|
|
str := []string{"Notifier, in map:\n"}
|
|
|
|
for k, v := range n.nodes {
|
|
str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
|
|
}
|
|
|
|
return strings.Join(str, "")
|
|
}
|