1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

make GenerateFilterRules take machine and peers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-08 19:10:09 +02:00 committed by Kristoffer Dalby
parent 35770278f7
commit db6cf4ac0a
6 changed files with 291 additions and 316 deletions

View File

@ -2,10 +2,13 @@ package db
import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
@ -77,7 +80,7 @@ func (s *Suite) TestSshRules(c *check.C) {
},
}
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false)
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, &machine, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(sshPolicy, check.NotNil)
@ -94,15 +97,7 @@ func (s *Suite) TestSshRules(c *check.C) {
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Sources section.
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
func TestValidExpandTagOwnersInSources(t *testing.T) {
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
@ -116,13 +111,13 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
UserID: 0,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
@ -136,85 +131,28 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
}
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Destinations section.
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []policy.ACL{
want := []tailcfg.FilterRule{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"tag:test:*"},
SrcIPs: []string{"100.64.0.1/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff)
}
}
// need a test with:
// tag on a host that isn't owned by a tag owners. So the user
// of the host should be valid.
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "testmachine")
c.Assert(err, check.NotNil)
func TestInvalidTagValidUser(t *testing.T) {
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
@ -228,13 +166,13 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
@ -247,190 +185,38 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
}
// tag on a host is owned by a tag owner, the tag is valid.
// an ACL rule is matching the tag to a user. It should not be valid since the
// host should be tied to the tag now.
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
user, err := db.CreateUser("user1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "webserver")
c.Assert(err, check.NotNil)
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("user1", "user")
hostInfo2 := tailcfg.Hostinfo{
OS: "debian",
Hostname: "Hostname",
}
c.Assert(err, check.NotNil)
machine = types.Machine{
ID: 2,
MachineKey: "56789",
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
ACLs: []policy.ACL{
want := []tailcfg.FilterRule{
{
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"tag:webapp:80,443"},
SrcIPs: []string{"100.64.0.1/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
c.Assert(rules[0].DstPorts, check.HasLen, 2)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
}
func (s *Suite) TestPortUser(c *check.C) {
user, err := db.CreateUser("testuser")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("testuser", "testmachine")
c.Assert(err, check.NotNil)
ips, _ := db.getAvailableIPs()
machine := types.Machine{
ID: 0,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: ips,
AuthKeyID: uint(pak.ID),
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff)
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
acl := []byte(`
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"testuser",
],
"dst": [
"host-1:*",
],
},
],
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
}
func (s *Suite) TestPortGroup(c *check.C) {
user, err := db.CreateUser("testuser")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetMachine("testuser", "testmachine")
c.Assert(err, check.NotNil)
ips, _ := db.getAvailableIPs()
func TestPortGroup(t *testing.T) {
machine := types.Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
UserID: 0,
User: types.User{
Name: "testuser",
},
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: ips,
AuthKeyID: uint(pak.ID),
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.5")},
}
err = db.MachineSave(&machine)
c.Assert(err, check.IsNil)
acl := []byte(`
{
@ -459,22 +245,211 @@ func (s *Suite) TestPortGroup(c *check.C) {
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
assert.NoError(t, err)
machines, err := db.ListMachines()
c.Assert(err, check.IsNil)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
c.Assert(err, check.IsNil)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.5/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
c.Assert(rules, check.NotNil)
c.Assert(rules, check.HasLen, 1)
c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestPortGroup() unexpected result (-want +got):\n%s", diff)
}
}
func TestPortUser(t *testing.T) {
machine := types.Machine{
ID: 0,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: 0,
User: types.User{
Name: "testuser",
},
RegisterMethod: util.RegisterMethodAuthKey,
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.9")},
}
acl := []byte(`
{
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"testuser",
],
"dst": [
"host-1:*",
],
},
],
}
`)
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
assert.NoError(t, err)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.9/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestPortUser() unexpected result (-want +got):\n%s", diff)
}
}
// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Destinations section.
func TestValidExpandTagOwnersInDestinations(t *testing.T) {
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "testmachine",
RequestTags: []string{"tag:test"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
}
pol := &policy.ACLPolicy{
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"tag:test:*"},
},
},
}
// rules, _, err := policy.GenerateFilterRules(pol, &machine, peers, false)
// c.Assert(err, check.IsNil)
//
// c.Assert(rules, check.HasLen, 1)
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
assert.NoError(t, err)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}},
},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", diff)
}
}
// tag on a host is owned by a tag owner, the tag is valid.
// an ACL rule is matching the tag to a user. It should not be valid since the
// host should be tied to the tag now.
func TestValidTagInvalidUser(t *testing.T) {
hostInfo := tailcfg.Hostinfo{
OS: "centos",
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
machine := types.Machine{
ID: 1,
MachineKey: "12345",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
}
hostInfo2 := tailcfg.Hostinfo{
OS: "debian",
Hostname: "Hostname",
}
machine2 := types.Machine{
ID: 2,
MachineKey: "56789",
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
UserID: 1,
User: types.User{
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo2),
}
pol := &policy.ACLPolicy{
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"tag:webapp:80,443"},
},
},
}
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false)
assert.NoError(t, err)
want := []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff)
}
}

