mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-28 13:49:04 +02:00
introduce policy v2 package
policy v2 is built from the ground up to be stricter and follow the same pattern for all types of resolvers. TODO introduce aliass resolver Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
c0f856256a
commit
72049e94b1
169
hscontrol/policy/v2/filter.go
Normal file
169
hscontrol/policy/v2/filter.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidAction = errors.New("invalid action")
|
||||||
|
)
|
||||||
|
|
||||||
|
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||||
|
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||||
|
func (pol *Policy) compileFilterRules(
|
||||||
|
users types.Users,
|
||||||
|
nodes types.Nodes,
|
||||||
|
) ([]tailcfg.FilterRule, error) {
|
||||||
|
if pol == nil {
|
||||||
|
return tailcfg.FilterAllowAll, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []tailcfg.FilterRule
|
||||||
|
|
||||||
|
for _, acl := range pol.ACLs {
|
||||||
|
if acl.Action != "accept" {
|
||||||
|
return nil, ErrInvalidAction
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
log.Trace().Err(err).Msgf("resolving source ips")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(srcIPs.Prefixes()) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): integrate type into schema
|
||||||
|
// TODO(kradalby): figure out the _ is wildcard stuff
|
||||||
|
protocols, _, err := parseProtocol(acl.Protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var destPorts []tailcfg.NetPortRange
|
||||||
|
for _, dest := range acl.Destinations {
|
||||||
|
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pref := range ips.Prefixes() {
|
||||||
|
for _, port := range dest.Ports {
|
||||||
|
pr := tailcfg.NetPortRange{
|
||||||
|
IP: pref.String(),
|
||||||
|
Ports: port,
|
||||||
|
}
|
||||||
|
destPorts = append(destPorts, pr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(destPorts) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, tailcfg.FilterRule{
|
||||||
|
SrcIPs: ipSetToPrefixStringList(srcIPs),
|
||||||
|
DstPorts: destPorts,
|
||||||
|
IPProto: protocols,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||||
|
return tailcfg.SSHAction{
|
||||||
|
Reject: !accept,
|
||||||
|
Accept: accept,
|
||||||
|
SessionDuration: duration,
|
||||||
|
AllowAgentForwarding: true,
|
||||||
|
AllowLocalPortForwarding: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pol *Policy) compileSSHPolicy(
|
||||||
|
users types.Users,
|
||||||
|
node *types.Node,
|
||||||
|
nodes types.Nodes,
|
||||||
|
) (*tailcfg.SSHPolicy, error) {
|
||||||
|
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []*tailcfg.SSHRule
|
||||||
|
|
||||||
|
for index, rule := range pol.SSHs {
|
||||||
|
var dest netipx.IPSetBuilder
|
||||||
|
for _, src := range rule.Destinations {
|
||||||
|
ips, err := src.Resolve(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||||
|
}
|
||||||
|
dest.AddSet(ips)
|
||||||
|
}
|
||||||
|
|
||||||
|
destSet, err := dest.IPSet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !node.InIPSet(destSet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var action tailcfg.SSHAction
|
||||||
|
switch rule.Action {
|
||||||
|
case "accept":
|
||||||
|
action = sshAction(true, 0)
|
||||||
|
case "check":
|
||||||
|
action = sshAction(true, rule.CheckPeriod)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var principals []*tailcfg.SSHPrincipal
|
||||||
|
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
log.Trace().Err(err).Msgf("resolving source ips")
|
||||||
|
}
|
||||||
|
|
||||||
|
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||||
|
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||||
|
NodeIP: addr.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
userMap := make(map[string]string, len(rule.Users))
|
||||||
|
for _, user := range rule.Users {
|
||||||
|
userMap[user.String()] = "="
|
||||||
|
}
|
||||||
|
rules = append(rules, &tailcfg.SSHRule{
|
||||||
|
Principals: principals,
|
||||||
|
SSHUsers: userMap,
|
||||||
|
Action: &action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tailcfg.SSHPolicy{
|
||||||
|
Rules: rules,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||||
|
var out []string
|
||||||
|
|
||||||
|
for _, pref := range ips.Prefixes() {
|
||||||
|
out = append(out, pref.String())
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
378
hscontrol/policy/v2/filter_test.go
Normal file
378
hscontrol/policy/v2/filter_test.go
Normal file
@ -0,0 +1,378 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParsing(t *testing.T) {
|
||||||
|
users := types.Users{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "testuser"},
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
format string
|
||||||
|
acl string
|
||||||
|
want []tailcfg.FilterRule
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid-hujson",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
// The new parser will ignore all that is irrelevant
|
||||||
|
// {
|
||||||
|
// name: "valid-hujson-invalid-content",
|
||||||
|
// format: "hujson",
|
||||||
|
// acl: `
|
||||||
|
// {
|
||||||
|
// "valid_json": true,
|
||||||
|
// "but_a_policy_though": false
|
||||||
|
// }
|
||||||
|
// `,
|
||||||
|
// want: []tailcfg.FilterRule{},
|
||||||
|
// wantErr: true,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "invalid-cidr",
|
||||||
|
// format: "hujson",
|
||||||
|
// acl: `
|
||||||
|
// {"example-host-1": "100.100.100.100/42"}
|
||||||
|
// `,
|
||||||
|
// want: []tailcfg.FilterRule{},
|
||||||
|
// wantErr: true,
|
||||||
|
// },
|
||||||
|
{
|
||||||
|
name: "basic-rule",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"subnet-1",
|
||||||
|
"192.168.1.0/24"
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"*:22,3389",
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||||
|
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||||
|
{IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||||
|
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parse-protocol",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"Action": "accept",
|
||||||
|
"src": [
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
"proto": "tcp",
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Action": "accept",
|
||||||
|
"src": [
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
"proto": "udp",
|
||||||
|
"dst": [
|
||||||
|
"host-1:53",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Action": "accept",
|
||||||
|
"src": [
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
"proto": "icmp",
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
IPProto: []int{protocolTCP},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||||
|
},
|
||||||
|
IPProto: []int{protocolUDP},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
IPProto: []int{protocolICMP, protocolIPv6ICMP},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port-wildcard",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"Action": "accept",
|
||||||
|
"src": [
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port-range",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"subnet-1",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:5400-5500",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.100.101.0/24"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{
|
||||||
|
IP: "100.100.100.100/32",
|
||||||
|
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port-group",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:example": [
|
||||||
|
"testuser@",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"group:example",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"200.200.200.200/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port-user",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"testuser@",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"200.200.200.200/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6",
|
||||||
|
format: "hujson",
|
||||||
|
acl: `
|
||||||
|
{
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.100.100.100/32",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{
|
||||||
|
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pol, err := policyFromBytes([]byte(tt.acl))
|
||||||
|
if tt.wantErr && err == nil {
|
||||||
|
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
} else if !tt.wantErr && err != nil {
|
||||||
|
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := pol.compileFilterRules(
|
||||||
|
users,
|
||||||
|
types.Nodes{
|
||||||
|
&types.Node{
|
||||||
|
IPv4: ap("100.100.100.100"),
|
||||||
|
},
|
||||||
|
&types.Node{
|
||||||
|
IPv4: ap("200.200.200.200"),
|
||||||
|
User: users[0],
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, rules); diff != "" {
|
||||||
|
t.Errorf("parsing() unexpected result (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
283
hscontrol/policy/v2/policy.go
Normal file
283
hscontrol/policy/v2/policy.go
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"tailscale.com/net/tsaddr"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/deephash"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PolicyManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
pol *Policy
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
|
||||||
|
filterHash deephash.Sum
|
||||||
|
filter []tailcfg.FilterRule
|
||||||
|
|
||||||
|
tagOwnerMapHash deephash.Sum
|
||||||
|
tagOwnerMap map[Tag]*netipx.IPSet
|
||||||
|
|
||||||
|
autoApproveMapHash deephash.Sum
|
||||||
|
autoApproveMap map[netip.Prefix]*netipx.IPSet
|
||||||
|
|
||||||
|
// Lazy map of SSH policies
|
||||||
|
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||||
|
// It returns an error if the policy file is invalid.
|
||||||
|
// The policy manager will update the filter rules based on the users and nodes.
|
||||||
|
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||||
|
policy, err := policyFromBytes(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := PolicyManager{
|
||||||
|
pol: policy,
|
||||||
|
users: users,
|
||||||
|
nodes: nodes,
|
||||||
|
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pm.updateLocked()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||||
|
// It must be called with the lock held.
|
||||||
|
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||||
|
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filterHash := deephash.Hash(&filter)
|
||||||
|
filterChanged := filterHash == pm.filterHash
|
||||||
|
pm.filter = filter
|
||||||
|
pm.filterHash = filterHash
|
||||||
|
|
||||||
|
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
||||||
|
// that the map for tag owners is resolved before resolving autoapprovers.
|
||||||
|
// TODO(kradalby): Order might not matter after #2417
|
||||||
|
tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("resolving tag owners map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||||
|
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||||
|
pm.tagOwnerMap = tagMap
|
||||||
|
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||||
|
|
||||||
|
autoMap, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("resolving auto approvers map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||||
|
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||||
|
pm.autoApproveMap = autoMap
|
||||||
|
pm.autoApproveMapHash = autoApproveMapHash
|
||||||
|
|
||||||
|
// If neither of the calculated values changed, no need to update nodes
|
||||||
|
if !filterChanged && !tagOwnerChanged && !autoApproveChanged {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the SSH policy map to ensure it's recalculated with the new policy.
|
||||||
|
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||||
|
// policies for nodes that have changed. Particularly if the only difference is
|
||||||
|
// that nodes has been added or removed.
|
||||||
|
clear(pm.sshPolicyMap)
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
|
||||||
|
return sshPol, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||||
|
}
|
||||||
|
pm.sshPolicyMap[node.ID] = sshPol
|
||||||
|
|
||||||
|
return sshPol, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||||
|
if len(polB) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pol, err := policyFromBytes(polB)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
pm.pol = pol
|
||||||
|
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter returns the current filter rules for the entire tailnet.
|
||||||
|
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
return pm.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
pm.users = users
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
pm.nodes = nodes
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||||
|
if pm == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
|
||||||
|
for _, nodeAddr := range node.IPs() {
|
||||||
|
if ips.Contains(nodeAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||||
|
if pm == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
// The fast path is that a node requests to approve a prefix
|
||||||
|
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||||
|
// check and return quickly
|
||||||
|
if _, ok := pm.autoApproveMap[route]; ok {
|
||||||
|
for _, nodeAddr := range node.IPs() {
|
||||||
|
if pm.autoApproveMap[route].Contains(nodeAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The slow path is that the node tries to approve
|
||||||
|
// 10.0.10.0/24, which is a part of 10.0.0.0/8, then we
|
||||||
|
// cannot just lookup in the prefix map and have to check
|
||||||
|
// if there is a "parent" prefix available.
|
||||||
|
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||||
|
// We do not want the exit node entry to approve all
|
||||||
|
// sorts of routes. The logic here is that it would be
|
||||||
|
// unexpected behaviour to have specific routes approved
|
||||||
|
// just because the node is allowed to designate itself as
|
||||||
|
// an exit.
|
||||||
|
if tsaddr.IsExitRoute(prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prefix is larger (so containing) and then overlaps
|
||||||
|
// the route to see if the node can approve a subset of an autoapprover
|
||||||
|
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||||
|
for _, nodeAddr := range node.IPs() {
|
||||||
|
if approveAddrs.Contains(nodeAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManager) Version() int {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManager) DebugString() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version())
|
||||||
|
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
if pm.pol != nil {
|
||||||
|
pol, err := json.MarshalIndent(pm.pol, "", " ")
|
||||||
|
if err == nil {
|
||||||
|
sb.WriteString("Policy:\n")
|
||||||
|
sb.Write(pol)
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||||
|
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||||
|
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||||
|
for _, iprange := range approveAddrs.Ranges() {
|
||||||
|
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||||
|
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||||
|
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||||
|
for _, iprange := range tagOwners.Ranges() {
|
||||||
|
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
if pm.filter != nil {
|
||||||
|
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||||
|
if err == nil {
|
||||||
|
sb.WriteString("Compiled filter:\n")
|
||||||
|
sb.Write(filter)
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
58
hscontrol/policy/v2/policy_test.go
Normal file
58
hscontrol/policy/v2/policy_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||||
|
return &types.Node{
|
||||||
|
ID: 0,
|
||||||
|
Hostname: name,
|
||||||
|
IPv4: ap(ipv4),
|
||||||
|
IPv6: ap(ipv6),
|
||||||
|
User: user,
|
||||||
|
UserID: user.ID,
|
||||||
|
Hostinfo: hostinfo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyManager(t *testing.T) {
|
||||||
|
users := types.Users{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
|
||||||
|
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pol string
|
||||||
|
nodes types.Nodes
|
||||||
|
wantFilter []tailcfg.FilterRule
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty-policy",
|
||||||
|
pol: "{}",
|
||||||
|
nodes: types.Nodes{},
|
||||||
|
wantFilter: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
filter := pm.Filter()
|
||||||
|
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||||
|
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Test SSH Policy
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
1005
hscontrol/policy/v2/types.go
Normal file
1005
hscontrol/policy/v2/types.go
Normal file
File diff suppressed because it is too large
Load Diff
1162
hscontrol/policy/v2/types_test.go
Normal file
1162
hscontrol/policy/v2/types_test.go
Normal file
File diff suppressed because it is too large
Load Diff
164
hscontrol/policy/v2/utils.go
Normal file
164
hscontrol/policy/v2/utils.go
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||||
|
func splitDestinationAndPort(input string) (string, string, error) {
|
||||||
|
// Find the last occurrence of the colon character
|
||||||
|
lastColonIndex := strings.LastIndex(input, ":")
|
||||||
|
|
||||||
|
// Check if the colon character is present and not at the beginning or end of the string
|
||||||
|
if lastColonIndex == -1 {
|
||||||
|
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||||
|
}
|
||||||
|
if lastColonIndex == 0 {
|
||||||
|
return "", "", errors.New("input cannot start with a colon character")
|
||||||
|
}
|
||||||
|
if lastColonIndex == len(input)-1 {
|
||||||
|
return "", "", errors.New("input cannot end with a colon character")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split the string into destination and port based on the last colon
|
||||||
|
destination := input[:lastColonIndex]
|
||||||
|
port := input[lastColonIndex+1:]
|
||||||
|
|
||||||
|
return destination, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePortRange parses a port definition string and returns a slice of PortRange structs.
|
||||||
|
func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||||
|
if portDef == "*" {
|
||||||
|
return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var portRanges []tailcfg.PortRange
|
||||||
|
parts := strings.Split(portDef, ",")
|
||||||
|
|
||||||
|
for _, part := range parts {
|
||||||
|
if strings.Contains(part, "-") {
|
||||||
|
rangeParts := strings.Split(part, "-")
|
||||||
|
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||||
|
return e == ""
|
||||||
|
})
|
||||||
|
if len(rangeParts) != 2 {
|
||||||
|
return nil, errors.New("invalid port range format")
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := parsePort(rangeParts[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
last, err := parsePort(rangeParts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if first > last {
|
||||||
|
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||||
|
}
|
||||||
|
|
||||||
|
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||||
|
} else {
|
||||||
|
port, err := parsePort(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return portRanges, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePort parses a single port number from a string.
|
||||||
|
func parsePort(portStr string) (uint16, error) {
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.New("invalid port number")
|
||||||
|
}
|
||||||
|
|
||||||
|
if port < 0 || port > 65535 {
|
||||||
|
return 0, errors.New("port number out of range")
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||||
|
const (
|
||||||
|
protocolICMP = 1 // Internet Control Message
|
||||||
|
protocolIGMP = 2 // Internet Group Management
|
||||||
|
protocolIPv4 = 4 // IPv4 encapsulation
|
||||||
|
protocolTCP = 6 // Transmission Control
|
||||||
|
protocolEGP = 8 // Exterior Gateway Protocol
|
||||||
|
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||||
|
protocolUDP = 17 // User Datagram
|
||||||
|
protocolGRE = 47 // Generic Routing Encapsulation
|
||||||
|
protocolESP = 50 // Encap Security Payload
|
||||||
|
protocolAH = 51 // Authentication Header
|
||||||
|
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||||
|
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||||
|
ProtocolFC = 133 // Fibre Channel
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseProtocol reads the proto field of the ACL and generates a list of
|
||||||
|
// protocols that will be allowed, following the IANA IP protocol number
|
||||||
|
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||||
|
//
|
||||||
|
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
|
||||||
|
// as per Tailscale behaviour (see tailcfg.FilterRule).
|
||||||
|
//
|
||||||
|
// Also returns a boolean indicating if the protocol
|
||||||
|
// requires all the destinations to use wildcard as port number (only TCP,
|
||||||
|
// UDP and SCTP support specifying ports).
|
||||||
|
func parseProtocol(protocol string) ([]int, bool, error) {
|
||||||
|
switch protocol {
|
||||||
|
case "":
|
||||||
|
return nil, false, nil
|
||||||
|
case "igmp":
|
||||||
|
return []int{protocolIGMP}, true, nil
|
||||||
|
case "ipv4", "ip-in-ip":
|
||||||
|
return []int{protocolIPv4}, true, nil
|
||||||
|
case "tcp":
|
||||||
|
return []int{protocolTCP}, false, nil
|
||||||
|
case "egp":
|
||||||
|
return []int{protocolEGP}, true, nil
|
||||||
|
case "igp":
|
||||||
|
return []int{protocolIGP}, true, nil
|
||||||
|
case "udp":
|
||||||
|
return []int{protocolUDP}, false, nil
|
||||||
|
case "gre":
|
||||||
|
return []int{protocolGRE}, true, nil
|
||||||
|
case "esp":
|
||||||
|
return []int{protocolESP}, true, nil
|
||||||
|
case "ah":
|
||||||
|
return []int{protocolAH}, true, nil
|
||||||
|
case "sctp":
|
||||||
|
return []int{protocolSCTP}, false, nil
|
||||||
|
case "icmp":
|
||||||
|
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
protocolNumber, err := strconv.Atoi(protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): What is this?
|
||||||
|
needsWildcard := protocolNumber != protocolTCP &&
|
||||||
|
protocolNumber != protocolUDP &&
|
||||||
|
protocolNumber != protocolSCTP
|
||||||
|
|
||||||
|
return []int{protocolNumber}, needsWildcard, nil
|
||||||
|
}
|
||||||
|
}
|
102
hscontrol/policy/v2/utils_test.go
Normal file
102
hscontrol/policy/v2/utils_test.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests.
|
||||||
|
func TestParseDestinationAndPort(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
input string
|
||||||
|
expectedDst string
|
||||||
|
expectedPort string
|
||||||
|
expectedErr error
|
||||||
|
}{
|
||||||
|
{"git-server:*", "git-server", "*", nil},
|
||||||
|
{"192.168.1.0/24:22", "192.168.1.0/24", "22", nil},
|
||||||
|
{"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil},
|
||||||
|
{"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil},
|
||||||
|
{"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil},
|
||||||
|
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||||
|
{"example-host-1:*", "example-host-1", "*", nil},
|
||||||
|
{"hostname:80-90", "hostname", "80-90", nil},
|
||||||
|
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||||
|
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||||
|
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
dst, port, err := splitDestinationAndPort(testCase.input)
|
||||||
|
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
|
||||||
|
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||||
|
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected uint16
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"80", 80, ""},
|
||||||
|
{"0", 0, ""},
|
||||||
|
{"65535", 65535, ""},
|
||||||
|
{"-1", 0, "port number out of range"},
|
||||||
|
{"65536", 0, "port number out of range"},
|
||||||
|
{"abc", 0, "invalid port number"},
|
||||||
|
{"", 0, "invalid port number"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
result, err := parsePort(test.input)
|
||||||
|
if err != nil && err.Error() != test.err {
|
||||||
|
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||||
|
}
|
||||||
|
if err == nil && test.err != "" {
|
||||||
|
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||||
|
}
|
||||||
|
if result != test.expected {
|
||||||
|
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePortRange(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected []tailcfg.PortRange
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"80", []tailcfg.PortRange{{80, 80}}, ""},
|
||||||
|
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
|
||||||
|
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
|
||||||
|
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
|
||||||
|
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||||
|
{"80-", nil, "invalid port range format"},
|
||||||
|
{"-90", nil, "invalid port range format"},
|
||||||
|
{"80-90,", nil, "invalid port number"},
|
||||||
|
{"80,90-", nil, "invalid port range format"},
|
||||||
|
{"80-90,abc", nil, "invalid port number"},
|
||||||
|
{"80-90,65536", nil, "port number out of range"},
|
||||||
|
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
result, err := parsePortRange(test.input)
|
||||||
|
if err != nil && err.Error() != test.err {
|
||||||
|
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||||
|
}
|
||||||
|
if err == nil && test.err != "" {
|
||||||
|
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||||
|
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user