From eeed3ab4f5bb5643d8a38bd4b91777c2396ed8e9 Mon Sep 17 00:00:00 2001 From: hopleus Date: Wed, 16 Oct 2024 19:58:21 +0300 Subject: [PATCH] Corrected existing tests. New tests added --- cmd/headscale/headscale_test.go | 1 + hscontrol/db/db_test.go | 15 ++-- hscontrol/db/node_test.go | 109 +++++++++++++++++++++++++++--- hscontrol/db/preauth_keys_test.go | 32 ++++++--- hscontrol/db/routes_test.go | 9 +-- hscontrol/db/suite_test.go | 1 + hscontrol/db/users_test.go | 6 +- hscontrol/mapper/mapper_test.go | 3 + hscontrol/mapper/tail_test.go | 3 +- 9 files changed, 145 insertions(+), 34 deletions(-) diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 00c4a276..323434c1 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -111,4 +111,5 @@ func (*Suite) TestConfigLoading(c *check.C) { ) c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false) + c.Assert(viper.GetBool("node_management.manual_approve_new_node"), check.Equals, false) } diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index bafe1e1b..8b7d3888 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -206,12 +206,17 @@ func TestMigrations(t *testing.T) { t.Fatalf("copying db for test: %s", err) } - hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{ - Type: "sqlite3", - Sqlite: types.SqliteConfig{ - Path: dbPath, + hsdb, err := NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, }, - }, "", emptyCache()) + "", + types.NodeManagement{}, + emptyCache(), + ) if err != nil && tt.wantErr != err.Error() { t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 7c83c1be..83490ac2 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -30,7 +30,7 @@ func (s *Suite) TestGetNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -59,7 +59,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -88,7 +88,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -144,7 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -162,6 +162,7 @@ func (s *Suite) TestListPeers(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), + Approved: true, } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -177,6 +178,53 @@ func (s *Suite) TestListPeers(c *check.C) { c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2") c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7") c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") + c.Assert(peersOfNode0[0].IsApproved(), check.Equals, true) + c.Assert(peersOfNode0[5].IsApproved(), check.Equals, true) + c.Assert(peersOfNode0[8].IsApproved(), check.Equals, true) +} + +func (s *Suite) TestListPeersWithoutNonAuthorized(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetNodeByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 4; index++ { + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + var approved bool + if index == 4 { + approved = true + } + + node := types.Node{ + ID: types.NodeID(index), + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode" + strconv.Itoa(index), + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + Approved: approved, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + } + + node0ByID, err := db.GetNodeByID(0) + c.Assert(err, check.IsNil) + + peersOfNode0, err := db.ListPeers(node0ByID.ID) + c.Assert(err, check.IsNil) + + c.Assert(len(peersOfNode0), check.Equals, 1) + c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode4") + c.Assert(peersOfNode0[0].IsApproved(), check.Equals, true) } func (s *Suite) TestGetACLFilteredPeers(c *check.C) { @@ -190,7 +238,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for _, name := range []string{"test", "admin"} { user, err := db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } @@ -211,6 +259,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: stor[index%2].user.ID, RegisterMethod: util.RegisterMethodAuthKey, + Approved: true, AuthKeyID: ptr.To(stor[index%2].key.ID), } trx := db.DB.Save(&node) @@ -278,11 +327,51 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7") } +func (s *Suite) TestApprovedNode(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, true, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.getNode(types.UserID(user.ID), "testnode") + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + node := &types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + Expiry: &time.Time{}, + } + db.DB.Save(node) + + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") + c.Assert(err, check.IsNil) + c.Assert(nodeFromDB, check.NotNil) + + c.Assert(nodeFromDB.IsApproved(), check.Equals, false) + + err = db.NodeSetApprove(nodeFromDB.ID, true) + c.Assert(err, check.IsNil) + + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") + c.Assert(err, check.IsNil) + + c.Assert(nodeFromDB.IsApproved(), check.Equals, true) +} + func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -323,7 +412,7 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -568,7 +657,7 @@ func TestAutoApproveRoutes(t *testing.T) { user, err := adb.CreateUser("test") require.NoError(t, err) - pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) require.NoError(t, err) nodeKey := key.NewNode() @@ -709,10 +798,10 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser("test") require.NoError(t, err) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) require.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, true, nil, nil) require.NoError(t, err) node := types.Node{ diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 3c56a35e..618c3d89 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -12,13 +12,13 @@ import ( func (*Suite) TestCreatePreAuthKey(c *check.C) { // ID does not exist - _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) + _, err := db.CreatePreAuthKey(12345, true, true, false, nil, nil) c.Assert(err, check.NotNil) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -45,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, true, false, &now, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -53,6 +53,16 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(key, check.IsNil) } +func (*Suite) TestPreApprovedPreAuthKey(c *check.C) { + user, err := db.CreateUser("test2") + c.Assert(err, check.IsNil) + + now := time.Now().Add(-5 * time.Second) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, true, false, &now, nil) + c.Assert(err, check.IsNil) + c.Assert(pak.PreApproved, check.Equals, true) +} + func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { key, err := db.ValidatePreAuthKey("potatoKey") c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) @@ -63,7 +73,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, true, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -75,7 +85,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -97,7 +107,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -119,7 +129,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -131,7 +141,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) @@ -148,7 +158,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true db.DB.Save(&pak) @@ -161,12 +171,12 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 7b11e136..57984015 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -35,7 +35,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "test_get_route_node") @@ -79,7 +79,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") @@ -153,7 +153,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") @@ -234,7 +234,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") @@ -336,6 +336,7 @@ func dbForTest(t *testing.T, testName string) *HSDatabase { }, }, "", + types.NodeManagement{}, emptyCache(), ) if err != nil { diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index fb7ce1df..cfced36e 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -66,6 +66,7 @@ func newSQLiteTestDB() (*HSDatabase, error) { }, }, "", + types.NodeManagement{}, emptyCache(), ) if err != nil { diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 06073762..05cbb505 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -33,7 +33,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) err = db.DestroyUser(types.UserID(user.ID)) @@ -46,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err = db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -104,7 +104,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { newUser, err := db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 4ee8c644..b1d61c52 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -181,6 +181,7 @@ func Test_fullMapResponse(t *testing.T) { ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, + Approved: true, Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ @@ -260,6 +261,7 @@ func Test_fullMapResponse(t *testing.T) { User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, + Approved: true, Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{}, @@ -316,6 +318,7 @@ func Test_fullMapResponse(t *testing.T) { ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, + Approved: true, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{}, CreatedAt: created, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 9d7f1fed..a49e0469 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -72,7 +72,7 @@ func TestTailNode(t *testing.T) { Hostinfo: hiview(tailcfg.Hostinfo{}), Tags: []string{}, PrimaryRoutes: []netip.Prefix{}, - MachineAuthorized: true, + MachineAuthorized: false, CapMap: tailcfg.NodeCapMap{ tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, @@ -105,6 +105,7 @@ func TestTailNode(t *testing.T) { ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, + Approved: true, Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{