View File

@ -287,14 +287,20 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
c.Assert(err, check.IsNil)
machines, err := db.ListMachines()
adminPeers, err := db.ListPeers(adminMachine)
c.Assert(err, check.IsNil)
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
testPeers, err := db.ListPeers(testMachine)
c.Assert(err, check.IsNil)
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules)
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules)
adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false)
c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false)
c.Assert(err, check.IsNil)
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, testPeers, testRules)
c.Log(peersOfTestMachine)
c.Assert(len(peersOfTestMachine), check.Equals, 9)

View File

@ -101,8 +101,8 @@ func fullMapResponse(
rules, sshPolicy, err := policy.GenerateFilterRules(
pol,
// The policy is currently calculated for the entire Headscale network
append(peers, *machine),
machine,
peers,
stripEmailDomain,
)
if err != nil {

View File

@ -360,7 +360,7 @@ func Test_fullMapResponse(t *testing.T) {
CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: nil,
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
@ -393,7 +393,7 @@ func Test_fullMapResponse(t *testing.T) {
CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: nil,
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
@ -442,7 +442,7 @@ func Test_fullMapResponse(t *testing.T) {
},
},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: nil,
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,

View File

@ -18,7 +18,6 @@ import (
"github.com/tailscale/hujson"
"go4.org/netipx"
"gopkg.in/yaml.v3"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
)
@ -54,8 +53,6 @@ const (
ProtocolFC = 133 // Fibre Channel
)
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) {
log.Debug().
@ -122,7 +119,8 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
// per node and that should be taken into account.
func GenerateFilterRules(
policy *ACLPolicy,
machines types.Machines,
machine *types.Machine,
peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all
@ -130,7 +128,7 @@ func GenerateFilterRules(
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
}
rules, err := policy.generateFilterRules(machines, stripEmailDomain)
rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@ -138,8 +136,7 @@ func GenerateFilterRules(
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
var sshPolicy *tailcfg.SSHPolicy
if featureEnableSSH() {
sshRules, err := generateSSHRules(policy, machines, stripEmailDomain)
sshRules, err := generateSSHRules(policy, append(peers, *machine), stripEmailDomain)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@ -148,9 +145,6 @@ func GenerateFilterRules(
sshPolicy = &tailcfg.SSHPolicy{}
}
sshPolicy.Rules = sshRules
} else if policy != nil && len(policy.SSHs) > 0 {
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
}
return rules, sshPolicy, nil
}

View File

@ -245,7 +245,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
},
},
}
_, _, err := GenerateFilterRules(pol, types.Machines{}, false)
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
}
@ -264,7 +264,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
},
},
}
_, _, err := GenerateFilterRules(pol, types.Machines{}, false)
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
}
@ -280,7 +280,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
},
}
_, _, err := GenerateFilterRules(pol, types.Machines{}, false)
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
}