1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

make parse destination string into a func

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-12 15:59:05 +02:00 committed by Kristoffer Dalby
parent 717abe89c1
commit 2675ff4b94
2 changed files with 97 additions and 29 deletions

View File

@ -375,9 +375,39 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
machines types.Machines, machines types.Machines,
needsWildcard bool, needsWildcard bool,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
var tokens []string alias, port, err := parseDestination(dest)
if err != nil {
return nil, err
}
log.Trace().Str("destination", dest).Msg("generating policy destination") expanded, err := pol.ExpandAlias(
machines,
alias,
)
if err != nil {
return nil, err
}
ports, err := expandPorts(port, needsWildcard)
if err != nil {
return nil, err
}
dests := []tailcfg.NetPortRange{}
for _, dest := range expanded.Prefixes() {
for _, port := range *ports {
pr := tailcfg.NetPortRange{
IP: dest.String(),
Ports: port,
}
dests = append(dests, pr)
}
}
return dests, nil
}
func parseDestination(dest string) (string, string, error) {
var tokens []string
// Check if there is a IPv4/6:Port combination, IPv6 has more than // Check if there is a IPv4/6:Port combination, IPv6 has more than
// three ":". // three ":".
@ -397,7 +427,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() { if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
log.Trace().Err(err).Msg("trying to parse as IPv6") log.Trace().Err(err).Msg("trying to parse as IPv6")
return nil, fmt.Errorf( return "", "", fmt.Errorf(
"failed to parse destination, tokens %v: %w", "failed to parse destination, tokens %v: %w",
tokens, tokens,
ErrInvalidPortFormat, ErrInvalidPortFormat,
@ -407,8 +437,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
} }
} }
log.Trace().Strs("tokens", tokens).Msg("generating policy destination")
var alias string var alias string
// We can have here stuff like: // We can have here stuff like:
// git-server:* // git-server:*
@ -424,30 +452,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
} }
expanded, err := pol.ExpandAlias( return alias, tokens[len(tokens)-1], nil
machines,
alias,
)
if err != nil {
return nil, err
}
ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard)
if err != nil {
return nil, err
}
dests := []tailcfg.NetPortRange{}
for _, dest := range expanded.Prefixes() {
for _, port := range *ports {
pr := tailcfg.NetPortRange{
IP: dest.String(),
Ports: port,
}
dests = append(dests, pr)
}
}
return dests, nil
} }
// parseProtocol reads the proto field of the ACL and generates a list of // parseProtocol reads the proto field of the ACL and generates a list of

View File

@ -2557,3 +2557,66 @@ func TestSSHRules(t *testing.T) {
}) })
} }
} }
func TestParseDestination(t *testing.T) {
tests := []struct {
dest string
wantAlias string
wantPort string
}{
{
dest: "git-server:*",
wantAlias: "git-server",
wantPort: "*",
},
{
dest: "192.168.1.0/24:22",
wantAlias: "192.168.1.0/24",
wantPort: "22",
},
{
dest: "192.168.1.1:22",
wantAlias: "192.168.1.1",
wantPort: "22",
},
{
dest: "fd7a:115c:a1e0::2:22",
wantAlias: "fd7a:115c:a1e0::2",
wantPort: "22",
},
{
dest: "fd7a:115c:a1e0::2/128:22",
wantAlias: "fd7a:115c:a1e0::2/128",
wantPort: "22",
},
{
dest: "tag:montreal-webserver:80,443",
wantAlias: "tag:montreal-webserver",
wantPort: "80,443",
},
{
dest: "tag:api-server:443",
wantAlias: "tag:api-server",
wantPort: "443",
},
{
dest: "example-host-1:*",
wantAlias: "example-host-1",
wantPort: "*",
},
}
for _, tt := range tests {
t.Run(tt.dest, func(t *testing.T) {
alias, port, _ := parseDestination(tt.dest)
if alias != tt.wantAlias {
t.Errorf("unexpected alias: want(%s) != got(%s)", tt.wantAlias, alias)
}
if port != tt.wantPort {
t.Errorf("unexpected port: want(%s) != got(%s)", tt.wantPort, port)
}
})
}
}