mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-20 19:09:07 +01:00
make GenerateFilterRules take machine and peers
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
35770278f7
commit
db6cf4ac0a
@ -2,10 +2,13 @@ package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -77,7 +80,7 @@ func (s *Suite) TestSshRules(c *check.C) {
|
||||
},
|
||||
}
|
||||
|
||||
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false)
|
||||
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, &machine, types.Machines{}, false)
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sshPolicy, check.NotNil)
|
||||
@ -94,15 +97,7 @@ func (s *Suite) TestSshRules(c *check.C) {
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Sources section.
|
||||
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
func TestValidExpandTagOwnersInSources(t *testing.T) {
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
@ -110,19 +105,19 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 0,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
@ -136,85 +131,28 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Destinations section.
|
||||
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"tag:test:*"},
|
||||
want := []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// need a test with:
|
||||
// tag on a host that isn't owned by a tag owners. So the user
|
||||
// of the host should be valid.
|
||||
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
func TestInvalidTagValidUser(t *testing.T) {
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
@ -222,19 +160,19 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
|
||||
@ -247,190 +185,38 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// tag on a host is owned by a tag owner, the tag is valid.
|
||||
// an ACL rule is matching the tag to a user. It should not be valid since the
|
||||
// host should be tied to the tag now.
|
||||
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "webserver")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "webserver",
|
||||
RequestTags: []string{"tag:webapp"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "webserver",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "user")
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
OS: "debian",
|
||||
Hostname: "Hostname",
|
||||
}
|
||||
c.Assert(err, check.NotNil)
|
||||
machine = types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "56789",
|
||||
NodeKey: "bar2",
|
||||
DiscoKey: "faab",
|
||||
Hostname: "user",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"tag:webapp:80,443"},
|
||||
want := []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 2)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
|
||||
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
|
||||
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
|
||||
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortUser(c *check.C) {
|
||||
user, err := db.CreateUser("testuser")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("testuser", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ips, _ := db.getAvailableIPs()
|
||||
func TestPortGroup(t *testing.T) {
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: ips,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: 0,
|
||||
User: types.User{
|
||||
Name: "testuser",
|
||||
},
|
||||
],
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pol, check.NotNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortGroup(c *check.C) {
|
||||
user, err := db.CreateUser("testuser")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("testuser", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ips, _ := db.getAvailableIPs()
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: ips,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.5")},
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
@ -459,22 +245,211 @@ func (s *Suite) TestPortGroup(c *check.C) {
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
want := []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.5/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("TestPortGroup() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortUser(t *testing.T) {
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: 0,
|
||||
User: types.User{
|
||||
Name: "testuser",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.9")},
|
||||
}
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
assert.NoError(t, err)
|
||||
|
||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.9/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("TestPortUser() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Destinations section.
|
||||
func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"tag:test:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// rules, _, err := policy.GenerateFilterRules(pol, &machine, peers, false)
|
||||
// c.Assert(err, check.IsNil)
|
||||
//
|
||||
// c.Assert(rules, check.HasLen, 1)
|
||||
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
|
||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// tag on a host is owned by a tag owner, the tag is valid.
|
||||
// an ACL rule is matching the tag to a user. It should not be valid since the
|
||||
// host should be tied to the tag now.
|
||||
func TestValidTagInvalidUser(t *testing.T) {
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "webserver",
|
||||
RequestTags: []string{"tag:webapp"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "webserver",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
OS: "debian",
|
||||
Hostname: "Hostname",
|
||||
}
|
||||
|
||||
machine2 := types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "56789",
|
||||
NodeKey: "bar2",
|
||||
DiscoKey: "faab",
|
||||
Hostname: "user",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
}
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"tag:webapp:80,443"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
@ -287,14 +287,20 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
adminPeers, err := db.ListPeers(adminMachine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
||||
testPeers, err := db.ListPeers(testMachine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules)
|
||||
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules)
|
||||
adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
|
||||
peersOfTestMachine := policy.FilterMachinesByACL(testMachine, testPeers, testRules)
|
||||
|
||||
c.Log(peersOfTestMachine)
|
||||
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||
|
@ -101,8 +101,8 @@ func fullMapResponse(
|
||||
|
||||
rules, sshPolicy, err := policy.GenerateFilterRules(
|
||||
pol,
|
||||
// The policy is currently calculated for the entire Headscale network
|
||||
append(peers, *machine),
|
||||
machine,
|
||||
peers,
|
||||
stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -360,7 +360,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
SSHPolicy: nil,
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
@ -393,7 +393,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
SSHPolicy: nil,
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
@ -442,7 +442,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
},
|
||||
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
|
||||
SSHPolicy: nil,
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/yaml.v3"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@ -54,8 +53,6 @@ const (
|
||||
ProtocolFC = 133 // Fibre Channel
|
||||
)
|
||||
|
||||
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
|
||||
|
||||
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
|
||||
func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) {
|
||||
log.Debug().
|
||||
@ -122,7 +119,8 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
|
||||
// per node and that should be taken into account.
|
||||
func GenerateFilterRules(
|
||||
policy *ACLPolicy,
|
||||
machines types.Machines,
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||
// If there is no policy defined, we default to allow all
|
||||
@ -130,7 +128,7 @@ func GenerateFilterRules(
|
||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||
}
|
||||
|
||||
rules, err := policy.generateFilterRules(machines, stripEmailDomain)
|
||||
rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain)
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
@ -138,19 +136,15 @@ func GenerateFilterRules(
|
||||
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
||||
|
||||
var sshPolicy *tailcfg.SSHPolicy
|
||||
if featureEnableSSH() {
|
||||
sshRules, err := generateSSHRules(policy, machines, stripEmailDomain)
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
|
||||
if sshPolicy == nil {
|
||||
sshPolicy = &tailcfg.SSHPolicy{}
|
||||
}
|
||||
sshPolicy.Rules = sshRules
|
||||
} else if policy != nil && len(policy.SSHs) > 0 {
|
||||
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
|
||||
sshRules, err := generateSSHRules(policy, append(peers, *machine), stripEmailDomain)
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
|
||||
if sshPolicy == nil {
|
||||
sshPolicy = &tailcfg.SSHPolicy{}
|
||||
}
|
||||
sshPolicy.Rules = sshRules
|
||||
|
||||
return rules, sshPolicy, nil
|
||||
}
|
||||
|
@ -245,7 +245,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
|
||||
},
|
||||
},
|
||||
}
|
||||
_, _, err := GenerateFilterRules(pol, types.Machines{}, false)
|
||||
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
|
||||
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
||||
}
|
||||
|
||||
@ -264,7 +264,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
|
||||
},
|
||||
},
|
||||
}
|
||||
_, _, err := GenerateFilterRules(pol, types.Machines{}, false)
|
||||
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
|
||||
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
||||
}
|
||||
|
||||
@ -280,7 +280,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := GenerateFilterRules(pol, types.Machines{}, false)
|
||||
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
|
||||
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user