1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-30 00:09:42 +01:00

Improve ACLs by adding protocol parsing support

This commit is contained in:
Juan Font Alonso 2022-06-08 17:43:59 +02:00
parent 3e353004b8
commit ab1aac9f3e
3 changed files with 93 additions and 14 deletions

70
acls.go
View File

@ -23,6 +23,7 @@ const (
errInvalidGroup = Error("invalid group") errInvalidGroup = Error("invalid group")
errInvalidTag = Error("invalid tag") errInvalidTag = Error("invalid tag")
errInvalidPortFormat = Error("invalid port format") errInvalidPortFormat = Error("invalid port format")
errWildcardIsNeeded = Error("wildcard as port is required for the procotol")
) )
const ( const (
@ -134,9 +135,17 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
} }
protocols, needsWildcard, err := parseProtocol(acl.Protocol)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d. protocol unknown %s", index, acl.Protocol)
return nil, err
}
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations { for innerIndex, dest := range acl.Destinations {
dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest) dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest, needsWildcard)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Destination %d", index, innerIndex) Msgf("Error parsing ACL %d, Destination %d", index, innerIndex)
@ -149,6 +158,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules = append(rules, tailcfg.FilterRule{ rules = append(rules, tailcfg.FilterRule{
SrcIPs: srcIPs, SrcIPs: srcIPs,
DstPorts: destPorts, DstPorts: destPorts,
IPProto: protocols,
}) })
} }
@ -167,6 +177,7 @@ func (h *Headscale) generateACLPolicyDest(
machines []Machine, machines []Machine,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
dest string, dest string,
needsWildcard bool,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(dest, ":") tokens := strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 { if len(tokens) < expectedTokenItems || len(tokens) > 3 {
@ -195,7 +206,7 @@ func (h *Headscale) generateACLPolicyDest(
if err != nil { if err != nil {
return nil, err return nil, err
} }
ports, err := expandPorts(tokens[len(tokens)-1]) ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,6 +225,54 @@ func (h *Headscale) generateACLPolicyDest(
return dests, nil return dests, nil
} }
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return []int{1, 58, 6, 17}, false, nil
case "igmp":
return []int{2}, true, nil
case "ipv4", "ip-in-ip":
return []int{4}, true, nil
case "tcp":
return []int{6}, false, nil
case "egp":
return []int{8}, true, nil
case "igp":
return []int{9}, true, nil
case "udp":
return []int{17}, false, nil
case "gre":
return []int{47}, true, nil
case "esp":
return []int{50}, true, nil
case "ah":
return []int{51}, true, nil
case "sctp":
return []int{132}, false, nil
case "icmp":
return []int{1, 58}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, err
}
needsWildcard := protocolNumber != 6 && protocolNumber != 17 && protocolNumber != 132
return []int{protocolNumber}, needsWildcard, nil
}
}
// expandalias has an input of either // expandalias has an input of either
// - a namespace // - a namespace
// - a group // - a group
@ -268,6 +327,7 @@ func expandAlias(
alias, alias,
) )
} }
return ips, nil return ips, nil
} else { } else {
return ips, err return ips, err
@ -359,13 +419,17 @@ func excludeCorrectlyTaggedNodes(
return out return out
} }
func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) {
if portsStr == "*" { if portsStr == "*" {
return &[]tailcfg.PortRange{ return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd}, {First: portRangeBegin, Last: portRangeEnd},
}, nil }, nil
} }
if needsWildcard {
return nil, errWildcardIsNeeded
}
ports := []tailcfg.PortRange{} ports := []tailcfg.PortRange{}
for _, portStr := range strings.Split(portsStr, ",") { for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(portStr, "-") rang := strings.Split(portStr, "-")

View File

@ -629,6 +629,7 @@ func Test_expandTagOwners(t *testing.T) {
func Test_expandPorts(t *testing.T) { func Test_expandPorts(t *testing.T) {
type args struct { type args struct {
portsStr string portsStr string
needsWildcard bool
} }
tests := []struct { tests := []struct {
name string name string
@ -638,15 +639,29 @@ func Test_expandPorts(t *testing.T) {
}{ }{
{ {
name: "wildcard", name: "wildcard",
args: args{portsStr: "*"}, args: args{portsStr: "*", needsWildcard: true},
want: &[]tailcfg.PortRange{ want: &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd}, {First: portRangeBegin, Last: portRangeEnd},
}, },
wantErr: false, wantErr: false,
}, },
{
name: "needs wildcard but does not require it",
args: args{portsStr: "*", needsWildcard: false},
want: &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
},
wantErr: false,
},
{
name: "needs wildcard but gets port",
args: args{portsStr: "80,443", needsWildcard: true},
want: nil,
wantErr: true,
},
{ {
name: "two Destinations", name: "two Destinations",
args: args{portsStr: "80,443"}, args: args{portsStr: "80,443", needsWildcard: false},
want: &[]tailcfg.PortRange{ want: &[]tailcfg.PortRange{
{First: 80, Last: 80}, {First: 80, Last: 80},
{First: 443, Last: 443}, {First: 443, Last: 443},
@ -655,7 +670,7 @@ func Test_expandPorts(t *testing.T) {
}, },
{ {
name: "a range and a port", name: "a range and a port",
args: args{portsStr: "80-1024,443"}, args: args{portsStr: "80-1024,443", needsWildcard: false},
want: &[]tailcfg.PortRange{ want: &[]tailcfg.PortRange{
{First: 80, Last: 1024}, {First: 80, Last: 1024},
{First: 443, Last: 443}, {First: 443, Last: 443},
@ -664,38 +679,38 @@ func Test_expandPorts(t *testing.T) {
}, },
{ {
name: "out of bounds", name: "out of bounds",
args: args{portsStr: "854038"}, args: args{portsStr: "854038", needsWildcard: false},
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong port", name: "wrong port",
args: args{portsStr: "85a38"}, args: args{portsStr: "85a38", needsWildcard: false},
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong port in first", name: "wrong port in first",
args: args{portsStr: "a-80"}, args: args{portsStr: "a-80", needsWildcard: false},
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong port in last", name: "wrong port in last",
args: args{portsStr: "80-85a38"}, args: args{portsStr: "80-85a38", needsWildcard: false},
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong port format", name: "wrong port format",
args: args{portsStr: "80-85a38-3"}, args: args{portsStr: "80-85a38-3", needsWildcard: false},
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got, err := expandPorts(test.args.portsStr) got, err := expandPorts(test.args.portsStr, test.args.needsWildcard)
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr) t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr)

View File

@ -21,7 +21,7 @@ type ACLPolicy struct {
// ACL is a basic rule for the ACL Policy. // ACL is a basic rule for the ACL Policy.
type ACL struct { type ACL struct {
Action string `json:"action" yaml:"action"` Action string `json:"action" yaml:"action"`
Protocol string `json:"protocol" yaml:"protocol"` Protocol string `json:"proto" yaml:"proto"`
Sources []string `json:"src" yaml:"src"` Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"` Destinations []string `json:"dst" yaml:"dst"`
} }