1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-04 00:09:34 +01:00

Merge pull request #366 from kradalby/registration-simplification

This commit is contained in:
Kristoffer Dalby 2022-03-02 08:02:26 +00:00 committed by GitHub
commit eeded85d9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 334 additions and 474 deletions

View File

@ -29,6 +29,7 @@ linters:
- wrapcheck - wrapcheck
- dupl - dupl
- makezero - makezero
- maintidx
# We might want to enable this, but it might be a lot of work # We might want to enable this, but it might be a lot of work
- cyclop - cyclop

View File

@ -1,33 +1,37 @@
# CHANGELOG # CHANGELOG
**0.15.0 (2022-xx-xx):** ## 0.15.0 (2022-xx-xx)
**BREAKING**: **Note:** Take a backup of your database before upgrading.
### BREAKING
- Boundaries between Namespaces has been removed and all nodes can communicate by default [#357](https://github.com/juanfont/headscale/pull/357) - Boundaries between Namespaces has been removed and all nodes can communicate by default [#357](https://github.com/juanfont/headscale/pull/357)
- To limit access between nodes, use [ACLs](./docs/acls.md). - To limit access between nodes, use [ACLs](./docs/acls.md).
**Features**: ### Features
- Add support for writing ACL files with YAML [#359](https://github.com/juanfont/headscale/pull/359) - Add support for writing ACL files with YAML [#359](https://github.com/juanfont/headscale/pull/359)
- Users can now use emails in ACL's groups [#372](https://github.com/juanfont/headscale/issues/372) - Users can now use emails in ACL's groups [#372](https://github.com/juanfont/headscale/issues/372)
**Changes**: ### Changes
- Fix a bug were the same IP could be assigned to multiple hosts if joined in quick succession [#346](https://github.com/juanfont/headscale/pull/346) - Fix a bug were the same IP could be assigned to multiple hosts if joined in quick succession [#346](https://github.com/juanfont/headscale/pull/346)
- Simplify the code behind registration of machines [#366](https://github.com/juanfont/headscale/pull/366)
- Nodes are now only written to database if they are registrated successfully
- Fix a limitation in the ACLs that prevented users to write rules with `*` as source [#374](https://github.com/juanfont/headscale/issues/374) - Fix a limitation in the ACLs that prevented users to write rules with `*` as source [#374](https://github.com/juanfont/headscale/issues/374)
**0.14.0 (2022-02-24):** ## 0.14.0 (2022-02-24)
**UPCOMING BREAKING**: **UPCOMING ### BREAKING
From the **next** version (`0.15.0`), all machines will be able to communicate regardless of From the **next\*\* version (`0.15.0`), all machines will be able to communicate regardless of
if they are in the same namespace. This means that the behaviour currently limited to ACLs if they are in the same namespace. This means that the behaviour currently limited to ACLs
will become default. From version `0.15.0`, all limitation of communications must be done will become default. From version `0.15.0`, all limitation of communications must be done
with ACLs. with ACLs.
This is a part of aligning `headscale`'s behaviour with Tailscale's upstream behaviour. This is a part of aligning `headscale`'s behaviour with Tailscale's upstream behaviour.
**BREAKING**: ### BREAKING
- ACLs have been rewritten to align with the bevaviour Tailscale Control Panel provides. **NOTE:** This is only active if you use ACLs - ACLs have been rewritten to align with the bevaviour Tailscale Control Panel provides. **NOTE:** This is only active if you use ACLs
- Namespaces are now treated as Users - Namespaces are now treated as Users
@ -35,17 +39,17 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh
- Tags should now work correctly and adding a host to Headscale should now reload the rules. - Tags should now work correctly and adding a host to Headscale should now reload the rules.
- The documentation have a [fictional example](docs/acls.md) that should cover some use cases of the ACLs features - The documentation have a [fictional example](docs/acls.md) that should cover some use cases of the ACLs features
**Features**: ### Features
- Add support for configurable mTLS [docs](docs/tls.md#configuring-mutual-tls-authentication-mtls) [#297](https://github.com/juanfont/headscale/pull/297) - Add support for configurable mTLS [docs](docs/tls.md#configuring-mutual-tls-authentication-mtls) [#297](https://github.com/juanfont/headscale/pull/297)
**Changes**: ### Changes
- Remove dependency on CGO (switch from CGO SQLite to pure Go) [#346](https://github.com/juanfont/headscale/pull/346) - Remove dependency on CGO (switch from CGO SQLite to pure Go) [#346](https://github.com/juanfont/headscale/pull/346)
**0.13.0 (2022-02-18):** **0.13.0 (2022-02-18):**
**Features**: ### Features
- Add IPv6 support to the prefix assigned to namespaces - Add IPv6 support to the prefix assigned to namespaces
- Add API Key support - Add API Key support
@ -56,7 +60,7 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh
- `oidc.domain_map` option has been removed - `oidc.domain_map` option has been removed
- `strip_email_domain` option has been added (see [config-example.yaml](./config_example.yaml)) - `strip_email_domain` option has been added (see [config-example.yaml](./config_example.yaml))
**Changes**: ### Changes
- `ip_prefix` is now superseded by `ip_prefixes` in the configuration [#208](https://github.com/juanfont/headscale/pull/208) - `ip_prefix` is now superseded by `ip_prefixes` in the configuration [#208](https://github.com/juanfont/headscale/pull/208)
- Upgrade `tailscale` (1.20.4) and other dependencies to latest [#314](https://github.com/juanfont/headscale/pull/314) - Upgrade `tailscale` (1.20.4) and other dependencies to latest [#314](https://github.com/juanfont/headscale/pull/314)
@ -65,35 +69,35 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh
**0.12.4 (2022-01-29):** **0.12.4 (2022-01-29):**
**Changes**: ### Changes
- Make gRPC Unix Socket permissions configurable [#292](https://github.com/juanfont/headscale/pull/292) - Make gRPC Unix Socket permissions configurable [#292](https://github.com/juanfont/headscale/pull/292)
- Trim whitespace before reading Private Key from file [#289](https://github.com/juanfont/headscale/pull/289) - Trim whitespace before reading Private Key from file [#289](https://github.com/juanfont/headscale/pull/289)
- Add new command to generate a private key for `headscale` [#290](https://github.com/juanfont/headscale/pull/290) - Add new command to generate a private key for `headscale` [#290](https://github.com/juanfont/headscale/pull/290)
- Fixed issue where hosts deleted from control server may be written back to the database, as long as they are connected to the control server [#278](https://github.com/juanfont/headscale/pull/278) - Fixed issue where hosts deleted from control server may be written back to the database, as long as they are connected to the control server [#278](https://github.com/juanfont/headscale/pull/278)
**0.12.3 (2022-01-13):** ## 0.12.3 (2022-01-13)
**Changes**: ### Changes
- Added Alpine container [#270](https://github.com/juanfont/headscale/pull/270) - Added Alpine container [#270](https://github.com/juanfont/headscale/pull/270)
- Minor updates in dependencies [#271](https://github.com/juanfont/headscale/pull/271) - Minor updates in dependencies [#271](https://github.com/juanfont/headscale/pull/271)
**0.12.2 (2022-01-11):** ## 0.12.2 (2022-01-11)
Happy New Year! Happy New Year!
**Changes**: ### Changes
- Fix Docker release [#258](https://github.com/juanfont/headscale/pull/258) - Fix Docker release [#258](https://github.com/juanfont/headscale/pull/258)
- Rewrite main docs [#262](https://github.com/juanfont/headscale/pull/262) - Rewrite main docs [#262](https://github.com/juanfont/headscale/pull/262)
- Improve Docker docs [#263](https://github.com/juanfont/headscale/pull/263) - Improve Docker docs [#263](https://github.com/juanfont/headscale/pull/263)
**0.12.1 (2021-12-24):** ## 0.12.1 (2021-12-24)
(We are skipping 0.12.0 to correct a mishap done weeks ago with the version tagging) (We are skipping 0.12.0 to correct a mishap done weeks ago with the version tagging)
**BREAKING**: ### BREAKING
- Upgrade to Tailscale 1.18 [#229](https://github.com/juanfont/headscale/pull/229) - Upgrade to Tailscale 1.18 [#229](https://github.com/juanfont/headscale/pull/229)
- This change requires a new format for private key, private keys are now generated automatically: - This change requires a new format for private key, private keys are now generated automatically:
@ -101,19 +105,19 @@ Happy New Year!
2. Restart `headscale`, a new key will be generated. 2. Restart `headscale`, a new key will be generated.
3. Restart all Tailscale clients to fetch the new key 3. Restart all Tailscale clients to fetch the new key
**Changes**: ### Changes
- Unify configuration example [#197](https://github.com/juanfont/headscale/pull/197) - Unify configuration example [#197](https://github.com/juanfont/headscale/pull/197)
- Add stricter linting and formatting [#223](https://github.com/juanfont/headscale/pull/223) - Add stricter linting and formatting [#223](https://github.com/juanfont/headscale/pull/223)
**Features**: ### Features
- Add gRPC and HTTP API (HTTP API is currently disabled) [#204](https://github.com/juanfont/headscale/pull/204) - Add gRPC and HTTP API (HTTP API is currently disabled) [#204](https://github.com/juanfont/headscale/pull/204)
- Use gRPC between the CLI and the server [#206](https://github.com/juanfont/headscale/pull/206), [#212](https://github.com/juanfont/headscale/pull/212) - Use gRPC between the CLI and the server [#206](https://github.com/juanfont/headscale/pull/206), [#212](https://github.com/juanfont/headscale/pull/212)
- Beta OpenID Connect support [#126](https://github.com/juanfont/headscale/pull/126), [#227](https://github.com/juanfont/headscale/pull/227) - Beta OpenID Connect support [#126](https://github.com/juanfont/headscale/pull/126), [#227](https://github.com/juanfont/headscale/pull/227)
**0.11.0 (2021-10-25):** ## 0.11.0 (2021-10-25)
**BREAKING**: ### BREAKING
- Make headscale fetch DERP map from URL and file [#196](https://github.com/juanfont/headscale/pull/196) - Make headscale fetch DERP map from URL and file [#196](https://github.com/juanfont/headscale/pull/196)

View File

@ -119,7 +119,6 @@ func (s *Suite) TestValidExpandTagOwnersInUsers(c *check.C) {
Name: "testmachine", Name: "testmachine",
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostInfo), HostInfo: datatypes.JSON(hostInfo),
@ -163,7 +162,6 @@ func (s *Suite) TestValidExpandTagOwnersInPorts(c *check.C) {
Name: "testmachine", Name: "testmachine",
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostInfo), HostInfo: datatypes.JSON(hostInfo),
@ -207,7 +205,6 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) {
Name: "testmachine", Name: "testmachine",
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostInfo), HostInfo: datatypes.JSON(hostInfo),
@ -250,7 +247,6 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
Name: "webserver", Name: "webserver",
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostInfo), HostInfo: datatypes.JSON(hostInfo),
@ -267,7 +263,6 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
Name: "user", Name: "user",
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostInfo), HostInfo: datatypes.JSON(hostInfo),
@ -361,7 +356,6 @@ func (s *Suite) TestPortNamespace(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: ips, IPAddresses: ips,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
@ -404,7 +398,6 @@ func (s *Suite) TestPortGroup(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: ips, IPAddresses: ips,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),

171
api.go
View File

@ -22,7 +22,7 @@ import (
const ( const (
reservedResponseHeaderSize = 4 reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authKey" RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc" RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli" RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error( ErrRegisterMethodCLIDoesNotSupportExpire = Error(
@ -125,25 +125,50 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
machine, err := h.GetMachineByMachineKey(machineKey) machine, err := h.GetMachineByMachineKey(machineKey)
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{
Expiry: &time.Time{}, machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
MachineKey: MachinePublicKeyStripPrefix(machineKey),
Name: req.Hostinfo.Hostname, // If the machine has AuthKey set, handle registration via PreAuthKeys
} if req.Auth.AuthKey != "" {
if err := h.db.Create(&newMachine).Error; err != nil { h.handleAuthKey(ctx, machineKey, req)
log.Error().
Caller().
Err(err).
Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc()
return return
} }
machine = &newMachine
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: machineKeyStr,
Name: req.Hostinfo.Hostname,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
} }
if machine.Registered { if !req.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", req.Hostinfo.Hostname).
Time("expiry", req.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry
}
h.registrationCache.Set(
machineKeyStr,
newMachine,
registerCacheExpiration,
)
h.handleMachineRegistrationNew(ctx, machineKey, req)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration // If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either: // request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past) // - Trying to log out (sending a expiry in the past)
@ -180,15 +205,6 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
return return
} }
// If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" {
h.handleAuthKey(ctx, machineKey, req, *machine)
return
}
h.handleMachineRegistrationNew(ctx, machineKey, req, *machine)
} }
func (h *Headscale) getMapResponse( func (h *Headscale) getMapResponse(
@ -402,7 +418,7 @@ func (h *Headscale) handleMachineExpired(
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(ctx, machineKey, registerRequest, machine) h.handleAuthKey(ctx, machineKey, registerRequest)
return return
} }
@ -465,13 +481,12 @@ func (h *Headscale) handleMachineRegistrationNew(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL // The machine registration is new, redirect the client to the registration URL
log.Debug(). log.Debug().
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node is sending us a new NodeKey, sending auth url") Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
@ -484,24 +499,6 @@ func (h *Headscale) handleMachineRegistrationNew(
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
} }
if !registerRequest.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", machine.Name).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested, adding to cache")
h.requestedExpiryCache.Set(
machineKey.String(),
registerRequest.Expiry,
requestedExpiryCacheExpiration,
)
}
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
// save the NodeKey
h.db.Save(&machine)
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -520,19 +517,21 @@ func (h *Headscale) handleAuthKey(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine,
) { ) {
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -541,76 +540,66 @@ func (h *Headscale) handleAuthKey(
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
ctx.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
return return
} }
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
return return
} }
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(&machine, registerRequest.Expiry)
} else {
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
h.ipAllocationMutex.Lock() nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
now := time.Now().UTC()
ips, err := h.getAvailableIPs() machineToRegister := Machine{
Name: registerRequest.Hostinfo.Hostname,
NamespaceID: pak.Namespace.ID,
MachineKey: machineKeyStr,
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err := h.RegisterMachine(
machineToRegister,
)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Err(err).
Str("machine", machine.Name). Msg("could not register machine")
Msg("Failed to find an available IP address") machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc() Inc()
ctx.String(
http.StatusInternalServerError,
"could not register machine",
)
return return
} }
log.Info().
Str("func", "handleAuthKey").
Str("machine", machine.Name).
Str("ips", strings.Join(ips.ToStringSlice(), ",")).
Msgf("Assigning %s to %s", strings.Join(ips.ToStringSlice(), ","), machine.Name)
machine.Expiry = &registerRequest.Expiry h.UsePreAuthKey(pak)
machine.AuthKeyID = uint(pak.ID)
machine.IPAddresses = ips
machine.NamespaceID = pak.NamespaceID
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
// we update it just in case
machine.Registered = true
machine.RegisterMethod = RegisterMethodAuthKey
h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
}
pak.Used = true
h.db.Save(&pak)
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser() resp.User = *pak.Namespace.toUser()
@ -619,21 +608,21 @@ func (h *Headscale) handleAuthKey(
log.Error(). log.Error().
Caller(). Caller().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "Extremely sad!") ctx.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }

15
app.go
View File

@ -55,8 +55,8 @@ const (
HTTPReadTimeout = 30 * time.Second HTTPReadTimeout = 30 * time.Second
privateKeyFileMode = 0o600 privateKeyFileMode = 0o600
requestedExpiryCacheExpiration = time.Minute * 5 registerCacheExpiration = time.Minute * 15
requestedExpiryCacheCleanupInterval = time.Minute * 10 registerCacheCleanup = time.Minute * 20
errUnsupportedDatabase = Error("unsupported DB") errUnsupportedDatabase = Error("unsupported DB")
errUnsupportedLetsEncryptChallengeType = Error( errUnsupportedLetsEncryptChallengeType = Error(
@ -150,9 +150,8 @@ type Headscale struct {
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
requestedExpiryCache *cache.Cache registrationCache *cache.Cache
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
} }
@ -202,9 +201,9 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errUnsupportedDatabase return nil, errUnsupportedDatabase
} }
requestedExpiryCache := cache.New( registrationCache := cache.New(
requestedExpiryCacheExpiration, registerCacheExpiration,
requestedExpiryCacheCleanupInterval, registerCacheCleanup,
) )
app := Headscale{ app := Headscale{
@ -213,7 +212,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
dbString: dbString, dbString: dbString,
privateKey: privKey, privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
requestedExpiryCache: requestedExpiryCache, registrationCache: registrationCache,
} }
err = app.initDB() err = app.initDB()

View File

@ -5,7 +5,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/patrickmn/go-cache"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr" "inet.af/netaddr"
) )
@ -50,10 +49,6 @@ func (s *Suite) ResetDB(c *check.C) {
cfg: cfg, cfg: cfg,
dbType: "sqlite3", dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db", dbString: tmpDir + "/headscale_test.db",
requestedExpiryCache: cache.New(
requestedExpiryCacheExpiration,
requestedExpiryCacheCleanupInterval,
),
} }
err = app.initDB() err = app.initDB()
if err != nil { if err != nil {

View File

@ -1,41 +0,0 @@
package headscale
import (
"time"
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func (s *Suite) TestRegisterMachine(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
now := time.Now().UTC()
machine := Machine{
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: namespace.ID,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("10.0.0.1")},
Expiry: &now,
}
err = app.db.Save(&machine).Error
c.Assert(err, check.IsNil)
_, err = app.GetMachine(namespace.Name, machine.Name)
c.Assert(err, check.IsNil)
machineAfterRegistering, err := app.RegisterMachine(
machine.MachineKey,
namespace.Name,
)
c.Assert(err, check.IsNil)
c.Assert(machineAfterRegistering.Registered, check.Equals, true)
_, err = machineAfterRegistering.GetHostInfo()
c.Assert(err, check.IsNil)
}

33
db.go
View File

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"github.com/rs/zerolog/log"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@ -34,6 +35,38 @@ func (h *Headscale) initDB() error {
_ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") _ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
// If the Machine table has a column for registered,
// find all occourences of "false" and drop them. Then
// remove the column.
if db.Migrator().HasColumn(&Machine{}, "registered") {
log.Info().
Msg(`Database has legacy "registered" column in machine, removing...`)
machines := Machines{}
if err := h.db.Not("registered").Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
for _, machine := range machines {
log.Info().
Str("machine", machine.Name).
Str("machine_key", machine.MachineKey).
Msg("Deleting unregistered machine")
if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil {
log.Error().
Err(err).
Str("machine", machine.Name).
Str("machine_key", machine.MachineKey).
Msg("Error deleting unregistered machine")
}
}
err := db.Migrator().DropColumn(&Machine{}, "registered")
if err != nil {
log.Error().Err(err).Msg("Error dropping registered column")
}
}
err = db.AutoMigrate(&Machine{}) err = db.AutoMigrate(&Machine{})
if err != nil { if err != nil {
return err return err

View File

@ -164,7 +164,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_1", Name: "test_get_shared_nodes_1",
NamespaceID: namespaceShared1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
@ -182,7 +181,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_2", Name: "test_get_shared_nodes_2",
NamespaceID: namespaceShared2.ID, NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2, Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
@ -200,7 +198,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_3", Name: "test_get_shared_nodes_3",
NamespaceID: namespaceShared3.ID, NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3, Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
@ -218,7 +215,6 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: namespaceShared1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(PreAuthKey2InShared1.ID), AuthKeyID: uint(PreAuthKey2InShared1.ID),
@ -311,7 +307,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_1", Name: "test_get_shared_nodes_1",
NamespaceID: namespaceShared1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID), AuthKeyID: uint(preAuthKeyInShared1.ID),
@ -329,7 +324,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_2", Name: "test_get_shared_nodes_2",
NamespaceID: namespaceShared2.ID, NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2, Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID), AuthKeyID: uint(preAuthKeyInShared2.ID),
@ -347,7 +341,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_3", Name: "test_get_shared_nodes_3",
NamespaceID: namespaceShared3.ID, NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3, Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID), AuthKeyID: uint(preAuthKeyInShared3.ID),
@ -365,7 +358,6 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: namespaceShared1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(preAuthKey2InShared1.ID), AuthKeyID: uint(preAuthKey2InShared1.ID),

View File

@ -85,13 +85,12 @@ type Machine struct {
IpAddresses []string `protobuf:"bytes,5,rep,name=ip_addresses,json=ipAddresses,proto3" json:"ip_addresses,omitempty"` IpAddresses []string `protobuf:"bytes,5,rep,name=ip_addresses,json=ipAddresses,proto3" json:"ip_addresses,omitempty"`
Name string `protobuf:"bytes,6,opt,name=name,proto3" json:"name,omitempty"` Name string `protobuf:"bytes,6,opt,name=name,proto3" json:"name,omitempty"`
Namespace *Namespace `protobuf:"bytes,7,opt,name=namespace,proto3" json:"namespace,omitempty"` Namespace *Namespace `protobuf:"bytes,7,opt,name=namespace,proto3" json:"namespace,omitempty"`
Registered bool `protobuf:"varint,8,opt,name=registered,proto3" json:"registered,omitempty"` LastSeen *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"`
RegisterMethod RegisterMethod `protobuf:"varint,9,opt,name=register_method,json=registerMethod,proto3,enum=headscale.v1.RegisterMethod" json:"register_method,omitempty"` LastSuccessfulUpdate *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=last_successful_update,json=lastSuccessfulUpdate,proto3" json:"last_successful_update,omitempty"`
LastSeen *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` Expiry *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=expiry,proto3" json:"expiry,omitempty"`
LastSuccessfulUpdate *timestamppb.Timestamp `protobuf:"bytes,11,opt,name=last_successful_update,json=lastSuccessfulUpdate,proto3" json:"last_successful_update,omitempty"` PreAuthKey *PreAuthKey `protobuf:"bytes,11,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"`
Expiry *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=expiry,proto3" json:"expiry,omitempty"` CreatedAt *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"`
PreAuthKey *PreAuthKey `protobuf:"bytes,13,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` RegisterMethod RegisterMethod `protobuf:"varint,13,opt,name=register_method,json=registerMethod,proto3,enum=headscale.v1.RegisterMethod" json:"register_method,omitempty"`
CreatedAt *timestamppb.Timestamp `protobuf:"bytes,14,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"`
} }
func (x *Machine) Reset() { func (x *Machine) Reset() {
@ -175,20 +174,6 @@ func (x *Machine) GetNamespace() *Namespace {
return nil return nil
} }
func (x *Machine) GetRegistered() bool {
if x != nil {
return x.Registered
}
return false
}
func (x *Machine) GetRegisterMethod() RegisterMethod {
if x != nil {
return x.RegisterMethod
}
return RegisterMethod_REGISTER_METHOD_UNSPECIFIED
}
func (x *Machine) GetLastSeen() *timestamppb.Timestamp { func (x *Machine) GetLastSeen() *timestamppb.Timestamp {
if x != nil { if x != nil {
return x.LastSeen return x.LastSeen
@ -224,6 +209,13 @@ func (x *Machine) GetCreatedAt() *timestamppb.Timestamp {
return nil return nil
} }
func (x *Machine) GetRegisterMethod() RegisterMethod {
if x != nil {
return x.RegisterMethod
}
return RegisterMethod_REGISTER_METHOD_UNSPECIFIED
}
type RegisterMachineRequest struct { type RegisterMachineRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -822,7 +814,7 @@ var file_headscale_v1_machine_proto_rawDesc = []byte{
0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70,
0x61, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1d, 0x68, 0x65, 0x61, 0x64, 0x73, 0x61, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1d, 0x68, 0x65, 0x61, 0x64, 0x73,
0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b,
0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xfd, 0x04, 0x0a, 0x07, 0x4d, 0x61, 0x63, 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x04, 0x0a, 0x07, 0x4d, 0x61, 0x63,
0x68, 0x69, 0x6e, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x68, 0x69, 0x6e, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04,
0x52, 0x02, 0x69, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f,
0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x61, 0x63, 0x68, 0x69,
@ -836,33 +828,31 @@ var file_headscale_v1_machine_proto_rawDesc = []byte{
0x6e, 0x61, 0x6d, 0x65, 0x12, 0x35, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x35, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63,
0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63,
0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65,
0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x72, 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x37, 0x0a, 0x09, 0x6c,
0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x65, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a,
0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x65, 0x64, 0x12, 0x45, 0x0a, 0x0f, 0x72, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x09, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x08, 0x6c, 0x61, 0x73, 0x74,
0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x53, 0x65, 0x65, 0x6e, 0x12, 0x50, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x75, 0x63,
0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x5f, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x09,
0x6f, 0x64, 0x52, 0x0e, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72,
0x6f, 0x64, 0x12, 0x37, 0x0a, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, 0x6e, 0x18, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70,
0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x52, 0x14, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x32, 0x0a, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79,
0x70, 0x52, 0x08, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x65, 0x65, 0x6e, 0x12, 0x50, 0x0a, 0x16, 0x6c, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e,
0x61, 0x73, 0x74, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x5f, 0x75, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61,
0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6d, 0x70, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x12, 0x3a, 0x0a, 0x0c, 0x70, 0x72,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x14, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x75, 0x63, 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e,
0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x32, 0x0a, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x52, 0x0a, 0x70, 0x72, 0x65, 0x41,
0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f,
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d,
0x79, 0x12, 0x3a, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, 0x65, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41,
0x79, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x74, 0x12, 0x45, 0x0a, 0x0f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x6d, 0x65,
0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x68, 0x65, 0x61,
0x79, 0x52, 0x0a, 0x70, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x12, 0x39, 0x0a, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74,
0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x52, 0x0e, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74,
0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x22, 0x48, 0x0a, 0x16, 0x52, 0x65, 0x67, 0x69,
0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63,
0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x22, 0x48, 0x0a, 0x16, 0x52, 0x65, 0x67, 0x69,
0x73, 0x74, 0x65, 0x72, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65,
@ -962,12 +952,12 @@ var file_headscale_v1_machine_proto_goTypes = []interface{}{
} }
var file_headscale_v1_machine_proto_depIdxs = []int32{ var file_headscale_v1_machine_proto_depIdxs = []int32{
14, // 0: headscale.v1.Machine.namespace:type_name -> headscale.v1.Namespace 14, // 0: headscale.v1.Machine.namespace:type_name -> headscale.v1.Namespace
0, // 1: headscale.v1.Machine.register_method:type_name -> headscale.v1.RegisterMethod 15, // 1: headscale.v1.Machine.last_seen:type_name -> google.protobuf.Timestamp
15, // 2: headscale.v1.Machine.last_seen:type_name -> google.protobuf.Timestamp 15, // 2: headscale.v1.Machine.last_successful_update:type_name -> google.protobuf.Timestamp
15, // 3: headscale.v1.Machine.last_successful_update:type_name -> google.protobuf.Timestamp 15, // 3: headscale.v1.Machine.expiry:type_name -> google.protobuf.Timestamp
15, // 4: headscale.v1.Machine.expiry:type_name -> google.protobuf.Timestamp 16, // 4: headscale.v1.Machine.pre_auth_key:type_name -> headscale.v1.PreAuthKey
16, // 5: headscale.v1.Machine.pre_auth_key:type_name -> headscale.v1.PreAuthKey 15, // 5: headscale.v1.Machine.created_at:type_name -> google.protobuf.Timestamp
15, // 6: headscale.v1.Machine.created_at:type_name -> google.protobuf.Timestamp 0, // 6: headscale.v1.Machine.register_method:type_name -> headscale.v1.RegisterMethod
1, // 7: headscale.v1.RegisterMachineResponse.machine:type_name -> headscale.v1.Machine 1, // 7: headscale.v1.RegisterMachineResponse.machine:type_name -> headscale.v1.Machine
1, // 8: headscale.v1.GetMachineResponse.machine:type_name -> headscale.v1.Machine 1, // 8: headscale.v1.GetMachineResponse.machine:type_name -> headscale.v1.Machine
1, // 9: headscale.v1.ExpireMachineResponse.machine:type_name -> headscale.v1.Machine 1, // 9: headscale.v1.ExpireMachineResponse.machine:type_name -> headscale.v1.Machine

View File

@ -885,12 +885,6 @@
"namespace": { "namespace": {
"$ref": "#/definitions/v1Namespace" "$ref": "#/definitions/v1Namespace"
}, },
"registered": {
"type": "boolean"
},
"registerMethod": {
"$ref": "#/definitions/v1RegisterMethod"
},
"lastSeen": { "lastSeen": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
@ -909,6 +903,9 @@
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
},
"registerMethod": {
"$ref": "#/definitions/v1RegisterMethod"
} }
} }
}, },

View File

@ -159,9 +159,11 @@ func (api headscaleV1APIServer) RegisterMachine(
Str("namespace", request.GetNamespace()). Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()). Str("machine_key", request.GetKey()).
Msg("Registering machine") Msg("Registering machine")
machine, err := api.h.RegisterMachine(
machine, err := api.h.RegisterMachineFromAuthCallback(
request.GetKey(), request.GetKey(),
request.GetNamespace(), request.GetNamespace(),
RegisterMethodCLI,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -398,11 +400,11 @@ func (api headscaleV1APIServer) DebugCreateMachine(
HostInfo: datatypes.JSON(hostinfoJson), HostInfo: datatypes.JSON(hostinfoJson),
} }
// log.Trace().Caller().Interface("machine", newMachine).Msg("") api.h.registrationCache.Set(
request.GetKey(),
if err := api.h.db.Create(&newMachine).Error; err != nil { newMachine,
return nil, err registerCacheExpiration,
} )
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
} }

View File

@ -621,12 +621,6 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Equal(s.T(), "machine-4", listAll[3].Name) assert.Equal(s.T(), "machine-4", listAll[3].Name)
assert.Equal(s.T(), "machine-5", listAll[4].Name) assert.Equal(s.T(), "machine-5", listAll[4].Name)
assert.True(s.T(), listAll[0].Registered)
assert.True(s.T(), listAll[1].Registered)
assert.True(s.T(), listAll[2].Registered)
assert.True(s.T(), listAll[3].Registered)
assert.True(s.T(), listAll[4].Registered)
otherNamespaceMachineKeys := []string{ otherNamespaceMachineKeys := []string{
"b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", "b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e",
"dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", "dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584",
@ -710,9 +704,6 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Equal(s.T(), "otherNamespace-machine-1", listAllWithotherNamespace[5].Name) assert.Equal(s.T(), "otherNamespace-machine-1", listAllWithotherNamespace[5].Name)
assert.Equal(s.T(), "otherNamespace-machine-2", listAllWithotherNamespace[6].Name) assert.Equal(s.T(), "otherNamespace-machine-2", listAllWithotherNamespace[6].Name)
assert.True(s.T(), listAllWithotherNamespace[5].Registered)
assert.True(s.T(), listAllWithotherNamespace[6].Registered)
// Test list all nodes after added otherNamespace // Test list all nodes after added otherNamespace
listOnlyotherNamespaceMachineNamespaceResult, err := ExecuteCommand( listOnlyotherNamespaceMachineNamespaceResult, err := ExecuteCommand(
&s.headscale, &s.headscale,
@ -752,9 +743,6 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
listOnlyotherNamespaceMachineNamespace[1].Name, listOnlyotherNamespaceMachineNamespace[1].Name,
) )
assert.True(s.T(), listOnlyotherNamespaceMachineNamespace[0].Registered)
assert.True(s.T(), listOnlyotherNamespaceMachineNamespace[1].Registered)
// Delete a machines // Delete a machines
_, err = ExecuteCommand( _, err = ExecuteCommand(
&s.headscale, &s.headscale,
@ -979,7 +967,6 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
assert.Equal(s.T(), uint64(1), machine.Id) assert.Equal(s.T(), uint64(1), machine.Id)
assert.Equal(s.T(), "route-machine", machine.Name) assert.Equal(s.T(), "route-machine", machine.Name)
assert.True(s.T(), machine.Registered)
listAllResult, err := ExecuteCommand( listAllResult, err := ExecuteCommand(
&s.headscale, &s.headscale,

View File

@ -21,9 +21,12 @@ import (
const ( const (
errMachineNotFound = Error("machine not found") errMachineNotFound = Error("machine not found")
errMachineAlreadyRegistered = Error("machine already registered")
errMachineRouteIsNotAvailable = Error("route is not available on machine") errMachineRouteIsNotAvailable = Error("route is not available on machine")
errMachineAddressesInvalid = Error("failed to parse machine addresses") errMachineAddressesInvalid = Error("failed to parse machine addresses")
errMachineNotFoundRegistrationCache = Error(
"machine not found in registration cache",
)
errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
errHostnameTooLong = Error("Hostname too long") errHostnameTooLong = Error("Hostname too long")
) )
@ -42,8 +45,9 @@ type Machine struct {
NamespaceID uint NamespaceID uint
Namespace Namespace `gorm:"foreignKey:NamespaceID"` Namespace Namespace `gorm:"foreignKey:NamespaceID"`
Registered bool // temp
RegisterMethod string RegisterMethod string
// TODO(kradalby): This seems like irrelevant information?
AuthKeyID uint AuthKeyID uint
AuthKey *PreAuthKey AuthKey *PreAuthKey
@ -51,6 +55,8 @@ type Machine struct {
LastSuccessfulUpdate *time.Time LastSuccessfulUpdate *time.Time
Expiry *time.Time Expiry *time.Time
// TODO(kradalby): Figure out a way to use tailcfg datatypes
// here and have gorm serialise them.
HostInfo datatypes.JSON HostInfo datatypes.JSON
Endpoints datatypes.JSON Endpoints datatypes.JSON
EnabledRoutes datatypes.JSON EnabledRoutes datatypes.JSON
@ -65,11 +71,6 @@ type (
MachinesP []*Machine MachinesP []*Machine
) )
// For the time being this method is rather naive.
func (machine Machine) isRegistered() bool {
return machine.Registered
}
type MachineAddresses []netaddr.IP type MachineAddresses []netaddr.IP
func (ma MachineAddresses) ToStringSlice() []string { func (ma MachineAddresses) ToStringSlice() []string {
@ -116,7 +117,7 @@ func (machine Machine) isExpired() bool {
// If Expiry is not set, the client has not indicated that // If Expiry is not set, the client has not indicated that
// it wants an expiry time, it is therefor considered // it wants an expiry time, it is therefor considered
// to mean "not expired" // to mean "not expired"
if machine.Expiry.IsZero() { if machine.Expiry == nil || machine.Expiry.IsZero() {
return false return false
} }
@ -232,7 +233,7 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
Msg("Finding direct peers") Msg("Finding direct peers")
machines := Machines{} machines := Machines{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ? AND registered", if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ?",
machine.MachineKey).Find(&machines).Error; err != nil { machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
@ -295,7 +296,7 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
} }
for _, peer := range peers { for _, peer := range peers {
if peer.isRegistered() && !peer.isExpired() { if !peer.isExpired() {
validPeers = append(validPeers, peer) validPeers = append(validPeers, peer)
} }
} }
@ -384,8 +385,6 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
// DeleteMachine softs deletes a Machine from the database. // DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(machine *Machine) error { func (h *Headscale) DeleteMachine(machine *Machine) error {
machine.Registered = false
h.db.Save(&machine) // we mark it as unregistered, just in case
if err := h.db.Delete(&machine).Error; err != nil { if err := h.db.Delete(&machine).Error; err != nil {
return err return err
} }
@ -653,7 +652,7 @@ func (machine Machine) toNode(
LastSeen: machine.LastSeen, LastSeen: machine.LastSeen,
KeepAlive: true, KeepAlive: true,
MachineAuthorized: machine.Registered, MachineAuthorized: !machine.isExpired(),
Capabilities: []string{tailcfg.CapabilityFileSharing}, Capabilities: []string{tailcfg.CapabilityFileSharing},
} }
@ -671,8 +670,6 @@ func (machine *Machine) toProto() *v1.Machine {
Name: machine.Name, Name: machine.Name,
Namespace: machine.Namespace.toProto(), Namespace: machine.Namespace.toProto(),
Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter // TODO(kradalby): Implement register method enum converter
// RegisterMethod: , // RegisterMethod: ,
@ -700,74 +697,50 @@ func (machine *Machine) toProto() *v1.Machine {
return machineProto return machineProto
} }
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. func (h *Headscale) RegisterMachineFromAuthCallback(
func (h *Headscale) RegisterMachine(
machineKeyStr string, machineKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf(
"failed to find namespace in register machine from auth callback, %w",
err,
)
} }
var machineKey key.MachinePublic registrationMachine.NamespaceID = namespace.ID
err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) registrationMachine.RegisterMethod = registrationMethod
if err != nil {
return nil, err machine, err := h.RegisterMachine(
registrationMachine,
)
return machine, err
} else {
return nil, errCouldNotConvertMachineInterface
}
} }
return nil, errMachineNotFoundRegistrationCache
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key_str", machineKeyStr). Str("machine_key", machine.MachineKey).
Str("machine_key", machineKey.String()).
Msg("Registering machine") Msg("Registering machine")
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
return nil, err
}
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("Expiry time found in cache, assigning to node")
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(machine, requestedTime)
return machine, nil
}
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Attempting to register machine") Msg("Attempting to register machine")
if machine.isRegistered() {
err := errMachineAlreadyRegistered
log.Error().
Caller().
Err(err).
Str("machine", machine.Name).
Msg("Attempting to register machine")
return nil, err
}
h.ipAllocationMutex.Lock() h.ipAllocationMutex.Lock()
defer h.ipAllocationMutex.Unlock() defer h.ipAllocationMutex.Unlock()
@ -782,17 +755,8 @@ func (h *Headscale) RegisterMachine(
return nil, err return nil, err
} }
log.Trace().
Caller().
Str("machine", machine.Name).
Str("ip", strings.Join(ips.ToStringSlice(), ",")).
Msg("Found IP for host")
machine.IPAddresses = ips machine.IPAddresses = ips
machine.NamespaceID = namespace.ID
machine.Registered = true
machine.RegisterMethod = RegisterMethodCLI
machine.Expiry = &requestedTime
h.db.Save(&machine) h.db.Save(&machine)
log.Trace(). log.Trace().
@ -801,7 +765,7 @@ func (h *Headscale) RegisterMachine(
Str("ip", strings.Join(ips.ToStringSlice(), ",")). Str("ip", strings.Join(ips.ToStringSlice(), ",")).
Msg("Machine registered with the database") Msg("Machine registered with the database")
return machine, nil return &machine, nil
} }
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {

View File

@ -29,7 +29,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
@ -59,7 +58,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
@ -82,7 +80,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(1), AuthKeyID: uint(1),
} }
@ -105,7 +102,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine3", Name: "testmachine3",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(1), AuthKeyID: uint(1),
} }
@ -136,7 +132,6 @@ func (s *Suite) TestListPeers(c *check.C) {
DiscoKey: "faa" + strconv.Itoa(index), DiscoKey: "faa" + strconv.Itoa(index),
Name: "testmachine" + strconv.Itoa(index), Name: "testmachine" + strconv.Itoa(index),
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
@ -188,7 +183,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
}, },
Name: "testmachine" + strconv.Itoa(index), Name: "testmachine" + strconv.Itoa(index),
NamespaceID: stor[index%2].namespace.ID, NamespaceID: stor[index%2].namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID), AuthKeyID: uint(stor[index%2].key.ID),
} }
@ -258,7 +252,6 @@ func (s *Suite) TestExpireMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
Expiry: &time.Time{}, Expiry: &time.Time{},

View File

@ -54,7 +54,6 @@ func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
@ -146,7 +145,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Name: "test_get_shared_nodes_1", Name: "test_get_shared_nodes_1",
NamespaceID: namespaceShared1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID), AuthKeyID: uint(preAuthKeyShared1.ID),
@ -164,7 +162,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Name: "test_get_shared_nodes_2", Name: "test_get_shared_nodes_2",
NamespaceID: namespaceShared2.ID, NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2, Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID), AuthKeyID: uint(preAuthKeyShared2.ID),
@ -182,7 +179,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Name: "test_get_shared_nodes_3", Name: "test_get_shared_nodes_3",
NamespaceID: namespaceShared3.ID, NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3, Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID), AuthKeyID: uint(preAuthKeyShared3.ID),
@ -200,7 +196,6 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Name: "test_get_shared_nodes_4", Name: "test_get_shared_nodes_4",
NamespaceID: namespaceShared1.ID, NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1, Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")}, IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID), AuthKeyID: uint(preAuthKey2Shared1.ID),

79
oidc.go
View File

@ -10,20 +10,15 @@ import (
"html/template" "html/template"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
const ( const (
oidcStateCacheExpiration = time.Minute * 5
oidcStateCacheCleanupInterval = time.Minute * 10
randomByteSize = 16 randomByteSize = 16
) )
@ -61,14 +56,6 @@ func (h *Headscale) initOIDC() error {
} }
} }
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(
oidcStateCacheExpiration,
oidcStateCacheCleanupInterval,
)
}
return nil return nil
} }
@ -101,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -125,7 +112,6 @@ var oidcCallbackTemplate = template.Must(
</html>`), </html>`),
) )
// TODO: Why is the entire machine registration logic duplicated here?
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace // Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
@ -197,7 +183,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
// retrieve machinekey from state cache // retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !machineKeyFound {
log.Error(). log.Error().
@ -207,10 +193,12 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machineKeyStr, machineKeyOK := machineKeyIf.(string) machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
var machineKey key.MachinePublic var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
)
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse machine public key")
@ -229,33 +217,19 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set // retrieve machine information if it exist
requestedTime := time.Time{} // The error is not important, because if it does not
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found { // exist, then this is a new machine and we will move
if reqTime, ok := requestedTimeIf.(time.Time); ok { // on to registration.
requestedTime = reqTime machine, _ := h.GetMachineByMachineKey(machineKey)
}
}
// retrieve machine information if machine != nil {
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
log.Error().Msg("machine key not found in database")
ctx.String(
http.StatusInternalServerError,
"could not get machine info from database",
)
return
}
if machine.isRegistered() {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("machine already registered, reauthenticating") Msg("machine already registered, reauthenticating")
h.RefreshMachine(machine, requestedTime) h.RefreshMachine(machine, *machine.Expiry)
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
@ -279,8 +253,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
now := time.Now().UTC()
namespaceName, err := NormalizeNamespaceName( namespaceName, err := NormalizeNamespaceName(
claims.Email, claims.Email,
h.cfg.OIDC.StripEmaildomain, h.cfg.OIDC.StripEmaildomain,
@ -294,12 +266,12 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// register the machine if it's new // register the machine if it's new
if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, errNamespaceNotFound) {
namespace, err = h.CreateNamespace(namespaceName) namespace, err = h.CreateNamespace(namespaceName)
if err != nil { if err != nil {
@ -328,29 +300,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
ips, err := h.getAvailableIPs() machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
_, err = h.RegisterMachineFromAuthCallback(
machineKeyStr,
namespace.Name,
RegisterMethodOIDC,
)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("could not get an IP from the pool") Msg("could not register machine")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get an IP from the pool", "could not register machine",
) )
return return
} }
machine.IPAddresses = ips
machine.NamespaceID = namespace.ID
machine.Registered = true
machine.RegisterMethod = RegisterMethodOIDC
machine.LastSuccessfulUpdate = &now
machine.Expiry = &requestedTime
h.db.Save(&machine)
}
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email, User: claims.Email,

