diff --git a/api.go b/api.go index f831cacf..9b8fafd2 100644 --- a/api.go +++ b/api.go @@ -140,17 +140,25 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { // We create the machine and then keep it around until a callback // happens newMachine := Machine{ - Expiry: &time.Time{}, MachineKey: machineKeyStr, Name: req.Hostinfo.Hostname, NodeKey: NodePublicKeyStripPrefix(req.NodeKey), LastSeen: &now, } + if !req.Expiry.IsZero() { + log.Trace(). + Caller(). + Str("machine", req.Hostinfo.Hostname). + Time("expiry", req.Expiry). + Msg("Non-zero expiry time requested") + newMachine.Expiry = &req.Expiry + } + h.registrationCache.Set( machineKeyStr, newMachine, - requestedExpiryCacheExpiration, + registerCacheExpiration, ) h.handleMachineRegistrationNew(ctx, machineKey, req) @@ -490,19 +498,6 @@ func (h *Headscale) handleMachineRegistrationNew( strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) } - if !registerRequest.Expiry.IsZero() { - log.Trace(). - Caller(). - Str("machine", registerRequest.Hostinfo.Hostname). - Time("expiry", registerRequest.Expiry). - Msg("Non-zero expiry time requested, adding to cache") - h.requestedExpiryCache.Set( - machineKey.String(), - registerRequest.Expiry, - requestedExpiryCacheExpiration, - ) - } - respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index 63ca10e4..f7a6e636 100644 --- a/app.go +++ b/app.go @@ -55,8 +55,8 @@ const ( HTTPReadTimeout = 30 * time.Second privateKeyFileMode = 0o600 - requestedExpiryCacheExpiration = time.Minute * 5 - requestedExpiryCacheCleanupInterval = time.Minute * 10 + registerCacheExpiration = time.Minute * 15 + registerCacheCleanup = time.Minute * 20 errUnsupportedDatabase = Error("unsupported DB") errUnsupportedLetsEncryptChallengeType = Error( @@ -148,11 +148,8 @@ type Headscale struct { lastStateChange sync.Map - oidcProvider *oidc.Provider - oauth2Config *oauth2.Config - oidcStateCache *cache.Cache - - requestedExpiryCache *cache.Cache + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config registrationCache *cache.Cache @@ -204,25 +201,19 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, errUnsupportedDatabase } - requestedExpiryCache := cache.New( - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, - ) - registrationCache := cache.New( // TODO(kradalby): Add unified cache expiry config options - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, + registerCacheExpiration, + registerCacheCleanup, ) app := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - privateKey: privKey, - aclRules: tailcfg.FilterAllowAll, // default allowall - requestedExpiryCache: requestedExpiryCache, - registrationCache: registrationCache, + cfg: cfg, + dbType: cfg.DBtype, + dbString: dbString, + privateKey: privKey, + aclRules: tailcfg.FilterAllowAll, // default allowall + registrationCache: registrationCache, } err = app.initDB() diff --git a/app_test.go b/app_test.go index 53c703a6..96036a1d 100644 --- a/app_test.go +++ b/app_test.go @@ -5,7 +5,6 @@ import ( "os" "testing" - "github.com/patrickmn/go-cache" "gopkg.in/check.v1" "inet.af/netaddr" ) @@ -50,10 +49,6 @@ func (s *Suite) ResetDB(c *check.C) { cfg: cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", - requestedExpiryCache: cache.New( - requestedExpiryCacheExpiration, - requestedExpiryCacheCleanupInterval, - ), } err = app.initDB() if err != nil { diff --git a/grpcv1.go b/grpcv1.go index 75732b1d..1c022905 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -160,25 +160,10 @@ func (api headscaleV1APIServer) RegisterMachine( Str("machine_key", request.GetKey()). Msg("Registering machine") - // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set - // This means that if a user is to slow with register a machine, it will possibly not - // have the correct expiry. - requestedTime := time.Time{} - if requestedTimeIf, found := api.h.requestedExpiryCache.Get(request.GetKey()); found { - log.Trace(). - Caller(). - Str("machine", request.Key). - Msg("Expiry time found in cache, assigning to node") - if reqTime, ok := requestedTimeIf.(time.Time); ok { - requestedTime = reqTime - } - } - machine, err := api.h.RegisterMachineFromAuthCallback( request.GetKey(), request.GetNamespace(), RegisterMethodCLI, - &requestedTime, ) if err != nil { return nil, err @@ -418,7 +403,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( api.h.registrationCache.Set( request.GetKey(), newMachine, - requestedExpiryCacheExpiration, + registerCacheExpiration, ) return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil diff --git a/machine.go b/machine.go index ac1a1c84..03f727a2 100644 --- a/machine.go +++ b/machine.go @@ -683,7 +683,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback( machineKeyStr string, namespaceName string, registrationMethod string, - expiry *time.Time, ) (*Machine, error) { if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if registrationMachine, ok := machineInterface.(Machine); ok { @@ -697,7 +696,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback( registrationMachine.NamespaceID = namespace.ID registrationMachine.RegisterMethod = registrationMethod - registrationMachine.Expiry = expiry machine, err := h.RegisterMachine( registrationMachine, diff --git a/oidc.go b/oidc.go index e745236b..fe69a76f 100644 --- a/oidc.go +++ b/oidc.go @@ -10,20 +10,16 @@ import ( "html/template" "net/http" "strings" - "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" - "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" "tailscale.com/types/key" ) const ( - oidcStateCacheExpiration = time.Minute * 5 - oidcStateCacheCleanupInterval = time.Minute * 10 - randomByteSize = 16 + randomByteSize = 16 ) type IDTokenClaims struct { @@ -60,14 +56,6 @@ func (h *Headscale) initOIDC() error { } } - // init the state cache if it hasn't been already - if h.oidcStateCache == nil { - h.oidcStateCache = cache.New( - oidcStateCacheExpiration, - oidcStateCacheCleanupInterval, - ) - } - return nil } @@ -100,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { stateStr := hex.EncodeToString(randomBlob)[:32] // place the machine key into the state cache, so it can be retrieved later - h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) + h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) authURL := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authURL) @@ -196,7 +184,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { } // retrieve machinekey from state cache - machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) + machineKeyIf, machineKeyFound := h.registrationCache.Get(state) if !machineKeyFound { log.Error(). @@ -228,14 +216,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - // TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set - requestedTime := time.Time{} - if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { - if reqTime, ok := requestedTimeIf.(time.Time); ok { - requestedTime = reqTime - } - } - // retrieve machine information machine, err := h.GetMachineByMachineKey(machineKey) if err != nil { @@ -254,7 +234,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("machine", machine.Name). Msg("machine already registered, reauthenticating") - h.RefreshMachine(machine, requestedTime) + h.RefreshMachine(machine, *machine.Expiry) var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ @@ -329,7 +309,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machineKeyStr, namespace.Name, RegisterMethodOIDC, - &requestedTime, ) if err != nil { log.Error().