1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-04 00:09:34 +01:00

Generalise registration for pre auth keys

This commit is contained in:
Kristoffer Dalby 2022-02-27 18:42:15 +01:00
parent c58ce6f60c
commit acb945841c

53
api.go
View File

@ -22,7 +22,7 @@ import (
const ( const (
reservedResponseHeaderSize = 4 reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authKey" RegisterMethodAuthKey = RegisterMethodAuthKey
RegisterMethodOIDC = "oidc" RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli" RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error( ErrRegisterMethodCLIDoesNotSupportExpire = Error(
@ -545,7 +545,7 @@ func (h *Headscale) handleAuthKey(
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).
Inc() Inc()
return return
@ -556,7 +556,7 @@ func (h *Headscale) handleAuthKey(
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).
Inc() Inc()
return return
@ -575,38 +575,31 @@ func (h *Headscale) handleAuthKey(
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
h.ipAllocationMutex.Lock() nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
now := time.Now().UTC()
ips, err := h.getAvailableIPs() _, err = h.RegisterMachine(
machine.Name,
machine.Namespace.Name,
RegisterMethodAuthKey,
&registerRequest.Expiry,
pak,
&nodeKey,
&now,
)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Err(err).
Str("machine", machine.Name). Msg("could not register machine")
Msg("Failed to find an available IP address") machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). ctx.String(
Inc() http.StatusInternalServerError,
"could not register machine",
)
return return
} }
log.Info().
Str("func", "handleAuthKey").
Str("machine", machine.Name).
Str("ips", strings.Join(ips.ToStringSlice(), ",")).
Msgf("Assigning %s to %s", strings.Join(ips.ToStringSlice(), ","), machine.Name)
machine.Expiry = &registerRequest.Expiry
machine.AuthKeyID = uint(pak.ID)
machine.IPAddresses = ips
machine.NamespaceID = pak.NamespaceID
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
// we update it just in case
machine.Registered = true
machine.RegisterMethod = RegisterMethodAuthKey
h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
} }
pak.Used = true pak.Used = true
@ -622,13 +615,13 @@ func (h *Headscale) handleAuthKey(
Str("machine", machine.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "Extremely sad!") ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info(). log.Info().