1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-10 13:46:46 +02:00

changes instead of stateupdate

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-26 23:06:22 +02:00
parent 21d49ea5c0
commit 4bb42ef6a4
No known key found for this signature in database
11 changed files with 426 additions and 373 deletions

View File

@ -30,7 +30,6 @@ import (
derpServer "github.com/juanfont/headscale/hscontrol/derp/server" derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/dns" "github.com/juanfont/headscale/hscontrol/dns"
"github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@ -95,8 +94,7 @@ type Headscale struct {
extraRecordMan *dns.ExtraRecordsMan extraRecordMan *dns.ExtraRecordsMan
primaryRoutes *routes.PrimaryRoutes primaryRoutes *routes.PrimaryRoutes
mapBatcher *mapper.Batcher mapBatcher *mapper.Batcher
nodeNotifier *notifier.Notifier
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
@ -169,12 +167,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
defer cancel() defer cancel()
oidcProvider, err := NewAuthProviderOIDC( oidcProvider, err := NewAuthProviderOIDC(
ctx, ctx,
&app,
cfg.ServerURL, cfg.ServerURL,
&cfg.OIDC, &cfg.OIDC,
app.db,
app.nodeNotifier,
app.ipAlloc,
app.polMan,
) )
if err != nil { if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable { if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -287,7 +282,15 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
if changed { if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes") log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
h.nodeNotifier.NotifyAll(update) // TODO(kradalby): Not sure how I feel about this one, feel like we
// can be more clever? but at the same time, if they are passed straight
// through later, its fine?
for _, node := range update.ChangePatches {
h.Change(types.Change{NodeChange: types.NodeChange{
ID: types.NodeID(node.NodeID),
ExpiryChanged: true,
}})
}
} }
case <-derpTickerChan: case <-derpTickerChan:
@ -298,10 +301,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
} }
h.nodeNotifier.NotifyAll(types.StateUpdate{ h.Change(types.Change{DERPChanged: true})
Type: types.StateDERPUpdated,
DERPMap: h.DERPMap,
})
case records, ok := <-extraRecordsUpdate: case records, ok := <-extraRecordsUpdate:
if !ok { if !ok {
@ -309,18 +309,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
} }
h.cfg.TailcfgDNSConfig.ExtraRecords = records h.cfg.TailcfgDNSConfig.ExtraRecords = records
// TODO(kradalby): We can probably do better than sending a full update here, h.Change(types.Change{ExtraRecordsChanged: true})
// but for now this will ensure that all of the nodes get the new records.
h.nodeNotifier.NotifyAll(types.UpdateFull())
} }
} }
} }
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{}, req any,
info *grpc.UnaryServerInfo, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler, handler grpc.UnaryHandler,
) (interface{}, error) { ) (any, error) {
// Check if the request is coming from the on-server client. // Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability // This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client // with the "legacy" database-based client
@ -498,55 +496,55 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router return router
} }
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. // // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB? // // Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus? // // Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes // // A bool is returned indicating if a full update was sent to all nodes
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { // func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers() // users, err := db.ListUsers()
if err != nil { // if err != nil {
return err // return err
} // }
changed, err := polMan.SetUsers(users) // changed, err := polMan.SetUsers(users)
if err != nil { // if err != nil {
return err // return err
} // }
if changed { // if changed {
notif.NotifyAll(types.UpdateFull()) // notif.NotifyAll(types.UpdateFull())
} // }
return nil // return nil
} // }
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. // // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB? // // Maybe we should attempt a new in memory state and not go via the DB?
// Maybe this should be implemented as an event bus? // // Maybe this should be implemented as an event bus?
// A bool is returned indicating if a full update was sent to all nodes // // A bool is returned indicating if a full update was sent to all nodes
func nodesChangedHook( // func nodesChangedHook(
db *db.HSDatabase, // db *db.HSDatabase,
polMan policy.PolicyManager, // polMan policy.PolicyManager,
notif *notifier.Notifier, // notif *notifier.Notifier,
) (bool, error) { // ) (bool, error) {
nodes, err := db.ListNodes() // nodes, err := db.ListNodes()
if err != nil { // if err != nil {
return false, err // return false, err
} // }
filterChanged, err := polMan.SetNodes(nodes) // filterChanged, err := polMan.SetNodes(nodes)
if err != nil { // if err != nil {
return false, err // return false, err
} // }
if filterChanged { // if filterChanged {
notif.NotifyAll(types.UpdateFull()) // notif.NotifyAll(types.UpdateFull())
return true, nil // return true, nil
} // }
return false, nil // return false, nil
} // }
// Serve launches the HTTP and gRPC server service Headscale and the API. // Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
@ -577,7 +575,6 @@ func (h *Headscale) Serve() error {
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP) h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapBatcher = mapper.NewBatcherAndMapper(h.db, h.cfg, h.DERPMap, h.polMan, h.primaryRoutes) h.mapBatcher = mapper.NewBatcherAndMapper(h.db, h.cfg, h.DERPMap, h.polMan, h.primaryRoutes)
h.nodeNotifier = notifier.NewNotifier(h.cfg, h.mapBatcher)
h.mapBatcher.Start() h.mapBatcher.Start()
defer h.mapBatcher.Close() defer h.mapBatcher.Close()
@ -869,7 +866,7 @@ func (h *Headscale) Serve() error {
log.Error().Err(err).Msg("failed to approve routes after new policy") log.Error().Err(err).Msg("failed to approve routes after new policy")
} }
h.nodeNotifier.NotifyAll(types.UpdateFull()) h.Change(types.Change{PolicyChanged: true})
} }
default: default:
info := func(msg string) { log.Info().Msg(msg) } info := func(msg string) { log.Info().Msg(msg) }
@ -895,7 +892,6 @@ func (h *Headscale) Serve() error {
} }
info("closing node notifier") info("closing node notifier")
h.nodeNotifier.Close()
info("waiting for netmap stream to close") info("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait() h.pollNetMapStreamWG.Wait()
@ -1198,3 +1194,7 @@ func (h *Headscale) autoApproveNodes() error {
return nil return nil
} }
func (h *Headscale) Change(c types.Change) {
h.mapBatcher.AddWork(&c)
}

