mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-28 13:49:04 +02:00
introduce policy v2 package
policy v2 is built from the ground up to be stricter and follow the same pattern for all types of resolvers. TODO introduce aliass resolver Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
c0f856256a
commit
72049e94b1
169
hscontrol/policy/v2/filter.go
Normal file
169
hscontrol/policy/v2/filter.go
Normal file
@ -0,0 +1,169 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
)
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) compileFilterRules(
|
||||
users types.Users,
|
||||
nodes types.Nodes,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
if pol == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
|
||||
var rules []tailcfg.FilterRule
|
||||
|
||||
for _, acl := range pol.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, ErrInvalidAction
|
||||
}
|
||||
|
||||
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
if len(srcIPs.Prefixes()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(kradalby): integrate type into schema
|
||||
// TODO(kradalby): figure out the _ is wildcard stuff
|
||||
protocols, _, err := parseProtocol(acl.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
|
||||
}
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
for _, port := range dest.Ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: pref.String(),
|
||||
Ports: port,
|
||||
}
|
||||
destPorts = append(destPorts, pr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(destPorts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, tailcfg.FilterRule{
|
||||
SrcIPs: ipSetToPrefixStringList(srcIPs),
|
||||
DstPorts: destPorts,
|
||||
IPProto: protocols,
|
||||
})
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
return tailcfg.SSHAction{
|
||||
Reject: !accept,
|
||||
Accept: accept,
|
||||
SessionDuration: duration,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
var dest netipx.IPSetBuilder
|
||||
for _, src := range rule.Destinations {
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
||||
destSet, err := dest.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !node.InIPSet(destSet) {
|
||||
continue
|
||||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
switch rule.Action {
|
||||
case "accept":
|
||||
action = sshAction(true, 0)
|
||||
case "check":
|
||||
action = sshAction(true, rule.CheckPeriod)
|
||||
default:
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
}
|
||||
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
userMap := make(map[string]string, len(rule.Users))
|
||||
for _, user := range rule.Users {
|
||||
userMap[user.String()] = "="
|
||||
}
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
|
||||
return &tailcfg.SSHPolicy{
|
||||
Rules: rules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
var out []string
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
out = append(out, pref.String())
|
||||
}
|
||||
return out
|
||||
}
|
378
hscontrol/policy/v2/filter_test.go
Normal file
378
hscontrol/policy/v2/filter_test.go
Normal file
@ -0,0 +1,378 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestParsing(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser"},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
acl string
|
||||
want []tailcfg.FilterRule
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid-hujson",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
`,
|
||||
want: []tailcfg.FilterRule{},
|
||||
wantErr: true,
|
||||
},
|
||||
// The new parser will ignore all that is irrelevant
|
||||
// {
|
||||
// name: "valid-hujson-invalid-content",
|
||||
// format: "hujson",
|
||||
// acl: `
|
||||
// {
|
||||
// "valid_json": true,
|
||||
// "but_a_policy_though": false
|
||||
// }
|
||||
// `,
|
||||
// want: []tailcfg.FilterRule{},
|
||||
// wantErr: true,
|
||||
// },
|
||||
// {
|
||||
// name: "invalid-cidr",
|
||||
// format: "hujson",
|
||||
// acl: `
|
||||
// {"example-host-1": "100.100.100.100/42"}
|
||||
// `,
|
||||
// want: []tailcfg.FilterRule{},
|
||||
// wantErr: true,
|
||||
// },
|
||||
{
|
||||
name: "basic-rule",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"subnet-1",
|
||||
"192.168.1.0/24"
|
||||
],
|
||||
"dst": [
|
||||
"*:22,3389",
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "parse-protocol",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "tcp",
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "udp",
|
||||
"dst": [
|
||||
"host-1:53",
|
||||
],
|
||||
},
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "icmp",
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{protocolUDP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolICMP, protocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-wildcard",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-range",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"subnet-1",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:5400-5500",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "100.100.100.100/32",
|
||||
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-group",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser@",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"200.200.200.200/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-user",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser@",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"200.200.200.200/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100/32",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pol, err := policyFromBytes([]byte(tt.acl))
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
} else if !tt.wantErr && err != nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := pol.compileFilterRules(
|
||||
users,
|
||||
types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.100.100.100"),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("200.200.200.200"),
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, rules); diff != "" {
|
||||
t.Errorf("parsing() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
283
hscontrol/policy/v2/policy.go
Normal file
283
hscontrol/policy/v2/policy.go
Normal file
@ -0,0 +1,283 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *Policy
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
|
||||
tagOwnerMapHash deephash.Sum
|
||||
tagOwnerMap map[Tag]*netipx.IPSet
|
||||
|
||||
autoApproveMapHash deephash.Sum
|
||||
autoApproveMap map[netip.Prefix]*netipx.IPSet
|
||||
|
||||
// Lazy map of SSH policies
|
||||
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
|
||||
}
|
||||
|
||||
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||
// It returns an error if the policy file is invalid.
|
||||
// The policy manager will update the filter rules based on the users and nodes.
|
||||
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policy, err := policyFromBytes(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: policy,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
filterChanged := filterHash == pm.filterHash
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
||||
// that the map for tag owners is resolved before resolving autoapprovers.
|
||||
// TODO(kradalby): Order might not matter after #2417
|
||||
tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolving tag owners map: %w", err)
|
||||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
autoMap, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolving auto approvers map: %w", err)
|
||||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
// If neither of the calculated values changed, no need to update nodes
|
||||
if !filterChanged && !tagOwnerChanged && !autoApproveChanged {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear the SSH policy map to ensure it's recalculated with the new policy.
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
pm.sshPolicyMap[node.ID] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := policyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// Filter returns the current filter rules for the entire tailnet.
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if ips.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// The fast path is that a node requests to approve a prefix
|
||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||
// check and return quickly
|
||||
if _, ok := pm.autoApproveMap[route]; ok {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if pm.autoApproveMap[route].Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The slow path is that the node tries to approve
|
||||
// 10.0.10.0/24, which is a part of 10.0.0.0/8, then we
|
||||
// cannot just lookup in the prefix map and have to check
|
||||
// if there is a "parent" prefix available.
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
// We do not want the exit node entry to approve all
|
||||
// sorts of routes. The logic here is that it would be
|
||||
// unexpected behaviour to have specific routes approved
|
||||
// just because the node is allowed to designate itself as
|
||||
// an exit.
|
||||
if tsaddr.IsExitRoute(prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if approveAddrs.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 2
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
var sb strings.Builder
|
||||
|
||||
fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version())
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.pol != nil {
|
||||
pol, err := json.MarshalIndent(pm.pol, "", " ")
|
||||
if err == nil {
|
||||
sb.WriteString("Policy:\n")
|
||||
sb.Write(pol)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
sb.WriteString("Compiled filter:\n")
|
||||
sb.Write(filter)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
58
hscontrol/policy/v2/policy_test.go
Normal file
58
hscontrol/policy/v2/policy_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
IPv4: ap(ipv4),
|
||||
IPv6: ap(ipv6),
|
||||
User: user,
|
||||
UserID: user.ID,
|
||||
Hostinfo: hostinfo,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyManager(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol string
|
||||
nodes types.Nodes
|
||||
wantFilter []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "empty-policy",
|
||||
pol: "{}",
|
||||
nodes: types.Nodes{},
|
||||
wantFilter: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
filter := pm.Filter()
|
||||
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Test SSH Policy
|
||||
})
|
||||
}
|
||||
}
|
1005
hscontrol/policy/v2/types.go
Normal file
1005
hscontrol/policy/v2/types.go
Normal file
File diff suppressed because it is too large
Load Diff
1162
hscontrol/policy/v2/types_test.go
Normal file
1162
hscontrol/policy/v2/types_test.go
Normal file
File diff suppressed because it is too large
Load Diff
164
hscontrol/policy/v2/utils.go
Normal file
164
hscontrol/policy/v2/utils.go
Normal file
@ -0,0 +1,164 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
lastColonIndex := strings.LastIndex(input, ":")
|
||||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
}
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
}
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
destination := input[:lastColonIndex]
|
||||
port := input[lastColonIndex+1:]
|
||||
|
||||
return destination, port, nil
|
||||
}
|
||||
|
||||
// parsePortRange parses a port definition string and returns a slice of PortRange structs.
|
||||
func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
if portDef == "*" {
|
||||
return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil
|
||||
}
|
||||
|
||||
var portRanges []tailcfg.PortRange
|
||||
parts := strings.Split(portDef, ",")
|
||||
|
||||
for _, part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
last, err := parsePort(rangeParts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
} else {
|
||||
port, err := parsePort(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
}
|
||||
}
|
||||
|
||||
return portRanges, nil
|
||||
}
|
||||
|
||||
// parsePort parses a single port number from a string.
|
||||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
protocolIGMP = 2 // Internet Group Management
|
||||
protocolIPv4 = 4 // IPv4 encapsulation
|
||||
protocolTCP = 6 // Transmission Control
|
||||
protocolEGP = 8 // Exterior Gateway Protocol
|
||||
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||
protocolUDP = 17 // User Datagram
|
||||
protocolGRE = 47 // Generic Routing Encapsulation
|
||||
protocolESP = 50 // Encap Security Payload
|
||||
protocolAH = 51 // Authentication Header
|
||||
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||
ProtocolFC = 133 // Fibre Channel
|
||||
)
|
||||
|
||||
// parseProtocol reads the proto field of the ACL and generates a list of
|
||||
// protocols that will be allowed, following the IANA IP protocol number
|
||||
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||
//
|
||||
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
|
||||
// as per Tailscale behaviour (see tailcfg.FilterRule).
|
||||
//
|
||||
// Also returns a boolean indicating if the protocol
|
||||
// requires all the destinations to use wildcard as port number (only TCP,
|
||||
// UDP and SCTP support specifying ports).
|
||||
func parseProtocol(protocol string) ([]int, bool, error) {
|
||||
switch protocol {
|
||||
case "":
|
||||
return nil, false, nil
|
||||
case "igmp":
|
||||
return []int{protocolIGMP}, true, nil
|
||||
case "ipv4", "ip-in-ip":
|
||||
return []int{protocolIPv4}, true, nil
|
||||
case "tcp":
|
||||
return []int{protocolTCP}, false, nil
|
||||
case "egp":
|
||||
return []int{protocolEGP}, true, nil
|
||||
case "igp":
|
||||
return []int{protocolIGP}, true, nil
|
||||
case "udp":
|
||||
return []int{protocolUDP}, false, nil
|
||||
case "gre":
|
||||
return []int{protocolGRE}, true, nil
|
||||
case "esp":
|
||||
return []int{protocolESP}, true, nil
|
||||
case "ah":
|
||||
return []int{protocolAH}, true, nil
|
||||
case "sctp":
|
||||
return []int{protocolSCTP}, false, nil
|
||||
case "icmp":
|
||||
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
|
||||
|
||||
default:
|
||||
protocolNumber, err := strconv.Atoi(protocol)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): What is this?
|
||||
needsWildcard := protocolNumber != protocolTCP &&
|
||||
protocolNumber != protocolUDP &&
|
||||
protocolNumber != protocolSCTP
|
||||
|
||||
return []int{protocolNumber}, needsWildcard, nil
|
||||
}
|
||||
}
|
102
hscontrol/policy/v2/utils_test.go
Normal file
102
hscontrol/policy/v2/utils_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests.
|
||||
func TestParseDestinationAndPort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedDst string
|
||||
expectedPort string
|
||||
expectedErr error
|
||||
}{
|
||||
{"git-server:*", "git-server", "*", nil},
|
||||
{"192.168.1.0/24:22", "192.168.1.0/24", "22", nil},
|
||||
{"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil},
|
||||
{"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil},
|
||||
{"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil},
|
||||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
dst, port, err := splitDestinationAndPort(testCase.input)
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
|
||||
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected uint16
|
||||
err string
|
||||
}{
|
||||
{"80", 80, ""},
|
||||
{"0", 0, ""},
|
||||
{"65535", 65535, ""},
|
||||
{"-1", 0, "port number out of range"},
|
||||
{"65536", 0, "port number out of range"},
|
||||
{"abc", 0, "invalid port number"},
|
||||
{"", 0, "invalid port number"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePort(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{80, 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
{"80-90,", nil, "invalid port number"},
|
||||
{"80,90-", nil, "invalid port range format"},
|
||||
{"80-90,abc", nil, "invalid port number"},
|
||||
{"80-90,65536", nil, "port number out of range"},
|
||||
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePortRange(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user