diff --git a/api.go b/api.go index ff738712..aadc604e 100644 --- a/api.go +++ b/api.go @@ -75,7 +75,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { machineKeyStr := ctx.Param("id") var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(machineKeyStr)) + err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Caller(). diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 5adc7f59..26ead6dc 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -486,7 +486,9 @@ func nodesToPtables( } var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(machine.NodeKey)) + err := nodeKey.UnmarshalText( + []byte(headscale.NodePublicKeyEnsurePrefix(machine.NodeKey)), + ) if err != nil { return nil, err } diff --git a/machine.go b/machine.go index 03caa5a2..d58c9c5c 100644 --- a/machine.go +++ b/machine.go @@ -439,7 +439,7 @@ func (machine Machine) toNode( includeRoutes bool, ) (*tailcfg.Node, error) { var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(machine.NodeKey)) + err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey))) if err != nil { log.Trace(). Caller(). @@ -450,14 +450,18 @@ func (machine Machine) toNode( } var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(machine.MachineKey)) + err = machineKey.UnmarshalText( + []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + ) if err != nil { return nil, fmt.Errorf("failed to parse machine public key: %w", err) } var discoKey key.DiscoPublic if machine.DiscoKey != "" { - err := discoKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey)) + err := discoKey.UnmarshalText( + []byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), + ) if err != nil { return nil, fmt.Errorf("failed to parse disco public key: %w", err) } @@ -634,7 +638,7 @@ func (h *Headscale) RegisterMachine( } var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(machineKeyStr)) + err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { return nil, err } diff --git a/oidc.go b/oidc.go index 48ad7187..d481e028 100644 --- a/oidc.go +++ b/oidc.go @@ -192,7 +192,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { machineKeyStr, machineKeyOK := machineKeyIf.(string) var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(machineKeyStr)) + err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Msg("could not parse machine public key") diff --git a/poll.go b/poll.go index 70bacc6f..1d2db944 100644 --- a/poll.go +++ b/poll.go @@ -38,7 +38,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { machineKeyStr := ctx.Param("id") var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(machineKeyStr)) + err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Str("handler", "PollNetMap"). diff --git a/utils.go b/utils.go index fa9f028d..c9971390 100644 --- a/utils.go +++ b/utils.go @@ -60,6 +60,30 @@ func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) } +func MachinePublicKeyEnsurePrefix(machineKey string) string { + if !strings.HasPrefix(machineKey, machinePublicHexPrefix) { + return machinePublicHexPrefix + machineKey + } + + return machineKey +} + +func NodePublicKeyEnsurePrefix(nodeKey string) string { + if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) { + return nodePublicHexPrefix + nodeKey + } + + return nodeKey +} + +func DiscoPublicKeyEnsurePrefix(discoKey string) string { + if !strings.HasPrefix(discoKey, discoPublicHexPrefix) { + return discoPublicHexPrefix + discoKey + } + + return discoKey +} + // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors type Error string