mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-14 13:51:01 +02:00
split policy -> policy and v1
This commit split out the common policy logic and policy implementation into separate packages. policy contains functions that are independent of the policy implementation, this typically means logic that works on tailcfg types and generic formats. In addition, it defines the PolicyManager interface which the v1 implements. v1 is a subpackage which implements the PolicyManager using the "original" policy implementation. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
1dd50c6916
commit
e9ffc4f2df
@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -246,7 +247,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol *policy.ACLPolicy
|
||||
pol []byte
|
||||
node *types.Node
|
||||
peers types.Nodes
|
||||
|
||||
@ -258,7 +259,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
// {
|
||||
// name: "empty-node",
|
||||
// node: types.Node{},
|
||||
// pol: &policy.ACLPolicy{},
|
||||
// pol: &policyv1.ACLPolicy{},
|
||||
// dnsConfig: &tailcfg.DNSConfig{},
|
||||
// baseDomain: "",
|
||||
// want: nil,
|
||||
@ -266,7 +267,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
// },
|
||||
{
|
||||
name: "no-pol-no-peers-map-response",
|
||||
pol: &policy.ACLPolicy{},
|
||||
node: mini,
|
||||
peers: types.Nodes{},
|
||||
derpMap: &tailcfg.DERPMap{},
|
||||
@ -284,10 +284,15 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{
|
||||
ID: tailcfg.UserID(user1.ID),
|
||||
LoginName: "user1",
|
||||
DisplayName: "user1",
|
||||
},
|
||||
},
|
||||
PacketFilter: tailcfg.FilterAllowAll,
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
@ -296,7 +301,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-pol-with-peer-map-response",
|
||||
pol: &policy.ACLPolicy{},
|
||||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
@ -318,13 +322,12 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
PacketFilter: tailcfg.FilterAllowAll,
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
@ -333,18 +336,17 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "with-pol-map-response",
|
||||
pol: &policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"mini": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"100.64.0.2"},
|
||||
Destinations: []string{"mini:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
pol: []byte(`
|
||||
{
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["100.64.0.2"],
|
||||
"dst": ["user1:*"],
|
||||
},
|
||||
],
|
||||
}
|
||||
`),
|
||||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
@ -374,11 +376,11 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
@ -390,7 +392,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
|
||||
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@ -49,7 +50,7 @@ func TestTailNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
pol *policy.ACLPolicy
|
||||
pol []byte
|
||||
dnsConfig *tailcfg.DNSConfig
|
||||
baseDomain string
|
||||
want *tailcfg.Node
|
||||
@ -61,7 +62,6 @@ func TestTailNode(t *testing.T) {
|
||||
GivenName: "empty",
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: &tailcfg.Node{
|
||||
@ -117,7 +117,6 @@ func TestTailNode(t *testing.T) {
|
||||
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
|
||||
CreatedAt: created,
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: &tailcfg.Node{
|
||||
@ -179,7 +178,8 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
cfg := &types.Config{
|
||||
BaseDomain: tt.baseDomain,
|
||||
@ -248,7 +248,7 @@ func TestNodeExpiry(t *testing.T) {
|
||||
tn, err := tailNode(
|
||||
node,
|
||||
0,
|
||||
&policy.PolicyManagerV1{},
|
||||
nil, // TODO(kradalby): removed in merge but error?
|
||||
nil,
|
||||
&types.Config{},
|
||||
)
|
||||
|
@ -1,219 +1,81 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
var (
|
||||
polv2 = envknob.Bool("HEADSCALE_EXPERIMENTAL_POLICY_V2")
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
Filter() []tailcfg.FilterRule
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
Tags(*types.Node) []string
|
||||
ApproversForRoute(netip.Prefix) []string
|
||||
ExpandAlias(string) (*netipx.IPSet, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(*types.Node, string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
// NewPolicyManager returns a new policy manager, the version is determined by
|
||||
// the environment flag "HEADSCALE_EXPERIMENTAL_POLICY_V2".
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if polv2 {
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
polMan, err = policyv1.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
return polMan, err
|
||||
}
|
||||
|
||||
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
// PolicyManagersForTest returns all available PostureManagers to be used
|
||||
// in tests to validate them in tests that try to determine that they
|
||||
// behave the same.
|
||||
func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) {
|
||||
var polMans []PolicyManager
|
||||
|
||||
_, err := pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManagerV1 struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||
if pm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
|
||||
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
|
||||
return tags
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||
// in the new policy this will be typed and not a problem, in this policy
|
||||
// we will just return empty list
|
||||
if pm.pol == nil {
|
||||
return nil
|
||||
}
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
return approvers
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
for _, pmf := range PolicyManagerFuncsForTest(pol) {
|
||||
pm, err := pmf(users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
return false
|
||||
return polMans, nil
|
||||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
return policyv1.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
return policyv2.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
|
||||
return polmanFuncs
|
||||
}
|
||||
|
109
hscontrol/policy/policy.go
Normal file
109
hscontrol/policy/policy.go
Normal file
@ -0,0 +1,109 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
|
||||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
// Fail closed, if we can't parse it, then we should not allow
|
||||
// access.
|
||||
if err != nil {
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if node.Hostinfo != nil {
|
||||
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
ret = append(ret, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: dests,
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range node.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
node.ApprovedRoutes = newApproved
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
1455
hscontrol/policy/policy_test.go
Normal file
1455
hscontrol/policy/policy_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,10 @@
|
||||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@ -18,7 +17,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@ -37,38 +35,6 @@ const (
|
||||
expectedTokenItems = 2
|
||||
)
|
||||
|
||||
var theInternetSet *netipx.IPSet
|
||||
|
||||
// theInternet returns the IPSet for the Internet.
|
||||
// https://www.youtube.com/watch?v=iDbyYGrswtg
|
||||
func theInternet() *netipx.IPSet {
|
||||
if theInternetSet != nil {
|
||||
return theInternetSet
|
||||
}
|
||||
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "can't find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
return theInternetSet
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
@ -240,53 +206,6 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
// TODO(kradalby): Make this nil and not alloc unless needed
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
// Fail closed, if we can't parse it, then we should not allow
|
||||
// access.
|
||||
if err != nil {
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if len(node.SubnetRoutes()) > 0 {
|
||||
for _, routableIP := range node.SubnetRoutes() {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
ret = append(ret, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: dests,
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
node *types.Node,
|
||||
users []types.User,
|
||||
@ -418,7 +337,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
|
||||
}
|
||||
for addr := range ipSetAll(ips) {
|
||||
for addr := range util.IPSetAddrIter(ips) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
@ -441,19 +360,6 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ipSetAll returns a function that iterates over all the IPs in the IPSet.
|
||||
func ipSetAll(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
|
||||
return func(yield func(netip.Addr) bool) {
|
||||
for _, rng := range ipSet.Ranges() {
|
||||
for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() {
|
||||
if !yield(ip) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||
sessionLength, err := time.ParseDuration(duration)
|
||||
if err != nil {
|
||||
@ -950,7 +856,7 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix(
|
||||
func expandAutoGroup(alias string) (*netipx.IPSet, error) {
|
||||
switch {
|
||||
case strings.HasPrefix(alias, "autogroup:internet"):
|
||||
return theInternet(), nil
|
||||
return util.TheInternet(), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown autogroup %q", alias)
|
||||
@ -1084,24 +990,3 @@ func findUserFromToken(users []types.User, token string) (types.User, error) {
|
||||
|
||||
return potentialUsers[0], nil
|
||||
}
|
||||
|
||||
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
|
||||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
187
hscontrol/policy/v1/policy.go
Normal file
187
hscontrol/policy/v1/policy.go
Normal file
@ -0,0 +1,187 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
if pm == nil || pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
|
||||
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
|
||||
|
||||
for _, t := range tags {
|
||||
if t == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm == nil || pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
return "not implemented for v1"
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
@ -10,10 +10,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@ -459,25 +458,10 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||
|
||||
// Take all the routes presented to us by the node and check
|
||||
// if any of them should be auto approved by the policy.
|
||||
// If any of them are, add them to the approved routes of the node.
|
||||
// Keep all the old entries and compact the list to remove duplicates.
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range m.node.Hostinfo.RoutableIPs {
|
||||
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, m.node.ApprovedRoutes...)
|
||||
slices.SortFunc(newApproved, util.ComparePrefix)
|
||||
slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
m.node.ApprovedRoutes = newApproved
|
||||
|
||||
// Approve routes if they are auto-approved by the policy.
|
||||
// If any of them are approved, report them to the primary route tracker
|
||||
// and send updates accordingly.
|
||||
if policy.AutoApproveRoutes(m.h.polMan, m.node) {
|
||||
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
|
@ -1,10 +1,13 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
@ -49,3 +52,29 @@ func MustStringsToPrefixes(strings []string) []netip.Prefix {
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TheInternet returns the IPSet for the Internet.
|
||||
// https://www.youtube.com/watch?v=iDbyYGrswtg
|
||||
var TheInternet = sync.OnceValue(func() *netipx.IPSet {
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "can't find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
return theInternetSet
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user