diff --git a/machine.go b/machine.go index 22be0da1..d2e55b03 100644 --- a/machine.go +++ b/machine.go @@ -350,7 +350,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { return &m, nil } -// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. +// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. func (h *Headscale) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*Machine, error) { @@ -362,6 +362,19 @@ func (h *Headscale) GetMachineByMachineKey( return &m, nil } +// GetMachineByNodeKey finds a Machine by its current NodeKey +func (h *Headscale) GetMachineByNodeKey( + nodeKey key.NodePublic, +) (*Machine, error) { + machine := Machine{} + if result := h.db.Preload("Namespace").First(&machine, "node_key = ?", + NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { + return nil, result.Error + } + + return &machine, nil +} + // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { @@ -762,11 +775,11 @@ func getTags( } func (h *Headscale) RegisterMachineFromAuthCallback( - machineKeyStr string, + nodeKeyStr string, namespaceName string, registrationMethod string, ) (*Machine, error) { - if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { + if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { if registrationMachine, ok := machineInterface.(Machine); ok { namespace, err := h.GetNamespace(namespaceName) if err != nil {