View File

@ -90,7 +90,13 @@ func (h *Headscale) handleExistingNode(
return nil, fmt.Errorf("deleting ephemeral node: %w", err) return nil, fmt.Errorf("deleting ephemeral node: %w", err)
} }
h.nodeNotifier.NotifyAll(types.UpdatePeerRemoved(node.ID)) h.Change(types.Change{NodeChange: types.NodeChange{
ID: node.ID,
RemovedNode: true,
// TODO(kradalby): Remove when specifics are implemented
FullChange: true,
}})
} }
expired = true expired = true
@ -101,7 +107,13 @@ func (h *Headscale) handleExistingNode(
return nil, fmt.Errorf("setting node expiry: %w", err) return nil, fmt.Errorf("setting node expiry: %w", err)
} }
h.nodeNotifier.NotifyWithIgnore(types.UpdateExpire(node.ID, requestExpiry), node.ID) h.Change(types.Change{NodeChange: types.NodeChange{
ID: node.ID,
ExpiryChanged: true,
// TODO(kradalby): Remove when specifics are implemented
FullChange: true,
}})
} }
return nodeToRegisterResponse(node), nil return nodeToRegisterResponse(node), nil
@ -238,11 +250,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
return nil, err return nil, err
} }
updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("nodes changed hook: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here. // dependency here.
// Because the way the policy manager works, we need to have the node // Because the way the policy manager works, we need to have the node
@ -254,14 +261,17 @@ func (h *Headscale) handleRegisterWithAuthKey(
// ensure we send an update. // ensure we send an update.
// This works, but might be another good candidate for doing some sort of // This works, but might be another good candidate for doing some sort of
// eventbus. // eventbus.
routesChanged := policy.AutoApproveRoutes(h.polMan, node) // TODO(kradalby): This needs to be ran as part of the batcher maybe?
// now since we dont update the node/pol here anymore
_ = policy.AutoApproveRoutes(h.polMan, node)
if err := h.db.DB.Save(node).Error; err != nil { if err := h.db.DB.Save(node).Error; err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err) return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
} }
if !updateSent || routesChanged { h.Change(types.Change{NodeChange: types.NodeChange{
h.nodeNotifier.NotifyAll(types.UpdatePeerChanged(node.ID)) ID: node.ID,
} NewNode: true,
}})
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
MachineAuthorized: true, MachineAuthorized: true,

View File

