1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-22 00:11:47 +01:00

Expand the signature of policy.ExpandAlias() to support the implementation of autogroups

This commit is contained in:
Juan Font 2023-08-12 11:47:23 +00:00
parent 043be13e6d
commit 28354cc651
3 changed files with 48 additions and 37 deletions

View File

@ -424,7 +424,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
} else { } else {
// TODO(kradalby): figure out how to get this to depend on less stuff // TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias) approvedIps, err := aclPolicy.ExpandAlias(*machine, types.Machines{}, approvedAlias)
if err != nil { if err != nil {
log.Err(err). log.Err(err).
Str("alias", approvedAlias). Str("alias", approvedAlias).

View File

@ -157,7 +157,6 @@ func (pol *ACLPolicy) generateFilterRules(
peers types.Machines, peers types.Machines,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
machines := append(peers, *machine)
for index, acl := range pol.ACLs { for index, acl := range pol.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {
@ -166,7 +165,7 @@ func (pol *ACLPolicy) generateFilterRules(
srcIPs := []string{} srcIPs := []string{}
for srcIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := pol.expandSource(src, machines) srcs, err := pol.expandSource(src, *machine, peers)
if err != nil { if err != nil {
log.Error(). log.Error().
Interface("src", src). Interface("src", src).
@ -195,7 +194,8 @@ func (pol *ACLPolicy) generateFilterRules(
} }
expanded, err := pol.ExpandAlias( expanded, err := pol.ExpandAlias(
machines, *machine,
peers,
alias, alias,
) )
if err != nil { if err != nil {
@ -293,7 +293,7 @@ func (pol *ACLPolicy) generateSSHRules(
for index, sshACL := range pol.SSHs { for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations { for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, *machine), src) expanded, err := pol.ExpandAlias(*machine, peers, src)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -350,6 +350,7 @@ func (pol *ACLPolicy) generateSSHRules(
} }
} else { } else {
expandedSrcs, err := pol.ExpandAlias( expandedSrcs, err := pol.ExpandAlias(
*machine,
peers, peers,
rawSrc, rawSrc,
) )
@ -501,9 +502,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// with the given src alias. // with the given src alias.
func (pol *ACLPolicy) expandSource( func (pol *ACLPolicy) expandSource(
src string, src string,
machines types.Machines, machine types.Machine,
peers types.Machines,
) ([]string, error) { ) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src) ipSet, err := pol.ExpandAlias(machine, peers, src)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -526,7 +528,8 @@ func (pol *ACLPolicy) expandSource(
// - a cidr // - a cidr
// and transform these in IPAddresses. // and transform these in IPAddresses.
func (pol *ACLPolicy) ExpandAlias( func (pol *ACLPolicy) ExpandAlias(
machines types.Machines, machine types.Machine,
peers types.Machines,
alias string, alias string,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
if isWildcard(alias) { if isWildcard(alias) {
@ -535,22 +538,24 @@ func (pol *ACLPolicy) ExpandAlias(
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
allMachines := append(peers, machine)
log.Debug(). log.Debug().
Str("alias", alias). Str("alias", alias).
Msg("Expanding") Msg("Expanding")
// if alias is a group // if alias is a group
if isGroup(alias) { if isGroup(alias) {
return pol.expandIPsFromGroup(alias, machines) return pol.expandIPsFromGroup(alias, allMachines)
} }
// if alias is a tag // if alias is a tag
if isTag(alias) { if isTag(alias) {
return pol.expandIPsFromTag(alias, machines) return pol.expandIPsFromTag(alias, allMachines)
} }
// if alias is a user // if alias is a user
if ips, err := pol.expandIPsFromUser(alias, machines); ips != nil { if ips, err := pol.expandIPsFromUser(alias, allMachines); ips != nil {
return ips, err return ips, err
} }
@ -559,17 +564,17 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok { if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(machines, h.String()) return pol.ExpandAlias(machine, peers, h.String())
} }
// if alias is an IP // if alias is an IP
if ip, err := netip.ParseAddr(alias); err == nil { if ip, err := netip.ParseAddr(alias); err == nil {
return pol.expandIPsFromSingleIP(ip, machines) return pol.expandIPsFromSingleIP(ip, allMachines)
} }
// if alias is an IP Prefix (CIDR) // if alias is an IP Prefix (CIDR)
if prefix, err := netip.ParsePrefix(alias); err == nil { if prefix, err := netip.ParsePrefix(alias); err == nil {
return pol.expandIPsFromIPPrefix(prefix, machines) return pol.expandIPsFromIPPrefix(prefix, allMachines)
} }
log.Warn().Msgf("No IPs found with the alias %v", alias) log.Warn().Msgf("No IPs found with the alias %v", alias)
@ -871,6 +876,10 @@ func isTag(str string) bool {
return strings.HasPrefix(str, "tag:") return strings.HasPrefix(str, "tag:")
} }
func isAutogroup(str string) bool {
return strings.HasPrefix(str, "autogroup:")
}
// TagsOfMachine will return the tags of the current machine. // TagsOfMachine will return the tags of the current machine.
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. // Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.

View File

@ -979,7 +979,8 @@ func Test_expandAlias(t *testing.T) {
pol ACLPolicy pol ACLPolicy
} }
type args struct { type args struct {
machines types.Machines machine types.Machine
peers types.Machines
aclPolicy ACLPolicy aclPolicy ACLPolicy
alias string alias string
} }
@ -997,7 +998,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "*", alias: "*",
machines: types.Machines{ peers: types.Machines{
{IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}}, {IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}},
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
@ -1021,7 +1022,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "group:accountant", alias: "group:accountant",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1062,7 +1063,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "group:hr", alias: "group:hr",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1098,8 +1099,8 @@ func Test_expandAlias(t *testing.T) {
pol: ACLPolicy{}, pol: ACLPolicy{},
}, },
args: args{ args: args{
alias: "10.0.0.3", alias: "10.0.0.3",
machines: types.Machines{}, peers: types.Machines{},
}, },
want: set([]string{ want: set([]string{
"10.0.0.3", "10.0.0.3",
@ -1112,8 +1113,8 @@ func Test_expandAlias(t *testing.T) {
pol: ACLPolicy{}, pol: ACLPolicy{},
}, },
args: args{ args: args{
alias: "10.0.0.1", alias: "10.0.0.1",
machines: types.Machines{}, peers: types.Machines{},
}, },
want: set([]string{ want: set([]string{
"10.0.0.1", "10.0.0.1",
@ -1127,7 +1128,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "10.0.0.1", alias: "10.0.0.1",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.1"),
@ -1148,7 +1149,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "10.0.0.1", alias: "10.0.0.1",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.1"),
@ -1170,7 +1171,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.1"),
@ -1195,8 +1196,8 @@ func Test_expandAlias(t *testing.T) {
}, },
}, },
args: args{ args: args{
alias: "testy", alias: "testy",
machines: types.Machines{}, peers: types.Machines{},
}, },
want: set([]string{}, []string{"10.0.0.132/32"}), want: set([]string{}, []string{"10.0.0.132/32"}),
wantErr: false, wantErr: false,
@ -1211,8 +1212,8 @@ func Test_expandAlias(t *testing.T) {
}, },
}, },
args: args{ args: args{
alias: "homeNetwork", alias: "homeNetwork",
machines: types.Machines{}, peers: types.Machines{},
}, },
want: set([]string{}, []string{"192.168.1.0/24"}), want: set([]string{}, []string{"192.168.1.0/24"}),
wantErr: false, wantErr: false,
@ -1224,7 +1225,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "10.0.0.0/16", alias: "10.0.0.0/16",
machines: types.Machines{}, peers: types.Machines{},
aclPolicy: ACLPolicy{}, aclPolicy: ACLPolicy{},
}, },
want: set([]string{}, []string{"10.0.0.0/16"}), want: set([]string{}, []string{"10.0.0.0/16"}),
@ -1239,7 +1240,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "tag:hr-webserver", alias: "tag:hr-webserver",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1293,7 +1294,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "tag:hr-webserver", alias: "tag:hr-webserver",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1330,7 +1331,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "tag:hr-webserver", alias: "tag:hr-webserver",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1373,7 +1374,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "tag:hr-webserver", alias: "tag:hr-webserver",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1418,7 +1419,7 @@ func Test_expandAlias(t *testing.T) {
}, },
args: args{ args: args{
alias: "joe", alias: "joe",
machines: types.Machines{ peers: types.Machines{
{ {
IPAddresses: types.MachineAddresses{ IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("100.64.0.1"),
@ -1462,7 +1463,8 @@ func Test_expandAlias(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got, err := test.field.pol.ExpandAlias( got, err := test.field.pol.ExpandAlias(
test.args.machines, test.args.machine,
test.args.peers,
test.args.alias, test.args.alias,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
@ -2421,7 +2423,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
}, },
{ {
// Investigating 699 // Investigating 699
// Found some machines: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] machine=ts-head-8w6paa // Found some peers: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] machine=ts-head-8w6paa
// ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}]
// ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}}
name: "issue-699-broken-star", name: "issue-699-broken-star",