From d671462811972cf317d5388dbefb439cb5f74cdf Mon Sep 17 00:00:00 2001 From: hopleus Date: Wed, 16 Oct 2024 19:58:21 +0300 Subject: [PATCH] Corrected existing tests. New tests added --- cmd/headscale/headscale_test.go | 1 + hscontrol/db/db_test.go | 15 ++-- hscontrol/db/node_test.go | 109 +++++++++++++++++++++++++++--- hscontrol/db/preauth_keys_test.go | 32 ++++++--- hscontrol/db/routes_test.go | 9 +-- hscontrol/db/suite_test.go | 1 + hscontrol/db/users_test.go | 6 +- hscontrol/mapper/mapper_test.go | 3 + hscontrol/mapper/tail_test.go | 3 +- 9 files changed, 145 insertions(+), 34 deletions(-) diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 00c4a276..323434c1 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -111,4 +111,5 @@ func (*Suite) TestConfigLoading(c *check.C) { ) c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false) + c.Assert(viper.GetBool("node_management.manual_approve_new_node"), check.Equals, false) } diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 68ea2ac1..48a93d45 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -203,12 +203,17 @@ func TestMigrations(t *testing.T) { t.Fatalf("copying db for test: %s", err) } - hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{ - Type: "sqlite3", - Sqlite: types.SqliteConfig{ - Path: dbPath, + hsdb, err := NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, }, - }, "", emptyCache()) + "", + types.NodeManagement{}, + emptyCache(), + ) if err != nil && tt.wantErr != err.Error() { t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 888f48db..0392e3b0 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -29,7 +29,7 @@ func (s *Suite) TestGetNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "testnode") @@ -58,7 +58,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -87,7 +87,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -143,7 +143,7 @@ func (s *Suite) TestListPeers(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -161,6 +161,7 @@ func (s *Suite) TestListPeers(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), + Approved: true, } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -176,6 +177,53 @@ func (s *Suite) TestListPeers(c *check.C) { c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2") c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7") c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") + c.Assert(peersOfNode0[0].IsApproved(), check.Equals, true) + c.Assert(peersOfNode0[5].IsApproved(), check.Equals, true) + c.Assert(peersOfNode0[8].IsApproved(), check.Equals, true) +} + +func (s *Suite) TestListPeersWithoutNonAuthorized(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetNodeByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 4; index++ { + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + var approved bool + if index == 4 { + approved = true + } + + node := types.Node{ + ID: types.NodeID(index), + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode" + strconv.Itoa(index), + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + Approved: approved, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + } + + node0ByID, err := db.GetNodeByID(0) + c.Assert(err, check.IsNil) + + peersOfNode0, err := db.ListPeers(node0ByID.ID) + c.Assert(err, check.IsNil) + + c.Assert(len(peersOfNode0), check.Equals, 1) + c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode4") + c.Assert(peersOfNode0[0].IsApproved(), check.Equals, true) } func (s *Suite) TestGetACLFilteredPeers(c *check.C) { @@ -189,7 +237,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for _, name := range []string{"test", "admin"} { user, err := db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } @@ -210,6 +258,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: stor[index%2].user.ID, RegisterMethod: util.RegisterMethodAuthKey, + Approved: true, AuthKeyID: ptr.To(stor[index%2].key.ID), } trx := db.DB.Save(&node) @@ -277,11 +326,51 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7") } +func (s *Suite) TestApprovedNode(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, true, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.getNode("test", "testnode") + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + node := &types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + Expiry: &time.Time{}, + } + db.DB.Save(node) + + nodeFromDB, err := db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + c.Assert(nodeFromDB, check.NotNil) + + c.Assert(nodeFromDB.IsApproved(), check.Equals, false) + + err = db.NodeSetApprove(nodeFromDB.ID, true) + c.Assert(err, check.IsNil) + + nodeFromDB, err = db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + + c.Assert(nodeFromDB.IsApproved(), check.Equals, true) +} + func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "testnode") @@ -322,7 +411,7 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "testnode") @@ -567,7 +656,7 @@ func TestAutoApproveRoutes(t *testing.T) { user, err := adb.CreateUser("test") assert.NoError(t, err) - pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := adb.CreatePreAuthKey(user.Name, false, true, false, nil, nil) assert.NoError(t, err) nodeKey := key.NewNode() @@ -699,10 +788,10 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser("test") assert.NoError(t, err) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) assert.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) + pakEph, err := db.CreatePreAuthKey(user.Name, false, true, true, nil, nil) assert.NoError(t, err) node := types.Node{ diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index ec3f6441..7894eb33 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -11,14 +11,14 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) + _, err := db.CreatePreAuthKey("bogus", true, true, false, nil, nil) c.Assert(err, check.NotNil) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := db.CreatePreAuthKey(user.Name, true, true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -44,7 +44,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, true, false, &now, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -52,6 +52,16 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(key, check.IsNil) } +func (*Suite) TestPreApprovedPreAuthKey(c *check.C) { + user, err := db.CreateUser("test2") + c.Assert(err, check.IsNil) + + now := time.Now().Add(-5 * time.Second) + pak, err := db.CreatePreAuthKey(user.Name, true, true, false, &now, nil) + c.Assert(err, check.IsNil) + c.Assert(pak.PreApproved, check.Equals, true) +} + func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { key, err := db.ValidatePreAuthKey("potatoKey") c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) @@ -62,7 +72,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, true, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -74,7 +84,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -96,7 +106,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -118,7 +128,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -130,7 +140,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) @@ -147,7 +157,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true db.DB.Save(&pak) @@ -160,12 +170,12 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(user.Name, false, true, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(user.Name, false, true, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) listedPaks, err := db.ListPreAuthKeys("test8") diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 5071077c..6474e5f7 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -35,7 +35,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "test_get_route_node") @@ -79,7 +79,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "test_enable_route_node") @@ -153,7 +153,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "test_enable_route_node") @@ -234,7 +234,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode("test", "test_enable_route_node") @@ -336,6 +336,7 @@ func dbForTest(t *testing.T, testName string) *HSDatabase { }, }, "", + types.NodeManagement{}, emptyCache(), ) if err != nil { diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 6cc46d3d..237ae290 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -59,6 +59,7 @@ func newTestDB() (*HSDatabase, error) { }, }, "", + types.NodeManagement{}, emptyCache(), ) if err != nil { diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 54399664..f276c0d1 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -31,7 +31,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) err = db.DestroyUser("test") @@ -44,7 +44,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err = db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err = db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -97,7 +97,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { newUser, err := db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(oldUser.Name, false, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 37ed5c42..d5f01284 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -178,6 +178,7 @@ func Test_fullMapResponse(t *testing.T) { ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, + Approved: true, Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ @@ -257,6 +258,7 @@ func Test_fullMapResponse(t *testing.T) { User: types.User{Name: "mini"}, ForcedTags: []string{}, LastSeen: &lastSeen, + Approved: true, Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{}, @@ -313,6 +315,7 @@ func Test_fullMapResponse(t *testing.T) { ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, + Approved: true, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{}, CreatedAt: created, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index b6692c16..59c8b646 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -72,7 +72,7 @@ func TestTailNode(t *testing.T) { Hostinfo: hiview(tailcfg.Hostinfo{}), Tags: []string{}, PrimaryRoutes: []netip.Prefix{}, - MachineAuthorized: true, + MachineAuthorized: false, CapMap: tailcfg.NodeCapMap{ tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, @@ -105,6 +105,7 @@ func TestTailNode(t *testing.T) { ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, + Approved: true, Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{