1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-08 00:11:42 +01:00

Add cache for requested expiry times

This commit adds a sentral cache to keep track of clients whom has
requested an expiry time, but were we need to keep hold of it until the
second request comes in.
This commit is contained in:
Kristoffer Dalby 2021-11-22 19:32:52 +00:00
parent e600ead3e9
commit 021c464148
4 changed files with 67 additions and 11 deletions

22
api.go
View File

@ -19,10 +19,13 @@ import (
) )
const ( const (
reservedResponseHeaderSize = 4 reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authKey" RegisterMethodAuthKey = "authKey"
RegisterMethodOIDC = "oidc" RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli" RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire",
)
) )
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
@ -441,7 +444,16 @@ func (h *Headscale) handleMachineRegistrationNew(
} }
if !reqisterRequest.Expiry.IsZero() { if !reqisterRequest.Expiry.IsZero() {
machine.Expiry = &reqisterRequest.Expiry log.Trace().
Caller().
Str("machine", machine.Name).
Time("expiry", reqisterRequest.Expiry).
Msg("Non-zero expiry time requested, adding to cache")
h.requestedExpiryCache.Set(
idKey.HexString(),
reqisterRequest.Expiry,
requestedExpiryCacheExpiration,
)
} }
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).HexString() // save the NodeKey machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).HexString() // save the NodeKey

23
app.go
View File

@ -53,6 +53,9 @@ const (
updateInterval = 5000 updateInterval = 5000
HTTPReadTimeout = 30 * time.Second HTTPReadTimeout = 30 * time.Second
requestedExpiryCacheExpiration = time.Minute * 5
requestedExpiryCacheCleanupInterval = time.Minute * 10
errUnsupportedDatabase = Error("unsupported DB") errUnsupportedDatabase = Error("unsupported DB")
errUnsupportedLetsEncryptChallengeType = Error( errUnsupportedLetsEncryptChallengeType = Error(
"unknown value for Lets Encrypt challenge type", "unknown value for Lets Encrypt challenge type",
@ -139,6 +142,8 @@ type Headscale struct {
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
oidcStateCache *cache.Cache oidcStateCache *cache.Cache
requestedExpiryCache *cache.Cache
} }
// NewHeadscale returns the Headscale app. // NewHeadscale returns the Headscale app.
@ -171,13 +176,19 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errUnsupportedDatabase return nil, errUnsupportedDatabase
} }
requestedExpiryCache := cache.New(
requestedExpiryCacheExpiration,
requestedExpiryCacheCleanupInterval,
)
app := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
privateKey: privKey, privateKey: privKey,
publicKey: &pubKey, publicKey: &pubKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
requestedExpiryCache: requestedExpiryCache,
} }
err = app.initDB() err = app.initDB()

View File

@ -616,6 +616,31 @@ func (h *Headscale) RegisterMachine(
return nil, errMachineNotFound return nil, errMachineNotFound
} }
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("Expiry time found in cache, assigning to node")
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(&machine, requestedTime)
return &machine, nil
}
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).

View File

@ -199,6 +199,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey); found {
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
// retrieve machine information // retrieve machine information
machine, err := h.GetMachineByMachineKey(machineKey) machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil { if err != nil {