View File

@ -113,6 +113,12 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
return nil return nil
} }
// UsePreAuthKey marks a PreAuthKey as used.
func (h *Headscale) UsePreAuthKey(k *PreAuthKey) {
k.Used = true
h.db.Save(k)
}
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used. // If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {

View File

@ -80,7 +80,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testest", Name: "testest",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
@ -105,7 +104,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testest", Name: "testest",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }
@ -143,7 +141,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testest", Name: "testest",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),

View File

@ -22,16 +22,16 @@ message Machine {
string name = 6; string name = 6;
Namespace namespace = 7; Namespace namespace = 7;
bool registered = 8;
RegisterMethod register_method = 9;
google.protobuf.Timestamp last_seen = 10; google.protobuf.Timestamp last_seen = 8;
google.protobuf.Timestamp last_successful_update = 11; google.protobuf.Timestamp last_successful_update = 9;
google.protobuf.Timestamp expiry = 12; google.protobuf.Timestamp expiry = 10;
PreAuthKey pre_auth_key = 13; PreAuthKey pre_auth_key = 11;
google.protobuf.Timestamp created_at = 14; google.protobuf.Timestamp created_at = 12;
RegisterMethod register_method = 13;
// google.protobuf.Timestamp updated_at = 14; // google.protobuf.Timestamp updated_at = 14;
// google.protobuf.Timestamp deleted_at = 15; // google.protobuf.Timestamp deleted_at = 15;

View File

@ -35,7 +35,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "test_get_route_machine", Name: "test_get_route_machine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo), HostInfo: datatypes.JSON(hostinfo),
@ -89,7 +88,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "test_enable_route_machine", Name: "test_enable_route_machine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo), HostInfo: datatypes.JSON(hostinfo),

View File

@ -36,7 +36,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddresses: ips, IPAddresses: ips,
@ -85,7 +84,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
IPAddresses: ips, IPAddresses: ips,
@ -176,7 +174,6 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: namespace.ID, NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: RegisterMethodAuthKey, RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: uint(pak.ID),
} }