From 5e92ddad4363815c0b29bd5933ec274db55bbb29 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 28 Feb 2022 22:42:30 +0000 Subject: [PATCH] Remove redundant caches This commit removes the two extra caches (oidc, requested time) and uses the new central registration cache instead. The requested time is unified into the main machine object and the oidc key is just added to the same cache, as a string with the state as a key instead of machine key. --- api.go | 25 ++++++++++--------------- app.go | 33 ++++++++++++--------------------- app_test.go | 5 ----- grpcv1.go | 17 +---------------- machine.go | 2 -- oidc.go | 29 ++++------------------------- 6 files changed, 27 insertions(+), 84 deletions(-) diff --git a/api.go b/api.go index f831cacf..9b8fafd2 100644 --- a/api.go +++ b/api.go @@ -140,17 +140,25 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { // We create the machine and then keep it around until a callback // happens newMachine := Machine{ - Expiry: &time.Time{}, MachineKey: machineKeyStr, Name: req.Hostinfo.Hostname, NodeKey: NodePublicKeyStripPrefix(req.NodeKey), LastSeen: &now, } + if !req.Expiry.IsZero() { + log.Trace(). + Caller(). + Str("machine", req.Hostinfo.Hostname). + Time("expiry", req.Expiry). + Msg("Non-zero expiry time requested") + newMachine.Expiry = &req.Expiry + } + h.registrationCache.Set( machineKeyStr, newMachine, - requestedExpiryCacheExpiration, + registerCacheExpiration, ) h.handleMachineRegistrationNew(ctx, machineKey, req) @@ -490,19 +498,6 @@ func (h *Headscale) handleMachineRegistrationNew( strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) } - if !registerRequest.Expiry.IsZero() { - log.Trace(). - Caller(). - Str("machine", registerRequest.Hostinfo.Hostname). - Time("expiry", registerRequest.Expiry). - Msg("Non-zero expiry time requested, adding to cache") - h.requestedExpiryCache.Set( - machineKey.String(), - registerRequest.Expiry, - requestedExpiryCacheExpiration, - ) - } - respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index 63ca10e4..f7a6e636 100644 --- a/app.go +++ b/app.go @@ -55,8 +55,8 @@ const ( HTTPReadTimeout = 30 * time.Second privateKeyFileMode = 0o600 - requestedExpiryCacheExpiration = time.Minute * 5 - requestedExpiryCacheCleanupInterval = time.Minute * 10 + registerCacheExpiration = time.Minute * 15 + registerCacheCleanup = time.Minute * 20 errUnsupportedDatabase = Error("unsupported DB") errUnsupportedLetsEncryptChallengeType = Error( @@ -148,11 +148,8 @@ type Headscale struct { lastStateChange sync.Map - oidcProvider *oidc.Provider - oauth2Config *oauth2.Config - oidcStateCache *cache.Cache - - requestedExpiryCache *cache.Cache + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config registrationCache *cache.Cache @@ -204,25 +201,19 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, errUnsupportedDatabase } - requestedExpiryCache := cache.New( - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, - ) - registrationCache := cache.New( // TODO(kradalby): Add unified cache expiry config options - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, + registerCacheExpiration, + registerCacheCleanup, ) app := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - aclRules: tailcfg.FilterAllowAll, // default allowall - requestedExpiryCache: requestedExpiryCache, - registrationCache: registrationCache, + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + aclRules: tailcfg.FilterAllowAll, // default allowall + registrationCache: registrationCache, } err = app.initDB() diff --git a/app_test.go b/app_test.go index 53c703a6..96036a1d 100644 --- a/app_test.go +++ b/app_test.go @@ -5,7 +5,6 @@ import ( "os" "testing" - "github.com/patrickmn/go-cache" "gopkg.in/check.v1" "inet.af/netaddr" ) @@ -50,10 +49,6 @@ func (s *Suite) ResetDB(c *check.C) { cfg: cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", - requestedExpiryCache: cache.New( - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, - ), } err = app.initDB() if err != nil { diff --git a/grpcv1.go b/grpcv1.go index 75732b1d..1c022905 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -160,25 +160,10 @@ func (api headscaleV1APIServer) RegisterMachine( Str("machine_key", request.GetKey()). Msg("Registering machine") - // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set - // This means that if a user is to slow with register a machine, it will possibly not - // have the correct expiry. - requestedTime := time.Time{} - if requestedTimeIf, found := api.h.requestedExpiryCache.Get(request.GetKey()); found { - log.Trace(). - Caller(). - Str("machine", request.Key). - Msg("Expiry time found in cache, assigning to node") - if reqTime, ok := requestedTimeIf.(time.Time); ok { - requestedTime = reqTime - } - } - machine, err := api.h.RegisterMachineFromAuthCallback( request.GetKey(), request.GetNamespace(), RegisterMethodCLI, - &requestedTime, ) if err != nil { return nil, err @@ -418,7 +403,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( api.h.registrationCache.Set( request.GetKey(), newMachine, - requestedExpiryCacheExpiration, + registerCacheExpiration, ) return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil diff --git a/machine.go b/machine.go index ac1a1c84..03f727a2 100644 --- a/machine.go +++ b/machine.go @@ -683,7 +683,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback( machineKeyStr string, namespaceName string, registrationMethod string, - expiry *time.Time, ) (*Machine, error) { if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if registrationMachine, ok := machineInterface.(Machine); ok { @@ -697,7 +696,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback( registrationMachine.NamespaceID = namespace.ID registrationMachine.RegisterMethod = registrationMethod - registrationMachine.Expiry = expiry machine, err := h.RegisterMachine( registrationMachine, diff --git a/oidc.go b/oidc.go index e745236b..fe69a76f 100644 --- a/oidc.go +++ b/oidc.go @@ -10,20 +10,16 @@ import ( "html/template" "net/http" "strings" - "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" - "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" "tailscale.com/types/key" ) const ( - oidcStateCacheExpiration = time.Minute * 5 - oidcStateCacheCleanupInterval = time.Minute * 10 - randomByteSize = 16 + randomByteSize = 16 ) type IDTokenClaims struct { @@ -60,14 +56,6 @@ func (h *Headscale) initOIDC() error { } } - // init the state cache if it hasn't been already - if h.oidcStateCache == nil { - h.oidcStateCache = cache.New( - oidcStateCacheExpiration, - oidcStateCacheCleanupInterval, - ) - } - return nil } @@ -100,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { stateStr := hex.EncodeToString(randomBlob)[:32] // place the machine key into the state cache, so it can be retrieved later - h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) + h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) authURL := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authURL) @@ -196,7 +184,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { } // retrieve machinekey from state cache - machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) + machineKeyIf, machineKeyFound := h.registrationCache.Get(state) if !machineKeyFound { log.Error(). @@ -228,14 +216,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set - requestedTime := time.Time{} - if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { - if reqTime, ok := requestedTimeIf.(time.Time); ok { - requestedTime = reqTime - } - } - // retrieve machine information machine, err := h.GetMachineByMachineKey(machineKey) if err != nil { @@ -254,7 +234,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("machine", machine.Name). Msg("machine already registered, reauthenticating") - h.RefreshMachine(machine, requestedTime) + h.RefreshMachine(machine, *machine.Expiry) var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ @@ -329,7 +309,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machineKeyStr, namespace.Name, RegisterMethodOIDC, - &requestedTime, ) if err != nil { log.Error().