From db6cf4ac0a8cbf8350038bb04238b2cfdcdcbaed Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 8 Jun 2023 19:10:09 +0200 Subject: [PATCH] make GenerateFilterRules take machine and peers Signed-off-by: Kristoffer Dalby --- hscontrol/db/acls_test.go | 549 +++++++++++++++----------------- hscontrol/db/machine_test.go | 14 +- hscontrol/mapper/mapper.go | 4 +- hscontrol/mapper/mapper_test.go | 6 +- hscontrol/policy/acls.go | 28 +- hscontrol/policy/acls_test.go | 6 +- 6 files changed, 291 insertions(+), 316 deletions(-) diff --git a/hscontrol/db/acls_test.go b/hscontrol/db/acls_test.go index 884b6c5c..e7d09a5f 100644 --- a/hscontrol/db/acls_test.go +++ b/hscontrol/db/acls_test.go @@ -2,10 +2,13 @@ package db import ( "net/netip" + "testing" + "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" "gopkg.in/check.v1" "tailscale.com/envknob" "tailscale.com/tailcfg" @@ -77,7 +80,7 @@ func (s *Suite) TestSshRules(c *check.C) { }, } - _, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false) + _, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, &machine, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(sshPolicy, check.NotNil) @@ -94,15 +97,7 @@ func (s *Suite) TestSshRules(c *check.C) { // this test should validate that we can expand a group in a TagOWner section and // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // the tag is matched in the Sources section. -func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { - user, err := db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) +func TestValidExpandTagOwnersInSources(t *testing.T) { hostInfo := tailcfg.Hostinfo{ OS: "centos", Hostname: "testmachine", @@ -110,19 +105,19 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { } machine := types.Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: 0, + User: types.User{ + Name: "user1", + }, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), HostInfo: types.HostInfo(hostInfo), } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) pol := &policy.ACLPolicy{ Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, @@ -136,85 +131,28 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { }, } - machines, err := db.ListMachines() - c.Assert(err, check.IsNil) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + assert.NoError(t, err) - rules, _, err := policy.GenerateFilterRules(pol, machines, false) - c.Assert(err, check.IsNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Destinations section. -func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { - user, err := db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := types.Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo), - } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) - - pol := &policy.ACLPolicy{ - Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"tag:test:*"}, + want := []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}}, + {IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}}, }, }, } - machines, err := db.ListMachines() - c.Assert(err, check.IsNil) - - rules, _, err := policy.GenerateFilterRules(pol, machines, false) - c.Assert(err, check.IsNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff) + } } // need a test with: // tag on a host that isn't owned by a tag owners. So the user // of the host should be valid. -func (s *Suite) TestInvalidTagValidUser(c *check.C) { - user, err := db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) +func TestInvalidTagValidUser(t *testing.T) { hostInfo := tailcfg.Hostinfo{ OS: "centos", Hostname: "testmachine", @@ -222,19 +160,19 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { } machine := types.Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: 1, + User: types.User{ + Name: "user1", + }, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), HostInfo: types.HostInfo(hostInfo), } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) pol := &policy.ACLPolicy{ TagOwners: policy.TagOwners{"tag:test": []string{"user1"}}, @@ -247,190 +185,38 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { }, } - machines, err := db.ListMachines() - c.Assert(err, check.IsNil) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + assert.NoError(t, err) - rules, _, err := policy.GenerateFilterRules(pol, machines, false) - c.Assert(err, check.IsNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") -} - -// tag on a host is owned by a tag owner, the tag is valid. -// an ACL rule is matching the tag to a user. It should not be valid since the -// host should be tied to the tag now. -func (s *Suite) TestValidTagInvalidUser(c *check.C) { - user, err := db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("user1", "webserver") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "webserver", - RequestTags: []string{"tag:webapp"}, - } - - machine := types.Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "webserver", - IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo), - } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("user1", "user") - hostInfo2 := tailcfg.Hostinfo{ - OS: "debian", - Hostname: "Hostname", - } - c.Assert(err, check.NotNil) - machine = types.Machine{ - ID: 2, - MachineKey: "56789", - NodeKey: "bar2", - DiscoKey: "faab", - Hostname: "user", - IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo2), - } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) - - pol := &policy.ACLPolicy{ - TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}}, - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"tag:webapp:80,443"}, + want := []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}}, + {IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}}, }, }, } - machines, err := db.ListMachines() - c.Assert(err, check.IsNil) - - rules, _, err := policy.GenerateFilterRules(pol, machines, false) - c.Assert(err, check.IsNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") - c.Assert(rules[0].DstPorts, check.HasLen, 2) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) - c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) - c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) - c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff) + } } -func (s *Suite) TestPortUser(c *check.C) { - user, err := db.CreateUser("testuser") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("testuser", "testmachine") - c.Assert(err, check.NotNil) - ips, _ := db.getAvailableIPs() +func TestPortGroup(t *testing.T) { machine := types.Machine{ - ID: 0, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: ips, - AuthKeyID: uint(pak.ID), - } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) - - acl := []byte(` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: 0, + User: types.User{ + Name: "testuser", }, - ], -} - `) - pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - c.Assert(pol, check.NotNil) - - machines, err := db.ListMachines() - c.Assert(err, check.IsNil) - - rules, _, err := policy.GenerateFilterRules(pol, machines, false) - c.Assert(err, check.IsNil) - - c.Assert(err, check.IsNil) - c.Assert(rules, check.NotNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert(len(ips), check.Equals, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") -} - -func (s *Suite) TestPortGroup(c *check.C) { - user, err := db.CreateUser("testuser") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("testuser", "testmachine") - c.Assert(err, check.NotNil) - ips, _ := db.getAvailableIPs() - machine := types.Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: ips, - AuthKeyID: uint(pak.ID), + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.5")}, } - err = db.MachineSave(&machine) - c.Assert(err, check.IsNil) acl := []byte(` { @@ -459,22 +245,211 @@ func (s *Suite) TestPortGroup(c *check.C) { } `) pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) + assert.NoError(t, err) - machines, err := db.ListMachines() - c.Assert(err, check.IsNil) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + assert.NoError(t, err) - rules, _, err := policy.GenerateFilterRules(pol, machines, false) - c.Assert(err, check.IsNil) + want := []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.5/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}}, + }, + }, + } - c.Assert(rules, check.NotNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert(len(ips), check.Equals, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("TestPortGroup() unexpected result (-want +got):\n%s", diff) + } +} + +func TestPortUser(t *testing.T) { + machine := types.Machine{ + ID: 0, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: 0, + User: types.User{ + Name: "testuser", + }, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.9")}, + } + + acl := []byte(` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} + `) + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + assert.NoError(t, err) + + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + assert.NoError(t, err) + + want := []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.9/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}}, + }, + }, + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("TestPortUser() unexpected result (-want +got):\n%s", diff) + } +} + +// this test should validate that we can expand a group in a TagOWner section and +// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. +// the tag is matched in the Destinations section. +func TestValidExpandTagOwnersInDestinations(t *testing.T) { + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: 1, + User: types.User{ + Name: "user1", + }, + RegisterMethod: util.RegisterMethodAuthKey, + HostInfo: types.HostInfo(hostInfo), + } + + pol := &policy.ACLPolicy{ + Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"tag:test:*"}, + }, + }, + } + + // rules, _, err := policy.GenerateFilterRules(pol, &machine, peers, false) + // c.Assert(err, check.IsNil) + // + // c.Assert(rules, check.HasLen, 1) + // c.Assert(rules[0].DstPorts, check.HasLen, 1) + // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") + + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + assert.NoError(t, err) + + want := []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}}, + }, + }, + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", diff) + } +} + +// tag on a host is owned by a tag owner, the tag is valid. +// an ACL rule is matching the tag to a user. It should not be valid since the +// host should be tied to the tag now. +func TestValidTagInvalidUser(t *testing.T) { + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "webserver", + RequestTags: []string{"tag:webapp"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "webserver", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: 1, + User: types.User{ + Name: "user1", + }, + RegisterMethod: util.RegisterMethodAuthKey, + HostInfo: types.HostInfo(hostInfo), + } + + hostInfo2 := tailcfg.Hostinfo{ + OS: "debian", + Hostname: "Hostname", + } + + machine2 := types.Machine{ + ID: 2, + MachineKey: "56789", + NodeKey: "bar2", + DiscoKey: "faab", + Hostname: "user", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + UserID: 1, + User: types.User{ + Name: "user1", + }, + RegisterMethod: util.RegisterMethodAuthKey, + HostInfo: types.HostInfo(hostInfo2), + } + + pol := &policy.ACLPolicy{ + TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"tag:webapp:80,443"}, + }, + }, + } + + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false) + assert.NoError(t, err) + + want := []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, + {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, + }, + }, + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff) + } } diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index deab3967..92021e7d 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -287,14 +287,20 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) c.Assert(err, check.IsNil) - machines, err := db.ListMachines() + adminPeers, err := db.ListPeers(adminMachine) c.Assert(err, check.IsNil) - aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) + testPeers, err := db.ListPeers(testMachine) c.Assert(err, check.IsNil) - peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules) - peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules) + adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false) + c.Assert(err, check.IsNil) + + testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false) + c.Assert(err, check.IsNil) + + peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules) + peersOfTestMachine := policy.FilterMachinesByACL(testMachine, testPeers, testRules) c.Log(peersOfTestMachine) c.Assert(len(peersOfTestMachine), check.Equals, 9) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 88e3ada0..6f9498ea 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -101,8 +101,8 @@ func fullMapResponse( rules, sshPolicy, err := policy.GenerateFilterRules( pol, - // The policy is currently calculated for the entire Headscale network - append(peers, *machine), + machine, + peers, stripEmailDomain, ) if err != nil { diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 851c8df4..919a22b2 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -360,7 +360,7 @@ func Test_fullMapResponse(t *testing.T) { CollectServices: "false", PacketFilter: []tailcfg.FilterRule{}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: nil, + SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, ControlTime: &time.Time{}, Debug: &tailcfg.Debug{ DisableLogTail: true, @@ -393,7 +393,7 @@ func Test_fullMapResponse(t *testing.T) { CollectServices: "false", PacketFilter: []tailcfg.FilterRule{}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: nil, + SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, ControlTime: &time.Time{}, Debug: &tailcfg.Debug{ DisableLogTail: true, @@ -442,7 +442,7 @@ func Test_fullMapResponse(t *testing.T) { }, }, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: nil, + SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, ControlTime: &time.Time{}, Debug: &tailcfg.Debug{ DisableLogTail: true, diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 675dd46e..deaf8a62 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -18,7 +18,6 @@ import ( "github.com/tailscale/hujson" "go4.org/netipx" "gopkg.in/yaml.v3" - "tailscale.com/envknob" "tailscale.com/tailcfg" ) @@ -54,8 +53,6 @@ const ( ProtocolFC = 133 // Fibre Channel ) -var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") - // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { log.Debug(). @@ -122,7 +119,8 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { // per node and that should be taken into account. func GenerateFilterRules( policy *ACLPolicy, - machines types.Machines, + machine *types.Machine, + peers types.Machines, stripEmailDomain bool, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all @@ -130,7 +128,7 @@ func GenerateFilterRules( return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.generateFilterRules(machines, stripEmailDomain) + rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -138,19 +136,15 @@ func GenerateFilterRules( log.Trace().Interface("ACL", rules).Msg("ACL rules generated") var sshPolicy *tailcfg.SSHPolicy - if featureEnableSSH() { - sshRules, err := generateSSHRules(policy, machines, stripEmailDomain) - if err != nil { - return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err - } - log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") - if sshPolicy == nil { - sshPolicy = &tailcfg.SSHPolicy{} - } - sshPolicy.Rules = sshRules - } else if policy != nil && len(policy.SSHs) > 0 { - log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") + sshRules, err := generateSSHRules(policy, append(peers, *machine), stripEmailDomain) + if err != nil { + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } + log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") + if sshPolicy == nil { + sshPolicy = &tailcfg.SSHPolicy{} + } + sshPolicy.Rules = sshRules return rules, sshPolicy, nil } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 2093d939..e9352d2a 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -245,7 +245,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -264,7 +264,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -280,7 +280,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) }