@ -352,8 +352,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationMethod string, registrationMethod string,
ipv4 *netip.Addr, ipv4 *netip.Addr,
ipv6 *netip.Addr, ipv6 *netip.Addr,
) (*types.Node, bool, error) { ) (*types.Node, types.Change, error) {
var newNode bool var change types.Change
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok { if reg, ok := hsdb.regCache.Get(registrationID); ok {
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
@ -405,7 +405,10 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
} }
close(reg.Registered) close(reg.Registered)
newNode = true change.NodeChange = types.NodeChange{
ID: node.ID,
NewNode: true,
}
return node, err return node, err
} else { } else {
// If the node is already registered, this is a refresh. // If the node is already registered, this is a refresh.
@ -413,6 +416,11 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
if err != nil { if err != nil {
return nil, err return nil, err
} }
change.NodeChange = types.NodeChange{
ID: node.ID,
ExpiryChanged: true,
}
return node, nil return node, nil
} }
} }
@ -420,7 +428,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
return nil, ErrNodeNotFoundRegistrationCache return nil, ErrNodeNotFoundRegistrationCache
}) })
return node, newNode, err return node, change, err
} }
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {

View File

@ -15,10 +15,6 @@ import (
func (h *Headscale) debugHTTPServer() *http.Server { func (h *Headscale) debugHTTPServer() *http.Server {
debugMux := http.NewServeMux() debugMux := http.NewServeMux()
debug := tsweb.Debugger(debugMux) debug := tsweb.Debugger(debugMux)
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String()))
}))
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config, err := json.MarshalIndent(h.cfg, "", " ") config, err := json.MarshalIndent(h.cfg, "", " ")
if err != nil { if err != nil {

View File

@ -58,10 +58,10 @@ func (api headscaleV1APIServer) CreateUser(
return nil, err return nil, err
} }
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) api.h.Change(types.Change{UserChange: types.UserChange{
if err != nil { ID: types.UserID(user.ID),
return nil, fmt.Errorf("updating resources using user: %w", err) NewUser: true,
} }})
return &v1.CreateUserResponse{User: user.Proto()}, nil return &v1.CreateUserResponse{User: user.Proto()}, nil
} }
@ -102,10 +102,10 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err return nil, err
} }
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) api.h.Change(types.Change{UserChange: types.UserChange{
if err != nil { ID: types.UserID(user.ID),
return nil, fmt.Errorf("updating resources using user: %w", err) RemovedUser: true,
} }})
return &v1.DeleteUserResponse{}, nil return &v1.DeleteUserResponse{}, nil
} }
@ -253,7 +253,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, fmt.Errorf("looking up user: %w", err) return nil, fmt.Errorf("looking up user: %w", err)
} }
node, _, err := api.h.db.HandleNodeFromAuthPath( node, change, err := api.h.db.HandleNodeFromAuthPath(
registrationId, registrationId,
types.UserID(user.ID), types.UserID(user.ID),
nil, nil,
@ -264,11 +264,6 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err return nil, err
} }
updateSent, err := nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using node: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here. // dependency here.
// Because the way the policy manager works, we need to have the node // Because the way the policy manager works, we need to have the node
@ -280,14 +275,14 @@ func (api headscaleV1APIServer) RegisterNode(
// ensure we send an update. // ensure we send an update.
// This works, but might be another good candidate for doing some sort of // This works, but might be another good candidate for doing some sort of
// eventbus. // eventbus.
routesChanged := policy.AutoApproveRoutes(api.h.polMan, node) // TODO(kradalby): This needs to be ran as part of the batcher maybe?
// now since we dont update the node/pol here anymore
_ = policy.AutoApproveRoutes(api.h.polMan, node)
if err := api.h.db.DB.Save(node).Error; err != nil { if err := api.h.db.DB.Save(node).Error; err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err) return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
} }
if !updateSent || routesChanged { api.h.Change(change)
api.h.nodeNotifier.NotifyAll(types.UpdatePeerChanged(node.ID))
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
} }
@ -305,7 +300,7 @@ func (api headscaleV1APIServer) GetNode(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.ID) resp.Online = api.h.mapBatcher.IsLikelyConnected(node.ID)
return &v1.GetNodeResponse{Node: resp}, nil return &v1.GetNodeResponse{Node: resp}, nil
} }
@ -335,7 +330,10 @@ func (api headscaleV1APIServer) SetTags(
}, status.Error(codes.InvalidArgument, err.Error()) }, status.Error(codes.InvalidArgument, err.Error())
} }
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) api.h.Change(types.Change{NodeChange: types.NodeChange{
ID: node.ID,
TagsChanged: true,
}})
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -379,11 +377,11 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
if api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) { api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
api.h.nodeNotifier.NotifyAll(types.UpdateFull()) api.h.Change(types.Change{NodeChange: types.NodeChange{
} else { ID: node.ID,
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) RoutesChanged: true,
} }})
proto := node.Proto() proto := node.Proto()
proto.SubnetRoutes = util.PrefixesToString(api.h.primaryRoutes.PrimaryRoutes(node.ID)) proto.SubnetRoutes = util.PrefixesToString(api.h.primaryRoutes.PrimaryRoutes(node.ID))
@ -418,7 +416,10 @@ func (api headscaleV1APIServer) DeleteNode(
return nil, err return nil, err
} }
api.h.nodeNotifier.NotifyAll(types.UpdatePeerRemoved(node.ID)) api.h.Change(types.Change{NodeChange: types.NodeChange{
ID: node.ID,
RemovedNode: true,
}})
return &v1.DeleteNodeResponse{}, nil return &v1.DeleteNodeResponse{}, nil
} }
@ -442,11 +443,11 @@ func (api headscaleV1APIServer) ExpireNode(
return nil, err return nil, err
} }
api.h.nodeNotifier.NotifyByNodeID( // TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
types.UpdateSelf(node.ID), api.h.Change(types.Change{NodeChange: types.NodeChange{
node.ID) ID: node.ID,
ExpiryChanged: true,
api.h.nodeNotifier.NotifyWithIgnore(types.UpdateExpire(node.ID, now), node.ID) }})
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -476,7 +477,12 @@ func (api headscaleV1APIServer) RenameNode(
return nil, err return nil, err
} }
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) api.h.Change(types.Change{NodeChange: types.NodeChange{
ID: node.ID,
// TODO(kradalby): Not sure if this is what we want to send, probably
// we can do a delta change here.
NewNode: true,
}})
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -495,7 +501,7 @@ func (api headscaleV1APIServer) ListNodes(
// probably be done once. // probably be done once.
// TODO(kradalby): This should be done in one tx. // TODO(kradalby): This should be done in one tx.
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() isLikelyConnected := api.h.mapBatcher.LikelyConnectedMap()
if request.GetUser() != "" { if request.GetUser() != "" {
user, err := api.h.db.GetUserByName(request.GetUser()) user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil { if err != nil {
@ -572,10 +578,13 @@ func (api headscaleV1APIServer) MoveNode(
return nil, err return nil, err
} }
api.h.nodeNotifier.NotifyByNodeID( // TODO(kradalby): ensure that both the selfupdate and peer updates are sent
types.UpdateSelf(node.ID), api.h.Change(types.Change{NodeChange: types.NodeChange{
node.ID) ID: node.ID,
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) // TODO(kradalby): Not sure if this is what we want to send, probably
// we can do a delta change here.
NewNode: true,
}})
return &v1.MoveNodeResponse{Node: node.Proto()}, nil return &v1.MoveNodeResponse{Node: node.Proto()}, nil
} }
@ -758,7 +767,7 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err return nil, err
} }
api.h.nodeNotifier.NotifyAll(types.UpdateFull()) api.h.Change(types.Change{PolicyChanged: true})
} }
response := &v1.SetPolicyResponse{ response := &v1.SetPolicyResponse{

View File

@ -28,17 +28,6 @@ func init() {
} }
} }
type ChangeWork struct {
NodeID *types.NodeID
Update types.StateUpdate
}
type nodeConn struct {
c chan<- []byte
compress string
version tailcfg.CapabilityVersion
}
type Batcher struct { type Batcher struct {
mu deadlock.RWMutex mu deadlock.RWMutex
@ -58,7 +47,7 @@ type Batcher struct {
// this should serve for the experiment. // this should serve for the experiment.
cancelCh chan struct{} cancelCh chan struct{}
workCh chan *ChangeWork workCh chan *types.Change
} }
func NewBatcherAndMapper( func NewBatcherAndMapper(
@ -80,7 +69,7 @@ func NewBatcher(mapper *mapper) *Batcher {
mapper: mapper, mapper: mapper,
cancelCh: make(chan struct{}), cancelCh: make(chan struct{}),
// TODO: No limit for now, this needs to be changed // TODO: No limit for now, this needs to be changed
workCh: make(chan *ChangeWork, (1<<16)-1), workCh: make(chan *types.Change, (1<<16)-1),
nodes: make(map[types.NodeID]nodeConn), nodes: make(map[types.NodeID]nodeConn),
connected: make(map[types.NodeID]*time.Time), connected: make(map[types.NodeID]*time.Time),
@ -110,37 +99,46 @@ func (b *Batcher) AddNode(id types.NodeID, c chan<- []byte, compress string, ver
} }
b.nodes[id] = nodeConn{ b.nodes[id] = nodeConn{
id: id,
c: c, c: c,
compress: compress, compress: compress,
version: version, version: version,
// TODO(kradalby): Not sure about this one yet.
mapper: b.mapper,
} }
b.connected[id] = nil // nil means connected b.connected[id] = nil // nil means connected
b.AddWork(&ChangeWork{ // TODO(kradalby): Handle:
NodeID: &id, // - Updating peers with online status
Update: types.UpdateFull(), // - Updating routes in routemanager and peers
}) b.AddWork(&types.Change{NodeChange: types.NodeChange{
ID: id,
Online: true,
}})
} }
func (b *Batcher) RemoveNode(id types.NodeID, c chan<- []byte) bool { func (b *Batcher) RemoveNode(id types.NodeID, c chan<- []byte) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if curr, ok := b.nodes[id]; ok { if curr, ok := b.nodes[id]; ok {
if curr.c != c { if curr.c != c {
return false return
} }
} }
delete(b.nodes, id) delete(b.nodes, id)
b.connected[id] = ptr.To(time.Now()) b.connected[id] = ptr.To(time.Now())
return true // TODO(kradalby): Handle:
// - Updating peers with lastseen status, and only if not replaced
// - Updating routes in routemanager and peers
} }
func (b *Batcher) AddWork(work *ChangeWork) { func (b *Batcher) AddWork(change *types.Change) {
log.Trace().Msgf("adding work: %v", work.Update) log.Trace().Msgf("adding work: %v", change)
b.workCh <- work b.workCh <- change
} }
func (b *Batcher) IsConnected(id types.NodeID) bool { func (b *Batcher) IsConnected(id types.NodeID) bool {
@ -194,12 +192,12 @@ func (b *Batcher) doWork() {
case <-b.cancelCh: case <-b.cancelCh:
return return
case work := <-b.workCh: case work := <-b.workCh:
b.processWork(work) b.processChange(work)
} }
} }
} }
// processWork is the current bottleneck where all the updates get picked up // processChange is the current bottleneck where all the updates get picked up
// one by one and processed. This will have to change, it needs to go as fast as // one by one and processed. This will have to change, it needs to go as fast as
// possible and just pass it on to the nodes. Currently it wont block because the // possible and just pass it on to the nodes. Currently it wont block because the
// work channel is super large, but it might not be able to keep up. // work channel is super large, but it might not be able to keep up.
@ -207,90 +205,125 @@ func (b *Batcher) doWork() {
// mean a lot of goroutines, hanging around. // mean a lot of goroutines, hanging around.
// Another is just a worker pool that picks up work and processes it, // Another is just a worker pool that picks up work and processes it,
// and passes it on to the nodes. That might be complicated with order? // and passes it on to the nodes. That might be complicated with order?
func (b *Batcher) processWork(work *ChangeWork) { func (b *Batcher) processChange(c *types.Change) {
b.mu.RLock() b.mu.RLock()
defer b.mu.RUnlock() defer b.mu.RUnlock()
log.Trace().Msgf("processing work: %v", work) log.Trace().Msgf("processing work: %v", c)
if work.NodeID != nil {
id := *work.NodeID
node, ok := b.nodes[id]
if !ok {
log.Trace().Msgf("node %d not found in batcher, skipping work: %v", id, work.Update)
return
}
resp, err := b.resp(id, &node, work)
if err != nil {
log.Debug().Msgf("creating mapResp for %d: %s", id, err)
}
node.c <- resp
return
}
for id, node := range b.nodes { for id, node := range b.nodes {
resp, err := b.resp(id, &node, work) err := node.change(c)
if err != nil { log.Error().Err(err).Uint64("node.id", id.Uint64()).Msgf("processing work for node %d", id)
log.Debug().Msgf("creating mapResp for %d: %s", id, err)
}
node.c <- resp
} }
} }
type nodeConn struct {
id types.NodeID
c chan<- []byte
compress string
version tailcfg.CapabilityVersion
mapper *mapper
}
type changeUpdate int
const (
_ changeUpdate = iota
ignoreUpdate
partialUpdate
fullUpdate
)
func determineChange(c *types.Change) changeUpdate {
if c == nil {
return ignoreUpdate
}
if c.FullUpdate() {
return fullUpdate
}
return fullUpdate
}
func (nc *nodeConn) change(c *types.Change) error {
switch determineChange(c) {
case partialUpdate:
return nc.partialUpdate(c)
case fullUpdate:
return nc.fullUpdate()
default:
log.Trace().Msgf("ignoring change: %v", c)
return nil
}
}
func (nc *nodeConn) partialUpdate(c *types.Change) error {
return nil
}
func (nc *nodeConn) fullUpdate() error {
data, err := nc.mapper.fullMapResponse(nc.id, nc.version, nc.compress)
if err != nil {
return err
}
nc.c <- data
return nil
}
// resp is the logic that used to reside in the poller, but is now moved // resp is the logic that used to reside in the poller, but is now moved
// to process before sending to the node. The idea is that we do not want to // to process before sending to the node. The idea is that we do not want to
// be blocked on the send channel to the individual node, but rather // be blocked on the send channel to the individual node, but rather
// process all the work and then send the responses to the nodes. // process all the work and then send the responses to the nodes.
// TODO(kradalby): This is a temporary solution, as we explore this // TODO(kradalby): This is a temporary solution, as we explore this
// approach, we will likely need to refactor this further. // approach, we will likely need to refactor this further.
func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte, error) { // func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte, error) {
var data []byte // var data []byte
var err error // var err error
// TODO(kradalby): This should not be necessary, mapper only // // TODO(kradalby): This should not be necessary, mapper only
// use compress and version, and this can either be moved out // // use compress and version, and this can either be moved out
// or passed directly. The mapreq isnt needed. // // or passed directly. The mapreq isnt needed.
req := tailcfg.MapRequest{ // req := tailcfg.MapRequest{
Compress: nc.compress, // Compress: nc.compress,
Version: nc.version, // Version: nc.version,
} // }
// TODO(kradalby): We dont want to use the db here. We should // // TODO(kradalby): We dont want to use the db here. We should
// just have the node available, or at least quickly accessible // // just have the node available, or at least quickly accessible
// from the new fancy mem state we want. // // from the new fancy mem state we want.
node, err := b.mapper.db.GetNodeByID(id) // node, err := b.mapper.db.GetNodeByID(id)
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
switch work.Update.Type { // switch work.Update.Type {
case types.StateFullUpdate: // case types.StateFullUpdate:
data, err = b.mapper.fullMapResponse(req, node) // data, err = b.mapper.fullMapResponse(req, node)
case types.StatePeerChanged: // case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(work.Update.ChangeNodes)) // changed := make(map[types.NodeID]bool, len(work.Update.ChangeNodes))
for _, nodeID := range work.Update.ChangeNodes { // for _, nodeID := range work.Update.ChangeNodes {
changed[nodeID] = true // changed[nodeID] = true
} // }
data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches) // data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches)
case types.StatePeerChangedPatch: // case types.StatePeerChangedPatch:
data, err = b.mapper.peerChangedPatchResponse(req, node, work.Update.ChangePatches) // data, err = b.mapper.peerChangedPatchResponse(req, node, work.Update.ChangePatches)
case types.StatePeerRemoved: // case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(work.Update.Removed)) // changed := make(map[types.NodeID]bool, len(work.Update.Removed))
for _, nodeID := range work.Update.Removed { // for _, nodeID := range work.Update.Removed {
changed[nodeID] = false // changed[nodeID] = false
} // }
data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches) // data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches)
case types.StateSelfUpdate: // case types.StateSelfUpdate:
data, err = b.mapper.peerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches) // data, err = b.mapper.peerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches)
case types.StateDERPUpdated: // case types.StateDERPUpdated:
data, err = b.mapper.derpMapResponse(req, node, work.Update.DERPMap) // data, err = b.mapper.derpMapResponse(req, node, work.Update.DERPMap)
} // }
return data, err // return data, err
} // }

