1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

use userID instead of username everywhere

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-11-17 19:40:06 -07:00 committed by Juan Font
parent 4b58dc6eb4
commit 7ba0c3d515
10 changed files with 178 additions and 151 deletions

View File

@ -121,12 +121,12 @@ func TestMigrations(t *testing.T) {
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite", dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) { wantFunc: func(t *testing.T, h *HSDatabase) {
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
kratest, err := ListPreAuthKeys(rx, "kratest") kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest
if err != nil { if err != nil {
return nil, err return nil, err
} }
testkra, err := ListPreAuthKeys(rx, "testkra") testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
}) })
} }
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, user, name) return getNode(rx, uid, name)
}) })
} }
// getNode finds a Node by name and user and returns the Node struct. // getNode finds a Node by name and user and returns the Node struct.
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
nodes, err := ListNodesByUser(tx, user) nodes, err := ListNodesByUser(tx, uid)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -30,10 +30,10 @@ func (s *Suite) TestGetNode(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -51,7 +51,7 @@ func (s *Suite) TestGetNode(c *check.C) {
trx := db.DB.Save(node) trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
_, err = db.getNode("test", "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
@ -59,7 +59,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
@ -88,7 +88,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
@ -136,7 +136,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(user.Name, "testnode3") _, err = db.getNode(types.UserID(user.ID), "testnode3")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
@ -144,7 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
@ -190,7 +190,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
for _, name := range []string{"test", "admin"} { for _, name := range []string{"test", "admin"} {
user, err := db.CreateUser(name) user, err := db.CreateUser(name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
stor = append(stor, base{user, pak}) stor = append(stor, base{user, pak})
} }
@ -282,10 +282,10 @@ func (s *Suite) TestExpireNode(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -303,7 +303,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
} }
db.DB.Save(node) db.DB.Save(node)
nodeFromDB, err := db.getNode("test", "testnode") nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nodeFromDB, check.NotNil) c.Assert(nodeFromDB, check.NotNil)
@ -313,7 +313,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
err = db.NodeSetExpiry(nodeFromDB.ID, now) err = db.NodeSetExpiry(nodeFromDB.ID, now)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
nodeFromDB, err = db.getNode("test", "testnode") nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, true) c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
@ -323,10 +323,10 @@ func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -349,7 +349,7 @@ func (s *Suite) TestSetTags(c *check.C) {
sTags := []string{"tag:test", "tag:foo"} sTags := []string{"tag:test", "tag:foo"}
err = db.SetTags(node.ID, sTags) err = db.SetTags(node.ID, sTags)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode") node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, sTags) c.Assert(node.ForcedTags, check.DeepEquals, sTags)
@ -357,7 +357,7 @@ func (s *Suite) TestSetTags(c *check.C) {
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
err = db.SetTags(node.ID, eTags) err = db.SetTags(node.ID, eTags)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode") node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert( c.Assert(
node.ForcedTags, node.ForcedTags,
@ -368,7 +368,7 @@ func (s *Suite) TestSetTags(c *check.C) {
// test removing tags // test removing tags
err = db.SetTags(node.ID, []string{}) err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode") node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, []string{}) c.Assert(node.ForcedTags, check.DeepEquals, []string{})
} }
@ -568,7 +568,7 @@ func TestAutoApproveRoutes(t *testing.T) {
user, err := adb.CreateUser("test") user, err := adb.CreateUser("test")
require.NoError(t, err) require.NoError(t, err)
pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
require.NoError(t, err) require.NoError(t, err)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -700,10 +700,10 @@ func TestListEphemeralNodes(t *testing.T) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
require.NoError(t, err) require.NoError(t, err)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
require.NoError(t, err) require.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
require.NoError(t, err) require.NoError(t, err)
node := types.Node{ node := types.Node{

View File

@ -23,29 +23,27 @@ var (
) )
func (hsdb *HSDatabase) CreatePreAuthKey( func (hsdb *HSDatabase) CreatePreAuthKey(
// TODO(kradalby): Should be ID, not name uid types.UserID,
userName string,
reusable bool, reusable bool,
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKey, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags) return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
}) })
} }
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey( func CreatePreAuthKey(
tx *gorm.DB, tx *gorm.DB,
// TODO(kradalby): Should be ID, not name uid types.UserID,
userName string,
reusable bool, reusable bool,
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKey, error) {
user, err := GetUserByUsername(tx, userName) user, err := GetUserByID(tx, uid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,15 +87,15 @@ func CreatePreAuthKey(
return &key, nil return &key, nil
} }
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeys(rx, userName) return ListPreAuthKeysByUser(rx, uid)
}) })
} }
// ListPreAuthKeys returns the list of PreAuthKeys for a user. // ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
user, err := GetUserByUsername(tx, userName) user, err := GetUserByID(tx, uid)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -11,14 +11,14 @@ import (
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) {
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) // ID does not exist
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Did we get a valid key? // Did we get a valid key?
@ -26,17 +26,18 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
c.Assert(len(key.Key), check.Equals, 48) c.Assert(len(key.Key), check.Equals, 48)
// Make sure the User association is populated // Make sure the User association is populated
c.Assert(key.User.Name, check.Equals, user.Name) c.Assert(key.User.ID, check.Equals, user.ID)
_, err = db.ListPreAuthKeys("bogus") // ID does not exist
_, err = db.ListPreAuthKeys(1000000)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
keys, err := db.ListPreAuthKeys(user.Name) keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1) c.Assert(len(keys), check.Equals, 1)
// Make sure the User association is populated // Make sure the User association is populated
c.Assert((keys)[0].User.Name, check.Equals, user.Name) c.Assert((keys)[0].User.ID, check.Equals, user.ID)
} }
func (*Suite) TestExpiredPreAuthKey(c *check.C) { func (*Suite) TestExpiredPreAuthKey(c *check.C) {
@ -44,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now().Add(-5 * time.Second) now := time.Now().Add(-5 * time.Second)
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
@ -62,7 +63,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) {
user, err := db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
@ -74,7 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
user, err := db.CreateUser("test4") user, err := db.CreateUser("test4")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -96,7 +97,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
user, err := db.CreateUser("test5") user, err := db.CreateUser("test5")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -118,7 +119,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
user, err := db.CreateUser("test6") user, err := db.CreateUser("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
@ -130,7 +131,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
user, err := db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil) c.Assert(pak.Expiration, check.IsNil)
@ -147,7 +148,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
user, err := db.CreateUser("test6") user, err := db.CreateUser("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak.Used = true pak.Used = true
db.DB.Save(&pak) db.DB.Save(&pak)
@ -160,15 +161,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
user, err := db.CreateUser("test8") user, err := db.CreateUser("test8")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
tags := []string{"tag:test1", "tag:test2"} tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
listedPaks, err := db.ListPreAuthKeys("test8") listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
gotTags := listedPaks[0].Proto().GetAclTags() gotTags := listedPaks[0].Proto().GetAclTags()
sort.Sort(sort.StringSlice(gotTags)) sort.Sort(sort.StringSlice(gotTags))

View File

@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "test_get_route_node") _, err = db.getNode(types.UserID(user.ID), "test_get_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24") route, err := netip.ParsePrefix("10.0.0.0/24")
@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "test_enable_route_node") _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "test_enable_route_node") _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode("test", "test_enable_route_node") _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix( prefix, err := netip.ParsePrefix(

View File

@ -40,21 +40,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
return &user, nil return &user, nil
} }
func (hsdb *HSDatabase) DestroyUser(name string) error { func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return DestroyUser(tx, name) return DestroyUser(tx, uid)
}) })
} }
// DestroyUser destroys a User. Returns error if the User does // DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it. // not exist or if there are nodes associated with it.
func DestroyUser(tx *gorm.DB, name string) error { func DestroyUser(tx *gorm.DB, uid types.UserID) error {
user, err := GetUserByUsername(tx, name) user, err := GetUserByID(tx, uid)
if err != nil { if err != nil {
return ErrUserNotFound return err
} }
nodes, err := ListNodesByUser(tx, name) nodes, err := ListNodesByUser(tx, uid)
if err != nil { if err != nil {
return err return err
} }
@ -62,7 +62,7 @@ func DestroyUser(tx *gorm.DB, name string) error {
return ErrUserStillHasNodes return ErrUserStillHasNodes
} }
keys, err := ListPreAuthKeys(tx, name) keys, err := ListPreAuthKeysByUser(tx, uid)
if err != nil { if err != nil {
return err return err
} }
@ -80,17 +80,17 @@ func DestroyUser(tx *gorm.DB, name string) error {
return nil return nil
} }
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return RenameUser(tx, oldName, newName) return RenameUser(tx, uid, newName)
}) })
} }
// RenameUser renames a User. Returns error if the User does // RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name. // not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, oldName, newName string) error { func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
var err error var err error
oldUser, err := GetUserByUsername(tx, oldName) oldUser, err := GetUserByID(tx, uid)
if err != nil { if err != nil {
return err return err
} }
@ -98,50 +98,25 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
if err != nil { if err != nil {
return err return err
} }
_, err = GetUserByUsername(tx, newName)
if err == nil {
return ErrUserExists
}
if !errors.Is(err, ErrUserNotFound) {
return err
}
oldUser.Name = newName oldUser.Name = newName
if result := tx.Save(&oldUser); result.Error != nil { if err := tx.Save(&oldUser).Error; err != nil {
return result.Error return err
} }
return nil return nil
} }
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByUsername(rx, name) return GetUserByID(rx, uid)
}) })
} }
func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) { func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
user := types.User{} user := types.User{}
if result := tx.First(&user, "name = ?", name); errors.Is( if result := tx.First(&user, "id = ?", uid); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrUserNotFound
}
return &user, nil
}
func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByID(rx, id)
})
}
func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) {
user := types.User{}
if result := tx.First(&user, "id = ?", id); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
@ -169,54 +144,65 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
return &user, nil return &user, nil
} }
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx) return ListUsers(rx, where...)
}) })
} }
// ListUsers gets all the existing users. // ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB) ([]types.User, error) { func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
if len(where) > 1 {
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
}
var user *types.User
if len(where) == 1 {
user = where[0]
}
users := []types.User{} users := []types.User{}
if err := tx.Find(&users).Error; err != nil { if err := tx.Where(user).Find(&users).Error; err != nil {
return nil, err return nil, err
} }
return users, nil return users, nil
} }
// ListNodesByUser gets all the nodes in a given user. // GetUserByName returns a user if the provided username is
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { // unique, and otherwise an error.
err := util.CheckForFQDNRules(name) func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
if err != nil { users, err := hsdb.ListUsers(&types.User{Name: name})
return nil, err
}
user, err := GetUserByUsername(tx, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(users) != 1 {
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
}
return &users[0], nil
}
// ListNodesByUser gets all the nodes in a given user.
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }
return nodes, nil return nodes, nil
} }
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, node, username) return AssignNodeToUser(tx, node, uid)
}) })
} }
// AssignNodeToUser assigns a Node to a user. // AssignNodeToUser assigns a Node to a user.
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error {
err := util.CheckForFQDNRules(username) user, err := GetUserByID(tx, uid)
if err != nil {
return err
}
user, err := GetUserByUsername(tx, username)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,8 @@
package db package db
import ( import (
"strings"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
@ -17,24 +19,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1) c.Assert(len(users), check.Equals, 1)
err = db.DestroyUser("test") err = db.DestroyUser(types.UserID(user.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetUserByName("test") _, err = db.GetUserByID(types.UserID(user.ID))
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestDestroyUserErrors(c *check.C) { func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := db.DestroyUser("test") err := db.DestroyUser(9998)
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, ErrUserNotFound)
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = db.DestroyUser("test") err = db.DestroyUser(types.UserID(user.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
@ -44,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
user, err = db.CreateUser("test") user, err = db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -57,7 +59,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
err = db.DestroyUser("test") err = db.DestroyUser(types.UserID(user.ID))
c.Assert(err, check.Equals, ErrUserStillHasNodes) c.Assert(err, check.Equals, ErrUserStillHasNodes)
} }
@ -70,24 +72,28 @@ func (s *Suite) TestRenameUser(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1) c.Assert(len(users), check.Equals, 1)
err = db.RenameUser("test", "test-renamed") err = db.RenameUser(types.UserID(userTest.ID), "test-renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetUserByName("test") users, err = db.ListUsers(&types.User{Name: "test"})
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, nil)
c.Assert(len(users), check.Equals, 0)
_, err = db.GetUserByName("test-renamed") users, err = db.ListUsers(&types.User{Name: "test-renamed"})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = db.RenameUser("test-does-not-exit", "test") err = db.RenameUser(99988, "test")
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := db.CreateUser("test2") userTest2, err := db.CreateUser("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(userTest2.Name, check.Equals, "test2") c.Assert(userTest2.Name, check.Equals, "test2")
err = db.RenameUser("test2", "test-renamed") err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed")
c.Assert(err, check.Equals, ErrUserExists) if !strings.Contains(err.Error(), "UNIQUE constraint failed") {
c.Fatalf("expected failure with unique constraint, got: %s", err.Error())
}
} }
func (s *Suite) TestSetMachineUser(c *check.C) { func (s *Suite) TestSetMachineUser(c *check.C) {
@ -97,7 +103,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
newUser, err := db.CreateUser("new") newUser, err := db.CreateUser("new")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -111,15 +117,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
c.Assert(node.UserID, check.Equals, oldUser.ID) c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, newUser.Name) err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.UserID, check.Equals, newUser.ID)
c.Assert(node.User.Name, check.Equals, newUser.Name) c.Assert(node.User.Name, check.Equals, newUser.Name)
err = db.AssignNodeToUser(&node, "non-existing-user") err = db.AssignNodeToUser(&node, 9584849)
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, ErrUserNotFound)
err = db.AssignNodeToUser(&node, newUser.Name) err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.UserID, check.Equals, newUser.ID)
c.Assert(node.User.Name, check.Equals, newUser.Name) c.Assert(node.User.Name, check.Equals, newUser.Name)

View File

@ -65,24 +65,34 @@ func (api headscaleV1APIServer) RenameUser(
ctx context.Context, ctx context.Context,
request *v1.RenameUserRequest, request *v1.RenameUserRequest,
) (*v1.RenameUserResponse, error) { ) (*v1.RenameUserResponse, error) {
err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) oldUser, err := api.h.db.GetUserByName(request.GetOldName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := api.h.db.GetUserByName(request.GetNewName()) err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &v1.RenameUserResponse{User: user.Proto()}, nil newUser, err := api.h.db.GetUserByName(request.GetNewName())
if err != nil {
return nil, err
}
return &v1.RenameUserResponse{User: newUser.Proto()}, nil
} }
func (api headscaleV1APIServer) DeleteUser( func (api headscaleV1APIServer) DeleteUser(
ctx context.Context, ctx context.Context,
request *v1.DeleteUserRequest, request *v1.DeleteUserRequest,
) (*v1.DeleteUserResponse, error) { ) (*v1.DeleteUserResponse, error) {
err := api.h.db.DestroyUser(request.GetName()) user, err := api.h.db.GetUserByName(request.GetName())
if err != nil {
return nil, err
}
err = api.h.db.DestroyUser(types.UserID(user.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -131,8 +141,13 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
} }
} }
user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
preAuthKey, err := api.h.db.CreatePreAuthKey( preAuthKey, err := api.h.db.CreatePreAuthKey(
request.GetUser(), types.UserID(user.ID),
request.GetReusable(), request.GetReusable(),
request.GetEphemeral(), request.GetEphemeral(),
&expiration, &expiration,
@ -168,7 +183,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context, ctx context.Context,
request *v1.ListPreAuthKeysRequest, request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) { ) (*v1.ListPreAuthKeysResponse, error) {
preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -406,10 +426,20 @@ func (api headscaleV1APIServer) ListNodes(
ctx context.Context, ctx context.Context,
request *v1.ListNodesRequest, request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) { ) (*v1.ListNodesResponse, error) {
// TODO(kradalby): it looks like this can be simplified a lot,
// the filtering of nodes by user, vs nodes as a whole can
// probably be done once.
// TODO(kradalby): This should be done in one tx.
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
if request.GetUser() != "" { if request.GetUser() != "" {
user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, request.GetUser()) return db.ListNodesByUser(rx, types.UserID(user.ID))
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -465,12 +495,18 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context, ctx context.Context,
request *v1.MoveNodeRequest, request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) { ) (*v1.MoveNodeResponse, error) {
// TODO(kradalby): This should be done in one tx.
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = api.h.db.AssignNodeToUser(node, request.GetUser()) user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -26,7 +26,7 @@ type User struct {
// Username for the user, is used if email is empty // Username for the user, is used if email is empty
// Should not be used, please use Username(). // Should not be used, please use Username().
Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"`
// Typically the full name of the user // Typically the full name of the user
DisplayName string DisplayName string