diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 889e60d5..c92a4497 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -147,105 +147,6 @@ func (s *Suite) TestListPeers(c *check.C) { c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") } -func (s *Suite) TestGetACLFilteredPeers(c *check.C) { - type base struct { - user *types.User - key *types.PreAuthKey - } - - stor := make([]base, 0) - - for _, name := range []string{"test", "admin"} { - user, err := db.CreateUser(types.User{Name: name}) - c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - stor = append(stor, base{user, pak}) - } - - _, err := db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1)) - node := types.Node{ - ID: types.NodeID(index), - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPv4: &v4, - Hostname: "testnode" + strconv.Itoa(index), - UserID: stor[index%2].user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(stor[index%2].key.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - } - - aclPolicy := &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:test": {"admin"}, - }, - Hosts: map[string]netip.Prefix{}, - TagOwners: map[string][]string{}, - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"admin"}, - Destinations: []string{"*:*"}, - }, - { - Action: "accept", - Sources: []string{"test"}, - Destinations: []string{"test:*"}, - }, - }, - Tests: []policy.ACLTest{}, - } - - adminNode, err := db.GetNodeByID(1) - c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User) - c.Assert(adminNode.IPv4, check.NotNil) - c.Assert(adminNode.IPv6, check.IsNil) - c.Assert(err, check.IsNil) - - testNode, err := db.GetNodeByID(2) - c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User) - c.Assert(err, check.IsNil) - - adminPeers, err := db.ListPeers(adminNode.ID) - c.Assert(err, check.IsNil) - c.Assert(len(adminPeers), check.Equals, 9) - - testPeers, err := db.ListPeers(testNode.ID) - c.Assert(err, check.IsNil) - c.Assert(len(testPeers), check.Equals, 9) - - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) - c.Assert(err, check.IsNil) - - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) - c.Assert(err, check.IsNil) - - peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) - peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules) - c.Log(peersOfAdminNode) - c.Log(peersOfTestNode) - - c.Assert(len(peersOfTestNode), check.Equals, 9) - c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1") - c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3") - c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5") - - c.Assert(len(peersOfAdminNode), check.Equals, 9) - c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2") - c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4") - c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7") -} - func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser(types.User{Name: "test"}) c.Assert(err, check.IsNil) @@ -457,143 +358,171 @@ func TestHeadscale_generateGivenName(t *testing.T) { } } -// TODO(kradalby): replace this test -// func TestAutoApproveRoutes(t *testing.T) { -// tests := []struct { -// name string -// acl string -// routes []netip.Prefix -// want []netip.Prefix -// }{ -// { -// name: "2068-approve-issue-sub", -// acl: ` -// { -// "groups": { -// "group:k8s": ["test"] -// }, +func TestAutoApproveRoutes(t *testing.T) { + tests := []struct { + name string + acl string + routes []netip.Prefix + want []netip.Prefix + want2 []netip.Prefix + }{ + { + name: "2068-approve-issue-sub-kube", + acl: ` +{ + "groups": { + "group:k8s": ["test@"] + }, // "acls": [ // {"action": "accept", "users": ["*"], "ports": ["*:*"]}, // ], -// "autoApprovers": { -// "routes": { -// "10.42.0.0/16": ["test"], -// } -// } -// }`, -// routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, -// want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, -// }, -// { -// name: "2068-approve-issue-sub", -// acl: ` -// { -// "tagOwners": { -// "tag:exit": ["test"], -// }, + "autoApprovers": { + "routes": { + "10.42.0.0/16": ["test@"], + } + } +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + }, + { + name: "2068-approve-issue-sub-exit-tag", + acl: ` +{ + "tagOwners": { + "tag:exit": ["test@"], + }, -// "groups": { -// "group:test": ["test"] -// }, + "groups": { + "group:test": ["test@"] + }, // "acls": [ // {"action": "accept", "users": ["*"], "ports": ["*:*"]}, // ], -// "autoApprovers": { -// "exitNode": ["tag:exit"], -// "routes": { -// "10.10.0.0/16": ["group:test"], -// "10.11.0.0/16": ["test"], -// } -// } -// }`, -// routes: []netip.Prefix{ -// tsaddr.AllIPv4(), -// tsaddr.AllIPv6(), -// netip.MustParsePrefix("10.10.0.0/16"), -// netip.MustParsePrefix("10.11.0.0/24"), -// }, -// want: []netip.Prefix{ -// tsaddr.AllIPv4(), -// netip.MustParsePrefix("10.10.0.0/16"), -// netip.MustParsePrefix("10.11.0.0/24"), -// tsaddr.AllIPv6(), -// }, -// }, -// } + "autoApprovers": { + "exitNode": ["tag:exit"], + "routes": { + "10.10.0.0/16": ["group:test"], + "10.11.0.0/16": ["test@"], + "8.11.0.0/24": ["test2@"], // No nodes + } + } +}`, + routes: []netip.Prefix{ + tsaddr.AllIPv4(), + tsaddr.AllIPv6(), + netip.MustParsePrefix("10.10.0.0/16"), + netip.MustParsePrefix("10.11.0.0/24"), -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// adb, err := newSQLiteTestDB() -// require.NoError(t, err) -// pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) + // Not approved + netip.MustParsePrefix("8.11.0.0/24"), + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.0.0/16"), + netip.MustParsePrefix("10.11.0.0/24"), + }, + want2: []netip.Prefix{ + tsaddr.AllIPv4(), + tsaddr.AllIPv6(), + }, + }, + } -// require.NoError(t, err) -// require.NotNil(t, pol) + for _, tt := range tests { + pmfs := policy.PolicyManagerFuncsForTest([]byte(tt.acl)) + for i, pmf := range pmfs { + version := i + 1 + t.Run(fmt.Sprintf("%s-policyv%d", tt.name, version), func(t *testing.T) { + adb, err := newSQLiteTestDB() + require.NoError(t, err) -// user, err := adb.CreateUser(types.User{Name: "test"}) -// require.NoError(t, err) + suffix := "" + if version == 1 { + suffix = "@" + } -// pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, nil, nil) -// require.NoError(t, err) + user, err := adb.CreateUser(types.User{Name: "test" + suffix}) + require.NoError(t, err) + _, err = adb.CreateUser(types.User{Name: "test2" + suffix}) + require.NoError(t, err) + taggedUser, err := adb.CreateUser(types.User{Name: "tagged" + suffix}) + require.NoError(t, err) -// nodeKey := key.NewNode() -// machineKey := key.NewMachine() + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.routes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + } -// v4 := netip.MustParseAddr("100.64.0.1") -// node := types.Node{ -// ID: 0, -// MachineKey: machineKey.Public(), -// NodeKey: nodeKey.Public(), -// Hostname: "test", -// UserID: user.ID, -// RegisterMethod: util.RegisterMethodAuthKey, -// AuthKeyID: ptr.To(pak.ID), -// Hostinfo: &tailcfg.Hostinfo{ -// RequestTags: []string{"tag:exit"}, -// RoutableIPs: tt.routes, -// }, -// IPv4: &v4, -// } + err = adb.DB.Save(&node).Error + require.NoError(t, err) -// trx := adb.DB.Save(&node) -// require.NoError(t, trx.Error) + nodeTagged := types.Node{ + ID: 2, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "taggednode", + UserID: taggedUser.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.routes, + }, + ForcedTags: []string{"tag:exit"}, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + } -// sendUpdate, err := adb.SaveNodeRoutes(&node) -// require.NoError(t, err) -// assert.False(t, sendUpdate) + err = adb.DB.Save(&nodeTagged).Error + require.NoError(t, err) -// node0ByID, err := adb.GetNodeByID(0) -// require.NoError(t, err) + users, err := adb.ListUsers() + assert.NoError(t, err) -// users, err := adb.ListUsers() -// assert.NoError(t, err) + nodes, err := adb.ListNodes() + assert.NoError(t, err) -// nodes, err := adb.ListNodes() -// assert.NoError(t, err) + pm, err := pmf(users, nodes) + require.NoError(t, err) + require.NotNil(t, pm) -// pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes) -// assert.NoError(t, err) + changed1 := policy.AutoApproveRoutes(pm, &node) + assert.True(t, changed1) -// // TODO(kradalby): Check state update -// err = adb.EnableAutoApprovedRoutes(pm, node0ByID) -// require.NoError(t, err) + err = adb.DB.Save(&node).Error + require.NoError(t, err) -// enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) -// require.NoError(t, err) -// assert.Len(t, enabledRoutes, len(tt.want)) + _ = policy.AutoApproveRoutes(pm, &nodeTagged) -// tsaddr.SortPrefixes(enabledRoutes) + err = adb.DB.Save(&nodeTagged).Error + require.NoError(t, err) -// if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" { -// t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) -// } -// }) -// } -// } + node1ByID, err := adb.GetNodeByID(1) + require.NoError(t, err) + + if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { + t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) + } + + node2ByID, err := adb.GetNodeByID(2) + require.NoError(t, err) + + if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { + t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) + } + }) + } + } +} func TestEphemeralGarbageCollectorOrder(t *testing.T) { want := []types.NodeID{1, 3}