diff --git a/hscontrol/app.go b/hscontrol/app.go index c0d13096..80658e33 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -30,7 +30,6 @@ import ( derpServer "github.com/juanfont/headscale/hscontrol/derp/server" "github.com/juanfont/headscale/hscontrol/dns" "github.com/juanfont/headscale/hscontrol/mapper" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" @@ -95,8 +94,7 @@ type Headscale struct { extraRecordMan *dns.ExtraRecordsMan primaryRoutes *routes.PrimaryRoutes - mapBatcher *mapper.Batcher - nodeNotifier *notifier.Notifier + mapBatcher *mapper.Batcher registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] @@ -169,12 +167,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { defer cancel() oidcProvider, err := NewAuthProviderOIDC( ctx, + &app, cfg.ServerURL, &cfg.OIDC, - app.db, - app.nodeNotifier, - app.ipAlloc, - app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -287,7 +282,15 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { if changed { log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes") - h.nodeNotifier.NotifyAll(update) + // TODO(kradalby): Not sure how I feel about this one, feel like we + // can be more clever? but at the same time, if they are passed straight + // through later, its fine? + for _, node := range update.ChangePatches { + h.Change(types.Change{NodeChange: types.NodeChange{ + ID: types.NodeID(node.NodeID), + ExpiryChanged: true, + }}) + } } case <-derpTickerChan: @@ -298,10 +301,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { h.DERPMap.Regions[region.RegionID] = ®ion } - h.nodeNotifier.NotifyAll(types.StateUpdate{ - Type: types.StateDERPUpdated, - DERPMap: h.DERPMap, - }) + h.Change(types.Change{DERPChanged: true}) case records, ok := <-extraRecordsUpdate: if !ok { @@ -309,18 +309,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { } h.cfg.TailcfgDNSConfig.ExtraRecords = records - // TODO(kradalby): We can probably do better than sending a full update here, - // but for now this will ensure that all of the nodes get the new records. - h.nodeNotifier.NotifyAll(types.UpdateFull()) + h.Change(types.Change{ExtraRecordsChanged: true}) } } } func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, - req interface{}, + req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, -) (interface{}, error) { +) (any, error) { // Check if the request is coming from the on-server client. // This is not secure, but it is to maintain maintainability // with the "legacy" database-based client @@ -498,55 +496,55 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } -// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. -// Maybe we should attempt a new in memory state and not go via the DB? -// Maybe this should be implemented as an event bus? -// A bool is returned indicating if a full update was sent to all nodes -func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { - users, err := db.ListUsers() - if err != nil { - return err - } +// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// // Maybe we should attempt a new in memory state and not go via the DB? +// // Maybe this should be implemented as an event bus? +// // A bool is returned indicating if a full update was sent to all nodes +// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { +// users, err := db.ListUsers() +// if err != nil { +// return err +// } - changed, err := polMan.SetUsers(users) - if err != nil { - return err - } +// changed, err := polMan.SetUsers(users) +// if err != nil { +// return err +// } - if changed { - notif.NotifyAll(types.UpdateFull()) - } +// if changed { +// notif.NotifyAll(types.UpdateFull()) +// } - return nil -} +// return nil +// } -// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. -// Maybe we should attempt a new in memory state and not go via the DB? -// Maybe this should be implemented as an event bus? -// A bool is returned indicating if a full update was sent to all nodes -func nodesChangedHook( - db *db.HSDatabase, - polMan policy.PolicyManager, - notif *notifier.Notifier, -) (bool, error) { - nodes, err := db.ListNodes() - if err != nil { - return false, err - } +// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// // Maybe we should attempt a new in memory state and not go via the DB? +// // Maybe this should be implemented as an event bus? +// // A bool is returned indicating if a full update was sent to all nodes +// func nodesChangedHook( +// db *db.HSDatabase, +// polMan policy.PolicyManager, +// notif *notifier.Notifier, +// ) (bool, error) { +// nodes, err := db.ListNodes() +// if err != nil { +// return false, err +// } - filterChanged, err := polMan.SetNodes(nodes) - if err != nil { - return false, err - } +// filterChanged, err := polMan.SetNodes(nodes) +// if err != nil { +// return false, err +// } - if filterChanged { - notif.NotifyAll(types.UpdateFull()) +// if filterChanged { +// notif.NotifyAll(types.UpdateFull()) - return true, nil - } +// return true, nil +// } - return false, nil -} +// return false, nil +// } // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { @@ -577,7 +575,6 @@ func (h *Headscale) Serve() error { // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) h.mapBatcher = mapper.NewBatcherAndMapper(h.db, h.cfg, h.DERPMap, h.polMan, h.primaryRoutes) - h.nodeNotifier = notifier.NewNotifier(h.cfg, h.mapBatcher) h.mapBatcher.Start() defer h.mapBatcher.Close() @@ -869,7 +866,7 @@ func (h *Headscale) Serve() error { log.Error().Err(err).Msg("failed to approve routes after new policy") } - h.nodeNotifier.NotifyAll(types.UpdateFull()) + h.Change(types.Change{PolicyChanged: true}) } default: info := func(msg string) { log.Info().Msg(msg) } @@ -895,7 +892,6 @@ func (h *Headscale) Serve() error { } info("closing node notifier") - h.nodeNotifier.Close() info("waiting for netmap stream to close") h.pollNetMapStreamWG.Wait() @@ -1198,3 +1194,7 @@ func (h *Headscale) autoApproveNodes() error { return nil } + +func (h *Headscale) Change(c types.Change) { + h.mapBatcher.AddWork(&c) +} diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 1ec1fcf3..cca02344 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -90,7 +90,13 @@ func (h *Headscale) handleExistingNode( return nil, fmt.Errorf("deleting ephemeral node: %w", err) } - h.nodeNotifier.NotifyAll(types.UpdatePeerRemoved(node.ID)) + h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + RemovedNode: true, + + // TODO(kradalby): Remove when specifics are implemented + FullChange: true, + }}) } expired = true @@ -101,7 +107,13 @@ func (h *Headscale) handleExistingNode( return nil, fmt.Errorf("setting node expiry: %w", err) } - h.nodeNotifier.NotifyWithIgnore(types.UpdateExpire(node.ID, requestExpiry), node.ID) + h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + ExpiryChanged: true, + + // TODO(kradalby): Remove when specifics are implemented + FullChange: true, + }}) } return nodeToRegisterResponse(node), nil @@ -238,11 +250,6 @@ func (h *Headscale) handleRegisterWithAuthKey( return nil, err } - updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier) - if err != nil { - return nil, fmt.Errorf("nodes changed hook: %w", err) - } - // This is a bit of a back and forth, but we have a bit of a chicken and egg // dependency here. // Because the way the policy manager works, we need to have the node @@ -254,14 +261,17 @@ func (h *Headscale) handleRegisterWithAuthKey( // ensure we send an update. // This works, but might be another good candidate for doing some sort of // eventbus. - routesChanged := policy.AutoApproveRoutes(h.polMan, node) + // TODO(kradalby): This needs to be ran as part of the batcher maybe? + // now since we dont update the node/pol here anymore + _ = policy.AutoApproveRoutes(h.polMan, node) if err := h.db.DB.Save(node).Error; err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } - if !updateSent || routesChanged { - h.nodeNotifier.NotifyAll(types.UpdatePeerChanged(node.ID)) - } + h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + NewNode: true, + }}) return &tailcfg.RegisterResponse{ MachineAuthorized: true, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index c91687da..2fa7b8e7 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -352,8 +352,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( registrationMethod string, ipv4 *netip.Addr, ipv6 *netip.Addr, -) (*types.Node, bool, error) { - var newNode bool +) (*types.Node, types.Change, error) { + var change types.Change node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { if reg, ok := hsdb.regCache.Get(registrationID); ok { if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { @@ -405,7 +405,10 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( } close(reg.Registered) - newNode = true + change.NodeChange = types.NodeChange{ + ID: node.ID, + NewNode: true, + } return node, err } else { // If the node is already registered, this is a refresh. @@ -413,6 +416,11 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( if err != nil { return nil, err } + + change.NodeChange = types.NodeChange{ + ID: node.ID, + ExpiryChanged: true, + } return node, nil } } @@ -420,7 +428,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, ErrNodeNotFoundRegistrationCache }) - return node, newNode, err + return node, change, err } func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { diff --git a/hscontrol/debug.go b/hscontrol/debug.go index ef28a955..28c40345 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -15,10 +15,6 @@ import ( func (h *Headscale) debugHTTPServer() *http.Server { debugMux := http.NewServeMux() debug := tsweb.Debugger(debugMux) - debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte(h.nodeNotifier.String())) - })) debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config, err := json.MarshalIndent(h.cfg, "", " ") if err != nil { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 69d9b161..007b1621 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -58,10 +58,10 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } - err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) - if err != nil { - return nil, fmt.Errorf("updating resources using user: %w", err) - } + api.h.Change(types.Change{UserChange: types.UserChange{ + ID: types.UserID(user.ID), + NewUser: true, + }}) return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -102,10 +102,10 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } - err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) - if err != nil { - return nil, fmt.Errorf("updating resources using user: %w", err) - } + api.h.Change(types.Change{UserChange: types.UserChange{ + ID: types.UserID(user.ID), + RemovedUser: true, + }}) return &v1.DeleteUserResponse{}, nil } @@ -253,7 +253,7 @@ func (api headscaleV1APIServer) RegisterNode( return nil, fmt.Errorf("looking up user: %w", err) } - node, _, err := api.h.db.HandleNodeFromAuthPath( + node, change, err := api.h.db.HandleNodeFromAuthPath( registrationId, types.UserID(user.ID), nil, @@ -264,11 +264,6 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } - updateSent, err := nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) - if err != nil { - return nil, fmt.Errorf("updating resources using node: %w", err) - } - // This is a bit of a back and forth, but we have a bit of a chicken and egg // dependency here. // Because the way the policy manager works, we need to have the node @@ -280,14 +275,14 @@ func (api headscaleV1APIServer) RegisterNode( // ensure we send an update. // This works, but might be another good candidate for doing some sort of // eventbus. - routesChanged := policy.AutoApproveRoutes(api.h.polMan, node) + // TODO(kradalby): This needs to be ran as part of the batcher maybe? + // now since we dont update the node/pol here anymore + _ = policy.AutoApproveRoutes(api.h.polMan, node) if err := api.h.db.DB.Save(node).Error; err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } - if !updateSent || routesChanged { - api.h.nodeNotifier.NotifyAll(types.UpdatePeerChanged(node.ID)) - } + api.h.Change(change) return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -305,7 +300,7 @@ func (api headscaleV1APIServer) GetNode( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.ID) + resp.Online = api.h.mapBatcher.IsLikelyConnected(node.ID) return &v1.GetNodeResponse{Node: resp}, nil } @@ -335,7 +330,10 @@ func (api headscaleV1APIServer) SetTags( }, status.Error(codes.InvalidArgument, err.Error()) } - api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) + api.h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + TagsChanged: true, + }}) log.Trace(). Str("node", node.Hostname). @@ -379,11 +377,11 @@ func (api headscaleV1APIServer) SetApprovedRoutes( return nil, status.Error(codes.InvalidArgument, err.Error()) } - if api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) { - api.h.nodeNotifier.NotifyAll(types.UpdateFull()) - } else { - api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) - } + api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) + api.h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + RoutesChanged: true, + }}) proto := node.Proto() proto.SubnetRoutes = util.PrefixesToString(api.h.primaryRoutes.PrimaryRoutes(node.ID)) @@ -418,7 +416,10 @@ func (api headscaleV1APIServer) DeleteNode( return nil, err } - api.h.nodeNotifier.NotifyAll(types.UpdatePeerRemoved(node.ID)) + api.h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + RemovedNode: true, + }}) return &v1.DeleteNodeResponse{}, nil } @@ -442,11 +443,11 @@ func (api headscaleV1APIServer) ExpireNode( return nil, err } - api.h.nodeNotifier.NotifyByNodeID( - types.UpdateSelf(node.ID), - node.ID) - - api.h.nodeNotifier.NotifyWithIgnore(types.UpdateExpire(node.ID, now), node.ID) + // TODO(kradalby): Ensure that both the selfupdate and peer updates are sent + api.h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + ExpiryChanged: true, + }}) log.Trace(). Str("node", node.Hostname). @@ -476,7 +477,12 @@ func (api headscaleV1APIServer) RenameNode( return nil, err } - api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) + api.h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + // TODO(kradalby): Not sure if this is what we want to send, probably + // we can do a delta change here. + NewNode: true, + }}) log.Trace(). Str("node", node.Hostname). @@ -495,7 +501,7 @@ func (api headscaleV1APIServer) ListNodes( // probably be done once. // TODO(kradalby): This should be done in one tx. - isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() + isLikelyConnected := api.h.mapBatcher.LikelyConnectedMap() if request.GetUser() != "" { user, err := api.h.db.GetUserByName(request.GetUser()) if err != nil { @@ -572,10 +578,13 @@ func (api headscaleV1APIServer) MoveNode( return nil, err } - api.h.nodeNotifier.NotifyByNodeID( - types.UpdateSelf(node.ID), - node.ID) - api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) + // TODO(kradalby): ensure that both the selfupdate and peer updates are sent + api.h.Change(types.Change{NodeChange: types.NodeChange{ + ID: node.ID, + // TODO(kradalby): Not sure if this is what we want to send, probably + // we can do a delta change here. + NewNode: true, + }}) return &v1.MoveNodeResponse{Node: node.Proto()}, nil } @@ -758,7 +767,7 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - api.h.nodeNotifier.NotifyAll(types.UpdateFull()) + api.h.Change(types.Change{PolicyChanged: true}) } response := &v1.SetPolicyResponse{ diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 0ce45069..a8254433 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -28,17 +28,6 @@ func init() { } } -type ChangeWork struct { - NodeID *types.NodeID - Update types.StateUpdate -} - -type nodeConn struct { - c chan<- []byte - compress string - version tailcfg.CapabilityVersion -} - type Batcher struct { mu deadlock.RWMutex @@ -58,7 +47,7 @@ type Batcher struct { // this should serve for the experiment. cancelCh chan struct{} - workCh chan *ChangeWork + workCh chan *types.Change } func NewBatcherAndMapper( @@ -80,7 +69,7 @@ func NewBatcher(mapper *mapper) *Batcher { mapper: mapper, cancelCh: make(chan struct{}), // TODO: No limit for now, this needs to be changed - workCh: make(chan *ChangeWork, (1<<16)-1), + workCh: make(chan *types.Change, (1<<16)-1), nodes: make(map[types.NodeID]nodeConn), connected: make(map[types.NodeID]*time.Time), @@ -110,37 +99,46 @@ func (b *Batcher) AddNode(id types.NodeID, c chan<- []byte, compress string, ver } b.nodes[id] = nodeConn{ + id: id, c: c, compress: compress, version: version, + + // TODO(kradalby): Not sure about this one yet. + mapper: b.mapper, } b.connected[id] = nil // nil means connected - b.AddWork(&ChangeWork{ - NodeID: &id, - Update: types.UpdateFull(), - }) + // TODO(kradalby): Handle: + // - Updating peers with online status + // - Updating routes in routemanager and peers + b.AddWork(&types.Change{NodeChange: types.NodeChange{ + ID: id, + Online: true, + }}) } -func (b *Batcher) RemoveNode(id types.NodeID, c chan<- []byte) bool { +func (b *Batcher) RemoveNode(id types.NodeID, c chan<- []byte) { b.mu.Lock() defer b.mu.Unlock() if curr, ok := b.nodes[id]; ok { if curr.c != c { - return false + return } } delete(b.nodes, id) b.connected[id] = ptr.To(time.Now()) - return true + // TODO(kradalby): Handle: + // - Updating peers with lastseen status, and only if not replaced + // - Updating routes in routemanager and peers } -func (b *Batcher) AddWork(work *ChangeWork) { - log.Trace().Msgf("adding work: %v", work.Update) - b.workCh <- work +func (b *Batcher) AddWork(change *types.Change) { + log.Trace().Msgf("adding work: %v", change) + b.workCh <- change } func (b *Batcher) IsConnected(id types.NodeID) bool { @@ -194,12 +192,12 @@ func (b *Batcher) doWork() { case <-b.cancelCh: return case work := <-b.workCh: - b.processWork(work) + b.processChange(work) } } } -// processWork is the current bottleneck where all the updates get picked up +// processChange is the current bottleneck where all the updates get picked up // one by one and processed. This will have to change, it needs to go as fast as // possible and just pass it on to the nodes. Currently it wont block because the // work channel is super large, but it might not be able to keep up. @@ -207,90 +205,125 @@ func (b *Batcher) doWork() { // mean a lot of goroutines, hanging around. // Another is just a worker pool that picks up work and processes it, // and passes it on to the nodes. That might be complicated with order? -func (b *Batcher) processWork(work *ChangeWork) { +func (b *Batcher) processChange(c *types.Change) { b.mu.RLock() defer b.mu.RUnlock() - log.Trace().Msgf("processing work: %v", work) - - if work.NodeID != nil { - id := *work.NodeID - node, ok := b.nodes[id] - if !ok { - log.Trace().Msgf("node %d not found in batcher, skipping work: %v", id, work.Update) - return - } - resp, err := b.resp(id, &node, work) - if err != nil { - log.Debug().Msgf("creating mapResp for %d: %s", id, err) - } - - node.c <- resp - return - } + log.Trace().Msgf("processing work: %v", c) for id, node := range b.nodes { - resp, err := b.resp(id, &node, work) - if err != nil { - log.Debug().Msgf("creating mapResp for %d: %s", id, err) - } - - node.c <- resp + err := node.change(c) + log.Error().Err(err).Uint64("node.id", id.Uint64()).Msgf("processing work for node %d", id) } } +type nodeConn struct { + id types.NodeID + c chan<- []byte + compress string + version tailcfg.CapabilityVersion + mapper *mapper +} + +type changeUpdate int + +const ( + _ changeUpdate = iota + ignoreUpdate + partialUpdate + fullUpdate +) + +func determineChange(c *types.Change) changeUpdate { + if c == nil { + return ignoreUpdate + } + + if c.FullUpdate() { + return fullUpdate + } + + return fullUpdate +} + +func (nc *nodeConn) change(c *types.Change) error { + switch determineChange(c) { + case partialUpdate: + return nc.partialUpdate(c) + case fullUpdate: + return nc.fullUpdate() + default: + log.Trace().Msgf("ignoring change: %v", c) + return nil + } +} + +func (nc *nodeConn) partialUpdate(c *types.Change) error { + return nil +} + +func (nc *nodeConn) fullUpdate() error { + data, err := nc.mapper.fullMapResponse(nc.id, nc.version, nc.compress) + if err != nil { + return err + } + + nc.c <- data + return nil +} + // resp is the logic that used to reside in the poller, but is now moved // to process before sending to the node. The idea is that we do not want to // be blocked on the send channel to the individual node, but rather // process all the work and then send the responses to the nodes. // TODO(kradalby): This is a temporary solution, as we explore this // approach, we will likely need to refactor this further. -func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte, error) { - var data []byte - var err error +// func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte, error) { +// var data []byte +// var err error - // TODO(kradalby): This should not be necessary, mapper only - // use compress and version, and this can either be moved out - // or passed directly. The mapreq isnt needed. - req := tailcfg.MapRequest{ - Compress: nc.compress, - Version: nc.version, - } +// // TODO(kradalby): This should not be necessary, mapper only +// // use compress and version, and this can either be moved out +// // or passed directly. The mapreq isnt needed. +// req := tailcfg.MapRequest{ +// Compress: nc.compress, +// Version: nc.version, +// } - // TODO(kradalby): We dont want to use the db here. We should - // just have the node available, or at least quickly accessible - // from the new fancy mem state we want. - node, err := b.mapper.db.GetNodeByID(id) - if err != nil { - return nil, err - } +// // TODO(kradalby): We dont want to use the db here. We should +// // just have the node available, or at least quickly accessible +// // from the new fancy mem state we want. +// node, err := b.mapper.db.GetNodeByID(id) +// if err != nil { +// return nil, err +// } - switch work.Update.Type { - case types.StateFullUpdate: - data, err = b.mapper.fullMapResponse(req, node) - case types.StatePeerChanged: - changed := make(map[types.NodeID]bool, len(work.Update.ChangeNodes)) +// switch work.Update.Type { +// case types.StateFullUpdate: +// data, err = b.mapper.fullMapResponse(req, node) +// case types.StatePeerChanged: +// changed := make(map[types.NodeID]bool, len(work.Update.ChangeNodes)) - for _, nodeID := range work.Update.ChangeNodes { - changed[nodeID] = true - } +// for _, nodeID := range work.Update.ChangeNodes { +// changed[nodeID] = true +// } - data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches) +// data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches) - case types.StatePeerChangedPatch: - data, err = b.mapper.peerChangedPatchResponse(req, node, work.Update.ChangePatches) - case types.StatePeerRemoved: - changed := make(map[types.NodeID]bool, len(work.Update.Removed)) +// case types.StatePeerChangedPatch: +// data, err = b.mapper.peerChangedPatchResponse(req, node, work.Update.ChangePatches) +// case types.StatePeerRemoved: +// changed := make(map[types.NodeID]bool, len(work.Update.Removed)) - for _, nodeID := range work.Update.Removed { - changed[nodeID] = false - } - data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches) - case types.StateSelfUpdate: - data, err = b.mapper.peerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches) - case types.StateDERPUpdated: - data, err = b.mapper.derpMapResponse(req, node, work.Update.DERPMap) - } +// for _, nodeID := range work.Update.Removed { +// changed[nodeID] = false +// } +// data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches) +// case types.StateSelfUpdate: +// data, err = b.mapper.peerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches) +// case types.StateDERPUpdated: +// data, err = b.mapper.derpMapResponse(req, node, work.Update.DERPMap) +// } - return data, err -} +// return data, err +// } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 1930076c..87a4c2d4 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -151,16 +151,22 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { // fullMapResponse returns a MapResponse for the given node. func (m *mapper) fullMapResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, + compress string, messages ...string, ) ([]byte, error) { - peers, err := m.listPeers(node.ID) + node, err := m.db.GetNodeByID(nodeID) if err != nil { return nil, err } - resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) + peers, err := m.listPeers(nodeID) + if err != nil { + return nil, err + } + + resp, err := m.baseWithConfigMapResponse(node, capVer) if err != nil { return nil, err } @@ -171,7 +177,7 @@ func (m *mapper) fullMapResponse( m.polMan, m.primary, node, - mapRequest.Version, + capVer, peers, m.cfg, ) @@ -179,7 +185,7 @@ func (m *mapper) fullMapResponse( return nil, err } - return marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) + return marshalMapResponse(resp, node, compress, messages...) } func (m *mapper) derpMapResponse( @@ -192,7 +198,7 @@ func (m *mapper) derpMapResponse( resp := m.baseMapResponse() resp.DERPMap = derpMap - return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) + return marshalMapResponse(&resp, node, mapRequest.Compress) } func (m *mapper) peerChangedResponse( @@ -269,7 +275,7 @@ func (m *mapper) peerChangedResponse( } resp.Node = tailnode - return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...) + return marshalMapResponse(&resp, node, mapRequest.Compress, messages...) } // peerChangedPatchResponse creates a patch MapResponse with @@ -282,11 +288,10 @@ func (m *mapper) peerChangedPatchResponse( resp := m.baseMapResponse() resp.PeersChangedPatch = changed - return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) + return marshalMapResponse(&resp, node, mapRequest.Compress) } func marshalMapResponse( - mapRequest tailcfg.MapRequest, resp *tailcfg.MapResponse, node *types.Node, compression string, @@ -300,7 +305,6 @@ func marshalMapResponse( if debugDumpMapResponsePath != "" { data := map[string]any{ "Messages": messages, - "MapRequest": mapRequest, "MapResponse": resp, } diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index a23d4b35..38705f2b 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -97,36 +97,36 @@ func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { return n.mbatcher.LikelyConnectedMap() } -func (n *Notifier) NotifyAll(update types.StateUpdate) { - n.NotifyWithIgnore(update) -} +// func (n *Notifier) NotifyAll(update types.StateUpdate) { +// n.NotifyWithIgnore(update) +// } -func (n *Notifier) NotifyWithIgnore( - update types.StateUpdate, - ignoreNodeIDs ...types.NodeID, -) { - if n.closed { - return - } +// func (n *Notifier) NotifyWithIgnore( +// update types.StateUpdate, +// ignoreNodeIDs ...types.NodeID, +// ) { +// if n.closed { +// return +// } - n.b.addOrPassthrough(update) -} +// n.b.addOrPassthrough(update) +// } -func (n *Notifier) NotifyByNodeID( - update types.StateUpdate, - nodeID types.NodeID, -) { - n.mbatcher.AddWork(&mapper.ChangeWork{ - NodeID: &nodeID, - Update: update, - }) -} +// func (n *Notifier) NotifyByNodeID( +// update types.StateUpdate, +// nodeID types.NodeID, +// ) { +// n.mbatcher.AddWork(&mapper.ChangeWork{ +// NodeID: &nodeID, +// Update: update, +// }) +// } -func (n *Notifier) sendAll(update types.StateUpdate) { - n.mbatcher.AddWork(&mapper.ChangeWork{ - Update: update, - }) -} +// func (n *Notifier) sendAll(update types.StateUpdate) { +// n.mbatcher.AddWork(&mapper.ChangeWork{ +// Update: update, +// }) +// } func (n *Notifier) String() string { notifierWaitersForLock.WithLabelValues("lock", "string").Inc() @@ -197,7 +197,7 @@ func (b *batcher) addOrPassthrough(update types.StateUpdate) { notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches))) default: - b.n.sendAll(update) + // b.n.sendAll(update) } } @@ -225,15 +225,15 @@ func (b *batcher) flush() { slices.Sort(changedNodes) if b.changedNodeIDs.Slice().Len() > 0 { - update := types.UpdatePeerChanged(changedNodes...) + // update := types.UpdatePeerChanged(changedNodes...) - b.n.sendAll(update) + // b.n.sendAll(update) } if len(patches) > 0 { - patchUpdate := types.UpdatePeerPatch(patches...) + // patchUpdate := types.UpdatePeerPatch(patches...) - b.n.sendAll(patchUpdate) + // b.n.sendAll(patchUpdate) } b.changedNodeIDs = set.Slice[types.NodeID]{} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 44ca3ed6..ead40e9e 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -16,7 +16,6 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -54,13 +53,10 @@ type RegistrationInfo struct { } type AuthProviderOIDC struct { + h *Headscale serverURL string cfg *types.OIDCConfig - db *db.HSDatabase registrationCache *zcache.Cache[string, RegistrationInfo] - notifier *notifier.Notifier - ipAlloc *db.IPAllocator - polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -68,12 +64,9 @@ type AuthProviderOIDC struct { func NewAuthProviderOIDC( ctx context.Context, + h *Headscale, serverURL string, cfg *types.OIDCConfig, - db *db.HSDatabase, - notif *notifier.Notifier, - ipAlloc *db.IPAllocator, - polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -99,13 +92,10 @@ func NewAuthProviderOIDC( ) return &AuthProviderOIDC{ + h: h, serverURL: serverURL, cfg: cfg, - db: db, registrationCache: registrationCache, - notifier: notif, - ipAlloc: ipAlloc, - polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -475,26 +465,29 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( ) (*types.User, error) { var user *types.User var err error - user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) + user, err = a.h.db.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, fmt.Errorf("creating or updating user: %w", err) } // if the user is still not found, create a new empty user. + // TODO(kradalby): This might cause us to not have an ID below which + // is a problem. if user == nil { user = &types.User{} } user.FromClaim(claims) - err = a.db.DB.Save(user).Error + err = a.h.db.DB.Save(user).Error if err != nil { return nil, fmt.Errorf("creating or updating user: %w", err) } - err = usersChangedHook(a.db, a.polMan, a.notifier) - if err != nil { - return nil, fmt.Errorf("updating resources using user: %w", err) - } + a.h.Change(types.Change{UserChange: types.UserChange{ + ID: types.UserID(user.ID), + // TODO(kradalby): Not sure about this one yet. + NewUser: true, + }}) return user, nil } @@ -504,12 +497,12 @@ func (a *AuthProviderOIDC) handleRegistration( registrationID types.RegistrationID, expiry time.Time, ) (bool, error) { - ipv4, ipv6, err := a.ipAlloc.Next() + ipv4, ipv6, err := a.h.ipAlloc.Next() if err != nil { return false, err } - node, newNode, err := a.db.HandleNodeFromAuthPath( + node, change, err := a.h.db.HandleNodeFromAuthPath( registrationID, types.UserID(user.ID), &expiry, @@ -520,14 +513,6 @@ func (a *AuthProviderOIDC) handleRegistration( return false, fmt.Errorf("could not register node: %w", err) } - // Send an update to all nodes if this is a new node that they need to know - // about. - // If this is a refresh, just send new expiry updates. - updateSent, err := nodesChangedHook(a.db, a.polMan, a.notifier) - if err != nil { - return false, fmt.Errorf("updating resources using node: %w", err) - } - // This is a bit of a back and forth, but we have a bit of a chicken and egg // dependency here. // Because the way the policy manager works, we need to have the node @@ -539,21 +524,16 @@ func (a *AuthProviderOIDC) handleRegistration( // ensure we send an update. // This works, but might be another good candidate for doing some sort of // eventbus. - routesChanged := policy.AutoApproveRoutes(a.polMan, node) - if err := a.db.DB.Save(node).Error; err != nil { + // TODO(kradalby): This needs to be ran as part of the batcher maybe? + // now since we dont update the node/pol here anymore + _ = policy.AutoApproveRoutes(a.h.polMan, node) + if err := a.h.db.DB.Save(node).Error; err != nil { return false, fmt.Errorf("saving auto approved routes to node: %w", err) } - if !updateSent || routesChanged { - a.notifier.NotifyByNodeID( - types.UpdateSelf(node.ID), - node.ID, - ) + a.h.Change(change) - a.notifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID) - } - - return newNode, nil + return change.NodeChange.NewNode, nil } // TODO(kradalby): diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 55a8c2ee..1d95f116 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -163,25 +163,7 @@ func (m *mapSession) serveLongPoll() { close(m.cancelCh) m.cancelChMu.Unlock() - // only update node status if the node channel was removed. - // in principal, it will be removed, but the client rapidly - // reconnects, the channel might be of another connection. - // In that case, it is not closed and the node is still online. - if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) { - // TODO(kradalby): All of this handling should be moved out of here - // to the mapBatcher(?), where there is more state (with the goal of removing it from here). - - // Failover the node's routes if any. - m.h.updateNodeOnlineStatus(false, m.node) - - // When a node disconnects, and it causes the primary route map to change, - // send a full update to all nodes. - // TODO(kradalby): This can likely be made more effective, but likely most - // nodes has access to the same routes, so it might not be a big deal. - if m.h.primaryRoutes.SetRoutes(m.node.ID) { - m.h.nodeNotifier.NotifyAll(types.UpdateFull()) - } - } + m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) m.afterServeLongPoll() m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) @@ -191,12 +173,6 @@ func (m *mapSession) serveLongPoll() { m.h.pollNetMapStreamWG.Add(1) defer m.h.pollNetMapStreamWG.Done() - // TODO(kradalby): All of this handling should be moved out of here - // to the mapBatcher(?), where there is more state (with the goal of removing it from here). - if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) { - m.h.nodeNotifier.NotifyAll(types.UpdateFull()) - } - // TODO(kradalby): I think this didnt really work and can be reverted back to a normal write thing. // Upgrade the writer to a ResponseController rc := http.NewResponseController(m.w) @@ -212,10 +188,6 @@ func (m *mapSession) serveLongPoll() { m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.req.Compress, m.req.Version) - // TODO(kradalby): All of this handling should be moved out of here - // to the mapBatcher(?), where there is more state (with the goal of removing it from here). - go m.h.updateNodeOnlineStatus(true, m.node) - m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) // Loop through updates and continuously send them to the @@ -298,38 +270,11 @@ var keepAliveZstd = (func() []byte { return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression) })() -// updateNodeOnlineStatus records the last seen status of a node and notifies peers -// about change in their online/offline status. -// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. -func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { - change := &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - Online: &online, - } - - if !online { - now := time.Now() - - // lastSeen is only relevant if the node is disconnected. - node.LastSeen = &now - change.LastSeen = &now - } - - if node.LastSeen != nil { - h.db.SetLastSeen(node.ID, *node.LastSeen) - } - - h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerPatch(change), node.ID) -} - func (m *mapSession) handleEndpointUpdate() { m.tracef("received endpoint update") change := m.node.PeerChangeFromMapRequest(m.req) - online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID) - change.Online = &online - m.node.ApplyPeerChange(&change) sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) @@ -355,6 +300,11 @@ func (m *mapSession) handleEndpointUpdate() { return } + c := types.Change{NodeChange: types.NodeChange{ + ID: m.node.ID, + HostinfoChanged: true, + }} + // Check if the Hostinfo of the node has changed. // If it has changed, check if there has been a change to // the routable IPs of the host and update them in @@ -364,9 +314,13 @@ func (m *mapSession) handleEndpointUpdate() { // If the hostinfo has changed, but not the routes, just update // hostinfo and let the function continue. if routesChanged { - // TODO(kradalby): I am not sure if we need this? - nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier) + c.NodeChange.RoutesChanged = true + // TODO(kradalby): I am not sure if we will ultimatly move this + // to a more central part as part of this effort. + // Do we want to make a thing where when you "save" a node, all the + // changes are calculated based on that and updating the right subsystems? + // // Approve any route that has been defined in policy as // auto approved. Any change here is not important as any // actual state change will be detected when the route manager @@ -375,19 +329,7 @@ func (m *mapSession) handleEndpointUpdate() { // Update the routes of the given node in the route manager to // see if an update needs to be sent. - if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) { - m.h.nodeNotifier.NotifyAll(types.UpdateFull()) - } else { - m.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(m.node.ID), m.node.ID) - - // TODO(kradalby): I am not sure if we need this? - // Send an update to the node itself with to ensure it - // has an updated packetfilter allowing the new route - // if it is defined in the ACL. - m.h.nodeNotifier.NotifyByNodeID( - types.UpdateSelf(m.node.ID), - m.node.ID) - } + m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) } // Check if there has been a change to Hostname and update them @@ -404,10 +346,7 @@ func (m *mapSession) handleEndpointUpdate() { return } - m.h.nodeNotifier.NotifyWithIgnore( - types.UpdatePeerChanged(m.node.ID), - m.node.ID, - ) + m.h.Change(c) m.w.WriteHeader(http.StatusOK) mapResponseEndpointUpdates.WithLabelValues("ok").Inc() diff --git a/hscontrol/types/change.go b/hscontrol/types/change.go index b39d2ef2..8ce90554 100644 --- a/hscontrol/types/change.go +++ b/hscontrol/types/change.go @@ -1,21 +1,95 @@ package types type Change struct { - NodeChange NodeChange - UserChange UserChange + NodeChange NodeChange + UserChange UserChange + DERPChanged bool + PolicyChanged bool + + // TODO(kradalby): We can probably do better than sending a full update here, + // but for now this will ensure that all of the nodes get the new records. + ExtraRecordsChanged bool } -type NodeChangeWhat string +func (c *Change) FullUpdate() bool { + if !c.NodeChange.FullUpdate() && !c.UserChange.FullUpdate() { + return false + } -const ( - NodeChangeCameOnline NodeChangeWhat = "node-online" -) + return true +} + +// type NodeChangeWhat string + +// const ( +// NodeChangeOnline NodeChangeWhat = "node-online" +// NodeChangeOffline NodeChangeWhat = "node-offline" +// NodeChangeAdded NodeChangeWhat = "node-added" +// NodeChangeRemoved NodeChangeWhat = "node-removed" +// ) type NodeChange struct { - ID NodeID - What NodeChangeWhat + ID NodeID + // What NodeChangeWhat + + // TODO(kradalby): FullChange is a bit of a + FullChange bool + + ExpiryChanged bool + RoutesChanged bool + + Online bool + Offline bool + + // TODO: This could maybe be more granular + HostinfoChanged bool + + // Registration and auth related changes + NewNode bool + RemovedNode bool + KeyChanged bool + TagsChanged bool +} + +func (c *NodeChange) RegistrationChanged() bool { + return c.NewNode || c.KeyChanged || c.TagsChanged +} + +func (c *NodeChange) OnlyKeyChange() bool { + return c.ID != 0 && !c.NewNode && !c.TagsChanged && c.KeyChanged +} + +func (c *NodeChange) FullUpdate() bool { + if c.ID != 0 { + if c.RegistrationChanged() { + return true + } + + if c.FullChange { + return true + } + + return false + } + + return false } type UserChange struct { ID UserID + + NewUser bool + RemovedUser bool +} + +func (c *UserChange) FullUpdate() bool { + if c.ID != 0 { + return true + } + + if c.NewUser || c.RemovedUser { + return true + } + + return false }