View File

@ -151,16 +151,22 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
// fullMapResponse returns a MapResponse for the given node. // fullMapResponse returns a MapResponse for the given node.
func (m *mapper) fullMapResponse( func (m *mapper) fullMapResponse(
mapRequest tailcfg.MapRequest, nodeID types.NodeID,
node *types.Node, capVer tailcfg.CapabilityVersion,
compress string,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
peers, err := m.listPeers(node.ID) node, err := m.db.GetNodeByID(nodeID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) peers, err := m.listPeers(nodeID)
if err != nil {
return nil, err
}
resp, err := m.baseWithConfigMapResponse(node, capVer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -171,7 +177,7 @@ func (m *mapper) fullMapResponse(
m.polMan, m.polMan,
m.primary, m.primary,
node, node,
mapRequest.Version, capVer,
peers, peers,
m.cfg, m.cfg,
) )
@ -179,7 +185,7 @@ func (m *mapper) fullMapResponse(
return nil, err return nil, err
} }
return marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) return marshalMapResponse(resp, node, compress, messages...)
} }
func (m *mapper) derpMapResponse( func (m *mapper) derpMapResponse(
@ -192,7 +198,7 @@ func (m *mapper) derpMapResponse(
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.DERPMap = derpMap resp.DERPMap = derpMap
return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) return marshalMapResponse(&resp, node, mapRequest.Compress)
} }
func (m *mapper) peerChangedResponse( func (m *mapper) peerChangedResponse(
@ -269,7 +275,7 @@ func (m *mapper) peerChangedResponse(
} }
resp.Node = tailnode resp.Node = tailnode
return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...) return marshalMapResponse(&resp, node, mapRequest.Compress, messages...)
} }
// peerChangedPatchResponse creates a patch MapResponse with // peerChangedPatchResponse creates a patch MapResponse with
@ -282,11 +288,10 @@ func (m *mapper) peerChangedPatchResponse(
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.PeersChangedPatch = changed resp.PeersChangedPatch = changed
return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) return marshalMapResponse(&resp, node, mapRequest.Compress)
} }
func marshalMapResponse( func marshalMapResponse(
mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
node *types.Node, node *types.Node,
compression string, compression string,
@ -300,7 +305,6 @@ func marshalMapResponse(
if debugDumpMapResponsePath != "" { if debugDumpMapResponsePath != "" {
data := map[string]any{ data := map[string]any{
"Messages": messages, "Messages": messages,
"MapRequest": mapRequest,
"MapResponse": resp, "MapResponse": resp,
} }

View File

@ -97,36 +97,36 @@ func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.mbatcher.LikelyConnectedMap() return n.mbatcher.LikelyConnectedMap()
} }
func (n *Notifier) NotifyAll(update types.StateUpdate) { // func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update) // n.NotifyWithIgnore(update)
} // }
func (n *Notifier) NotifyWithIgnore( // func (n *Notifier) NotifyWithIgnore(
update types.StateUpdate, // update types.StateUpdate,
ignoreNodeIDs ...types.NodeID, // ignoreNodeIDs ...types.NodeID,
) { // ) {
if n.closed { // if n.closed {
return // return
} // }
n.b.addOrPassthrough(update) // n.b.addOrPassthrough(update)
} // }
func (n *Notifier) NotifyByNodeID( // func (n *Notifier) NotifyByNodeID(
update types.StateUpdate, // update types.StateUpdate,
nodeID types.NodeID, // nodeID types.NodeID,
) { // ) {
n.mbatcher.AddWork(&mapper.ChangeWork{ // n.mbatcher.AddWork(&mapper.ChangeWork{
NodeID: &nodeID, // NodeID: &nodeID,
Update: update, // Update: update,
}) // })
} // }
func (n *Notifier) sendAll(update types.StateUpdate) { // func (n *Notifier) sendAll(update types.StateUpdate) {
n.mbatcher.AddWork(&mapper.ChangeWork{ // n.mbatcher.AddWork(&mapper.ChangeWork{
Update: update, // Update: update,
}) // })
} // }
func (n *Notifier) String() string { func (n *Notifier) String() string {
notifierWaitersForLock.WithLabelValues("lock", "string").Inc() notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
@ -197,7 +197,7 @@ func (b *batcher) addOrPassthrough(update types.StateUpdate) {
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches))) notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
default: default:
b.n.sendAll(update) // b.n.sendAll(update)
} }
} }
@ -225,15 +225,15 @@ func (b *batcher) flush() {
slices.Sort(changedNodes) slices.Sort(changedNodes)
if b.changedNodeIDs.Slice().Len() > 0 { if b.changedNodeIDs.Slice().Len() > 0 {
update := types.UpdatePeerChanged(changedNodes...) // update := types.UpdatePeerChanged(changedNodes...)
b.n.sendAll(update) // b.n.sendAll(update)
} }
if len(patches) > 0 { if len(patches) > 0 {
patchUpdate := types.UpdatePeerPatch(patches...) // patchUpdate := types.UpdatePeerPatch(patches...)
b.n.sendAll(patchUpdate) // b.n.sendAll(patchUpdate)
} }
b.changedNodeIDs = set.Slice[types.NodeID]{} b.changedNodeIDs = set.Slice[types.NodeID]{}

View File

@ -16,7 +16,6 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -54,13 +53,10 @@ type RegistrationInfo struct {
} }
type AuthProviderOIDC struct { type AuthProviderOIDC struct {
h *Headscale
serverURL string serverURL string
cfg *types.OIDCConfig cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *zcache.Cache[string, RegistrationInfo] registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
polMan policy.PolicyManager
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
@ -68,12 +64,9 @@ type AuthProviderOIDC struct {
func NewAuthProviderOIDC( func NewAuthProviderOIDC(
ctx context.Context, ctx context.Context,
h *Headscale,
serverURL string, serverURL string,
cfg *types.OIDCConfig, cfg *types.OIDCConfig,
db *db.HSDatabase,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
polMan policy.PolicyManager,
) (*AuthProviderOIDC, error) { ) (*AuthProviderOIDC, error) {
var err error var err error
// grab oidc config if it hasn't been already // grab oidc config if it hasn't been already
@ -99,13 +92,10 @@ func NewAuthProviderOIDC(
) )
return &AuthProviderOIDC{ return &AuthProviderOIDC{
h: h,
serverURL: serverURL, serverURL: serverURL,
cfg: cfg, cfg: cfg,
db: db,
registrationCache: registrationCache, registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
polMan: polMan,
oidcProvider: oidcProvider, oidcProvider: oidcProvider,
oauth2Config: oauth2Config, oauth2Config: oauth2Config,
@ -475,26 +465,29 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
) (*types.User, error) { ) (*types.User, error) {
var user *types.User var user *types.User
var err error var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) user, err = a.h.db.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) { if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err) return nil, fmt.Errorf("creating or updating user: %w", err)
} }
// if the user is still not found, create a new empty user. // if the user is still not found, create a new empty user.
// TODO(kradalby): This might cause us to not have an ID below which
// is a problem.
if user == nil { if user == nil {
user = &types.User{} user = &types.User{}
} }
user.FromClaim(claims) user.FromClaim(claims)
err = a.db.DB.Save(user).Error err = a.h.db.DB.Save(user).Error
if err != nil { if err != nil {
return nil, fmt.Errorf("creating or updating user: %w", err) return nil, fmt.Errorf("creating or updating user: %w", err)
} }
err = usersChangedHook(a.db, a.polMan, a.notifier) a.h.Change(types.Change{UserChange: types.UserChange{
if err != nil { ID: types.UserID(user.ID),
return nil, fmt.Errorf("updating resources using user: %w", err) // TODO(kradalby): Not sure about this one yet.
} NewUser: true,
}})
return user, nil return user, nil
} }
@ -504,12 +497,12 @@ func (a *AuthProviderOIDC) handleRegistration(
registrationID types.RegistrationID, registrationID types.RegistrationID,
expiry time.Time, expiry time.Time,
) (bool, error) { ) (bool, error) {
ipv4, ipv6, err := a.ipAlloc.Next() ipv4, ipv6, err := a.h.ipAlloc.Next()
if err != nil { if err != nil {
return false, err return false, err
} }
node, newNode, err := a.db.HandleNodeFromAuthPath( node, change, err := a.h.db.HandleNodeFromAuthPath(
registrationID, registrationID,
types.UserID(user.ID), types.UserID(user.ID),
&expiry, &expiry,
@ -520,14 +513,6 @@ func (a *AuthProviderOIDC) handleRegistration(
return false, fmt.Errorf("could not register node: %w", err) return false, fmt.Errorf("could not register node: %w", err)
} }
// Send an update to all nodes if this is a new node that they need to know
// about.
// If this is a refresh, just send new expiry updates.
updateSent, err := nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return false, fmt.Errorf("updating resources using node: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here. // dependency here.
// Because the way the policy manager works, we need to have the node // Because the way the policy manager works, we need to have the node
@ -539,21 +524,16 @@ func (a *AuthProviderOIDC) handleRegistration(
// ensure we send an update. // ensure we send an update.
// This works, but might be another good candidate for doing some sort of // This works, but might be another good candidate for doing some sort of
// eventbus. // eventbus.
routesChanged := policy.AutoApproveRoutes(a.polMan, node) // TODO(kradalby): This needs to be ran as part of the batcher maybe?
if err := a.db.DB.Save(node).Error; err != nil { // now since we dont update the node/pol here anymore
_ = policy.AutoApproveRoutes(a.h.polMan, node)
if err := a.h.db.DB.Save(node).Error; err != nil {
return false, fmt.Errorf("saving auto approved routes to node: %w", err) return false, fmt.Errorf("saving auto approved routes to node: %w", err)
} }
if !updateSent || routesChanged { a.h.Change(change)
a.notifier.NotifyByNodeID(
types.UpdateSelf(node.ID),
node.ID,
)
a.notifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) return change.NodeChange.NewNode, nil
}
return newNode, nil
} }
// TODO(kradalby): // TODO(kradalby):

View File

@ -163,25 +163,7 @@ func (m *mapSession) serveLongPoll() {
close(m.cancelCh) close(m.cancelCh)
m.cancelChMu.Unlock() m.cancelChMu.Unlock()
// only update node status if the node channel was removed. m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
// in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) {
// TODO(kradalby): All of this handling should be moved out of here
// to the mapBatcher(?), where there is more state (with the goal of removing it from here).
// Failover the node's routes if any.
m.h.updateNodeOnlineStatus(false, m.node)
// When a node disconnects, and it causes the primary route map to change,
// send a full update to all nodes.
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
if m.h.primaryRoutes.SetRoutes(m.node.ID) {
m.h.nodeNotifier.NotifyAll(types.UpdateFull())
}
}
m.afterServeLongPoll() m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
@ -191,12 +173,6 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1) m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done() defer m.h.pollNetMapStreamWG.Done()
// TODO(kradalby): All of this handling should be moved out of here
// to the mapBatcher(?), where there is more state (with the goal of removing it from here).
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
m.h.nodeNotifier.NotifyAll(types.UpdateFull())
}
// TODO(kradalby): I think this didnt really work and can be reverted back to a normal write thing. // TODO(kradalby): I think this didnt really work and can be reverted back to a normal write thing.
// Upgrade the writer to a ResponseController // Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w) rc := http.NewResponseController(m.w)
@ -212,10 +188,6 @@ func (m *mapSession) serveLongPoll() {
m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.req.Compress, m.req.Version) m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.req.Compress, m.req.Version)
// TODO(kradalby): All of this handling should be moved out of here
// to the mapBatcher(?), where there is more state (with the goal of removing it from here).
go m.h.updateNodeOnlineStatus(true, m.node)
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
// Loop through updates and continuously send them to the // Loop through updates and continuously send them to the
@ -298,38 +270,11 @@ var keepAliveZstd = (func() []byte {
return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression) return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression)
})() })()
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
change := &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
Online: &online,
}
if !online {
now := time.Now()
// lastSeen is only relevant if the node is disconnected.
node.LastSeen = &now
change.LastSeen = &now
}
if node.LastSeen != nil {
h.db.SetLastSeen(node.ID, *node.LastSeen)
}
h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerPatch(change), node.ID)
}
func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleEndpointUpdate() {
m.tracef("received endpoint update") m.tracef("received endpoint update")
change := m.node.PeerChangeFromMapRequest(m.req) change := m.node.PeerChangeFromMapRequest(m.req)
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID)
change.Online = &online
m.node.ApplyPeerChange(&change) m.node.ApplyPeerChange(&change)
sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo)
@ -355,6 +300,11 @@ func (m *mapSession) handleEndpointUpdate() {
return return
} }
c := types.Change{NodeChange: types.NodeChange{
ID: m.node.ID,
HostinfoChanged: true,
}}
// Check if the Hostinfo of the node has changed. // Check if the Hostinfo of the node has changed.
// If it has changed, check if there has been a change to // If it has changed, check if there has been a change to
// the routable IPs of the host and update them in // the routable IPs of the host and update them in
@ -364,9 +314,13 @@ func (m *mapSession) handleEndpointUpdate() {
// If the hostinfo has changed, but not the routes, just update // If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue. // hostinfo and let the function continue.
if routesChanged { if routesChanged {
// TODO(kradalby): I am not sure if we need this? c.NodeChange.RoutesChanged = true
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
// TODO(kradalby): I am not sure if we will ultimatly move this
// to a more central part as part of this effort.
// Do we want to make a thing where when you "save" a node, all the
// changes are calculated based on that and updating the right subsystems?
//
// Approve any route that has been defined in policy as // Approve any route that has been defined in policy as
// auto approved. Any change here is not important as any // auto approved. Any change here is not important as any
// actual state change will be detected when the route manager // actual state change will be detected when the route manager
@ -375,19 +329,7 @@ func (m *mapSession) handleEndpointUpdate() {
// Update the routes of the given node in the route manager to // Update the routes of the given node in the route manager to
// see if an update needs to be sent. // see if an update needs to be sent.
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) { m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...)
m.h.nodeNotifier.NotifyAll(types.UpdateFull())
} else {
m.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(m.node.ID), m.node.ID)
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
m.h.nodeNotifier.NotifyByNodeID(
types.UpdateSelf(m.node.ID),
m.node.ID)
}
} }
// Check if there has been a change to Hostname and update them // Check if there has been a change to Hostname and update them
@ -404,10 +346,7 @@ func (m *mapSession) handleEndpointUpdate() {
return return
} }
m.h.nodeNotifier.NotifyWithIgnore( m.h.Change(c)
types.UpdatePeerChanged(m.node.ID),
m.node.ID,
)
m.w.WriteHeader(http.StatusOK) m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc() mapResponseEndpointUpdates.WithLabelValues("ok").Inc()

