diff --git a/machine.go b/machine.go index aebfbcef..b436fd87 100644 --- a/machine.go +++ b/machine.go @@ -375,6 +375,19 @@ func (h *Headscale) GetMachineByNodeKey( return &machine, nil } +// GetMachineByAnyNodeKey finds a Machine by its current NodeKey or the old one, and returns the Machine struct. +func (h *Headscale) GetMachineByAnyNodeKey( + nodeKey key.NodePublic, oldNodeKey key.NodePublic, +) (*Machine, error) { + machine := Machine{} + if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?", + NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { + return nil, result.Error + } + + return &machine, nil +} + // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { diff --git a/machine_test.go b/machine_test.go index 53d065ff..5da0906f 100644 --- a/machine_test.go +++ b/machine_test.go @@ -11,6 +11,7 @@ import ( "gopkg.in/check.v1" "inet.af/netaddr" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) func (s *Suite) TestGetMachine(c *check.C) { @@ -65,6 +66,63 @@ func (s *Suite) TestGetMachineByID(c *check.C) { c.Assert(err, check.IsNil) } +func (s *Suite) TestGetMachineByNodeKey(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + + machine := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + app.db.Save(&machine) + + _, err = app.GetMachineByNodeKey(nodeKey.Public()) + c.Assert(err, check.IsNil) +} + +func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + oldNodeKey := key.NewNode() + + machine := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + app.db.Save(&machine) + + _, err = app.GetMachineByAnyNodeKey(nodeKey.Public(), oldNodeKey.Public()) + c.Assert(err, check.IsNil) +} + func (s *Suite) TestDeleteMachine(c *check.C) { namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil)