From 8d5b04f3d3ec6694c3f6a21c70ff966f174dc1f3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 12:53:04 -0500 Subject: [PATCH] hook up user and node changes to policy Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ hscontrol/auth.go | 7 +++++++ hscontrol/grpcv1.go | 15 +++++++++++++++ hscontrol/oidc.go | 14 ++++++++++++++ 4 files changed, 80 insertions(+) diff --git a/hscontrol/app.go b/hscontrol/app.go index 3489d18f..b4d36caa 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -165,6 +165,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.db, app.nodeNotifier, app.ipAlloc, + app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -472,6 +473,48 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + users, err := db.ListUsers() + if err != nil { + return err + } + + changed, err := polMan.SetUsers(users) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + nodes, err := db.ListNodes() + if err != nil { + return err + } + + changed, err := polMan.SetNodes(nodes) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if profilingEnabled { @@ -770,6 +813,7 @@ func (h *Headscale) Serve() error { Msg("Received SIGHUP, reloading ACL and Config") // TODO(kradalby): Reload config on SIGHUP + // TODO(kradalby): Only update if we set a new policy if err := h.loadACLPolicy(); err != nil { log.Error().Err(err).Msg("failed to reload ACL policy") } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 67545031..2b23aad3 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey( return } + + err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } + } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a221d519..51134e7e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -57,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -86,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.DeleteUserResponse{}, nil } @@ -220,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } + err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 84267b41..5028e244 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -53,6 +54,7 @@ type AuthProviderOIDC struct { registrationCache *zcache.Cache[string, key.MachinePublic] notifier *notifier.Notifier ipAlloc *db.IPAllocator + polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -65,6 +67,7 @@ func NewAuthProviderOIDC( db *db.HSDatabase, notif *notifier.Notifier, ipAlloc *db.IPAllocator, + polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -96,6 +99,7 @@ func NewAuthProviderOIDC( registrationCache: registrationCache, notifier: notif, ipAlloc: ipAlloc, + polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return nil, fmt.Errorf("creating or updating user: %w", err) } + err = usersChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return user, nil } @@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode( return fmt.Errorf("could not register node: %w", err) } + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return fmt.Errorf("updating resources using node: %w", err) + } + return nil }