View File

@ -1,21 +1,95 @@
package types package types
type Change struct { type Change struct {
NodeChange NodeChange NodeChange NodeChange
UserChange UserChange UserChange UserChange
DERPChanged bool
PolicyChanged bool
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
ExtraRecordsChanged bool
} }
type NodeChangeWhat string func (c *Change) FullUpdate() bool {
if !c.NodeChange.FullUpdate() && !c.UserChange.FullUpdate() {
return false
}
const ( return true
NodeChangeCameOnline NodeChangeWhat = "node-online" }
)
// type NodeChangeWhat string
// const (
// NodeChangeOnline NodeChangeWhat = "node-online"
// NodeChangeOffline NodeChangeWhat = "node-offline"
// NodeChangeAdded NodeChangeWhat = "node-added"
// NodeChangeRemoved NodeChangeWhat = "node-removed"
// )
type NodeChange struct { type NodeChange struct {
ID NodeID ID NodeID
What NodeChangeWhat // What NodeChangeWhat
// TODO(kradalby): FullChange is a bit of a
FullChange bool
ExpiryChanged bool
RoutesChanged bool
Online bool
Offline bool
// TODO: This could maybe be more granular
HostinfoChanged bool
// Registration and auth related changes
NewNode bool
RemovedNode bool
KeyChanged bool
TagsChanged bool
}
func (c *NodeChange) RegistrationChanged() bool {
return c.NewNode || c.KeyChanged || c.TagsChanged
}
func (c *NodeChange) OnlyKeyChange() bool {
return c.ID != 0 && !c.NewNode && !c.TagsChanged && c.KeyChanged
}
func (c *NodeChange) FullUpdate() bool {
if c.ID != 0 {
if c.RegistrationChanged() {
return true
}
if c.FullChange {
return true
}
return false
}
return false
} }
type UserChange struct { type UserChange struct {
ID UserID ID UserID
NewUser bool
RemovedUser bool
}
func (c *UserChange) FullUpdate() bool {
if c.ID != 0 {
return true
}
if c.NewUser || c.RemovedUser {
return true
}
return false
} }