1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-14 13:51:01 +02:00

split policy -> policy and v1

This commit split out the common policy logic and policy implementation
into separate packages.

policy contains functions that are independent of the policy implementation,
this typically means logic that works on tailcfg types and generic formats.
In addition, it defines the PolicyManager interface which the v1 implements.

v1 is a subpackage which implements the PolicyManager using the "original"
policy implementation.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-26 19:15:04 +01:00
parent 1dd50c6916
commit e9ffc4f2df
No known key found for this signature in database
12 changed files with 1874 additions and 1742 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@ -246,7 +247,7 @@ func Test_fullMapResponse(t *testing.T) {
tests := []struct {
name string
pol *policy.ACLPolicy
pol []byte
node *types.Node
peers types.Nodes
@ -258,7 +259,7 @@ func Test_fullMapResponse(t *testing.T) {
// {
// name: "empty-node",
// node: types.Node{},
// pol: &policy.ACLPolicy{},
// pol: &policyv1.ACLPolicy{},
// dnsConfig: &tailcfg.DNSConfig{},
// baseDomain: "",
// want: nil,
@ -266,7 +267,6 @@ func Test_fullMapResponse(t *testing.T) {
// },
{
name: "no-pol-no-peers-map-response",
pol: &policy.ACLPolicy{},
node: mini,
peers: types.Nodes{},
derpMap: &tailcfg.DERPMap{},
@ -284,10 +284,15 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
UserProfiles: []tailcfg.UserProfile{
{
ID: tailcfg.UserID(user1.ID),
LoginName: "user1",
DisplayName: "user1",
},
},
PacketFilter: tailcfg.FilterAllowAll,
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
@ -296,7 +301,6 @@ func Test_fullMapResponse(t *testing.T) {
},
{
name: "no-pol-with-peer-map-response",
pol: &policy.ACLPolicy{},
node: mini,
peers: types.Nodes{
peer1,
@ -318,13 +322,12 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
PacketFilter: tailcfg.FilterAllowAll,
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
@ -333,18 +336,17 @@ func Test_fullMapResponse(t *testing.T) {
},
{
name: "with-pol-map-response",
pol: &policy.ACLPolicy{
Hosts: policy.Hosts{
"mini": netip.MustParsePrefix("100.64.0.1/32"),
},
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"100.64.0.2"},
Destinations: []string{"mini:*"},
},
},
},
pol: []byte(`
{
"acls": [
{
"action": "accept",
"src": ["100.64.0.2"],
"dst": ["user1:*"],
},
],
}
`),
node: mini,
peers: types.Nodes{
peer1,
@ -374,11 +376,11 @@ func Test_fullMapResponse(t *testing.T) {
},
},
},
SSHPolicy: &tailcfg.SSHPolicy{},
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
@ -390,7 +392,8 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
require.NoError(t, err)
primary := routes.New()
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)

View File

@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -49,7 +50,7 @@ func TestTailNode(t *testing.T) {
tests := []struct {
name string
node *types.Node
pol *policy.ACLPolicy
pol []byte
dnsConfig *tailcfg.DNSConfig
baseDomain string
want *tailcfg.Node
@ -61,7 +62,6 @@ func TestTailNode(t *testing.T) {
GivenName: "empty",
Hostinfo: &tailcfg.Hostinfo{},
},
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
want: &tailcfg.Node{
@ -117,7 +117,6 @@ func TestTailNode(t *testing.T) {
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
CreatedAt: created,
},
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
want: &tailcfg.Node{
@ -179,7 +178,8 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node})
require.NoError(t, err)
primary := routes.New()
cfg := &types.Config{
BaseDomain: tt.baseDomain,
@ -248,7 +248,7 @@ func TestNodeExpiry(t *testing.T) {
tn, err := tailNode(
node,
0,
&policy.PolicyManagerV1{},
nil, // TODO(kradalby): removed in merge but error?
nil,
&types.Config{},
)

View File

@ -1,219 +1,81 @@
package policy
import (
"fmt"
"io"
"net/netip"
"os"
"sync"
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
)
var (
polv2 = envknob.Bool("HEADSCALE_EXPERIMENTAL_POLICY_V2")
)
type PolicyManager interface {
Filter() []tailcfg.FilterRule
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
Tags(*types.Node) []string
ApproversForRoute(netip.Prefix) []string
ExpandAlias(string) (*netipx.IPSet, error)
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error)
// NodeCanHaveTag reports whether the given node can have the given tag.
NodeCanHaveTag(*types.Node, string) bool
// NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
Version() int
DebugString() string
}
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
policyFile, err := os.Open(path)
if err != nil {
return nil, err
}
defer policyFile.Close()
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return nil, err
}
return NewPolicyManager(policyBytes, users, nodes)
}
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
var pol *ACLPolicy
// NewPolicyManager returns a new policy manager, the version is determined by
// the environment flag "HEADSCALE_EXPERIMENTAL_POLICY_V2".
func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
var polMan PolicyManager
var err error
if polB != nil && len(polB) > 0 {
pol, err = LoadACLPolicyFromBytes(polB)
if polv2 {
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
return nil, err
}
} else {
polMan, err = policyv1.NewPolicyManager(pol, users, nodes)
if err != nil {
return nil, err
}
}
pm := PolicyManagerV1{
pol: pol,
users: users,
nodes: nodes,
}
_, err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
return polMan, err
}
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
pm := PolicyManagerV1{
pol: pol,
users: users,
nodes: nodes,
}
// PolicyManagersForTest returns all available PostureManagers to be used
// in tests to validate them in tests that try to determine that they
// behave the same.
func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) {
var polMans []PolicyManager
_, err := pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
type PolicyManagerV1 struct {
mu sync.Mutex
pol *ACLPolicy
users []types.User
nodes types.Nodes
filterHash deephash.Sum
filter []tailcfg.FilterRule
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("compiling filter rules: %w", err)
}
filterHash := deephash.Hash(&filter)
if filterHash == pm.filterHash {
return false, nil
}
pm.filter = filter
pm.filterHash = filterHash
return true, nil
}
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
}
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
if len(polB) == 0 {
return false, nil
}
pol, err := LoadACLPolicyFromBytes(polB)
if err != nil {
return false, fmt.Errorf("parsing policy: %w", err)
}
pm.mu.Lock()
defer pm.mu.Unlock()
pm.pol = pol
return pm.updateLocked()
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
if pm == nil {
return nil
}
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
return tags
}
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
// TODO(kradalby): This can be a parse error of the address in the policy,
// in the new policy this will be typed and not a problem, in this policy
// we will just return empty list
if pm.pol == nil {
return nil
}
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
return approvers
}
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
if err != nil {
return nil, err
}
return ips, nil
}
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
if pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
for _, approvedAlias := range approvers {
if approvedAlias == node.User.Username() {
return true
} else {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
if err != nil {
return false
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if ips.Contains(*node.IPv4) {
return true
}
for _, pmf := range PolicyManagerFuncsForTest(pol) {
pm, err := pmf(users, nodes)
if err != nil {
return nil, err
}
polMans = append(polMans, pm)
}
return false
return polMans, nil
}
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) {
var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error)
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
return policyv1.NewPolicyManager(pol, u, n)
})
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
return policyv2.NewPolicyManager(pol, u, n)
})
return polmanFuncs
}

109
hscontrol/policy/policy.go Normal file
View File

@ -0,0 +1,109 @@
package policy
import (
"net/netip"
"slices"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
func FilterNodesByACL(
node *types.Node,
nodes types.Nodes,
filter []tailcfg.FilterRule,
) types.Nodes {
var result types.Nodes
for index, peer := range nodes {
if peer.ID == node.ID {
continue
}
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
result = append(result, peer)
}
}
return result
}
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
// that are not relevant to that particular node.
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)
// Fail closed, if we can't parse it, then we should not allow
// access.
if err != nil {
continue DEST_LOOP
}
if node.InIPSet(expanded) {
dests = append(dests, dest)
continue DEST_LOOP
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo != nil {
if len(node.Hostinfo.RoutableIPs) > 0 {
for _, routableIP := range node.Hostinfo.RoutableIPs {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
}
}
if len(dests) > 0 {
ret = append(ret, tailcfg.FilterRule{
SrcIPs: rule.SrcIPs,
DstPorts: dests,
IPProto: rule.IPProto,
})
}
}
return ret
}
// AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy.
// It reports true if any routes were approved.
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
if pm == nil {
return false
}
var newApproved []netip.Prefix
for _, route := range node.AnnouncedRoutes() {
if pm.NodeCanApproveRoute(node, route) {
newApproved = append(newApproved, route)
}
}
if newApproved != nil {
newApproved = append(newApproved, node.ApprovedRoutes...)
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
node.ApprovedRoutes = newApproved
return true
}
return false
}

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,10 @@
package policy
package v1
import (
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"net/netip"
"os"
"slices"
@ -18,7 +17,6 @@ import (
"github.com/rs/zerolog/log"
"github.com/tailscale/hujson"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
@ -37,38 +35,6 @@ const (
expectedTokenItems = 2
)
var theInternetSet *netipx.IPSet
// theInternet returns the IPSet for the Internet.
// https://www.youtube.com/watch?v=iDbyYGrswtg
func theInternet() *netipx.IPSet {
if theInternetSet != nil {
return theInternetSet
}
var internetBuilder netipx.IPSetBuilder
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
internetBuilder.AddPrefix(tsaddr.AllIPv4())
// Delete Private network addresses
// https://datatracker.ietf.org/doc/html/rfc1918
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
// Delete Tailscale networks
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
// Delete "can't find DHCP networks"
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
}
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
@ -240,53 +206,6 @@ func (pol *ACLPolicy) CompileFilterRules(
return rules, nil
}
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
// that are not relevant to that particular node.
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
// TODO(kradalby): Make this nil and not alloc unless needed
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)
// Fail closed, if we can't parse it, then we should not allow
// access.
if err != nil {
continue DEST_LOOP
}
if node.InIPSet(expanded) {
dests = append(dests, dest)
continue DEST_LOOP
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if len(node.SubnetRoutes()) > 0 {
for _, routableIP := range node.SubnetRoutes() {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
}
if len(dests) > 0 {
ret = append(ret, tailcfg.FilterRule{
SrcIPs: rule.SrcIPs,
DstPorts: dests,
IPProto: rule.IPProto,
})
}
}
return ret
}
func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node,
users []types.User,
@ -418,7 +337,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
}
for addr := range ipSetAll(ips) {
for addr := range util.IPSetAddrIter(ips) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
@ -441,19 +360,6 @@ func (pol *ACLPolicy) CompileSSHPolicy(
}, nil
}
// ipSetAll returns a function that iterates over all the IPs in the IPSet.
func ipSetAll(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
return func(yield func(netip.Addr) bool) {
for _, rng := range ipSet.Ranges() {
for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() {
if !yield(ip) {
return
}
}
}
}
}
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
sessionLength, err := time.ParseDuration(duration)
if err != nil {
@ -950,7 +856,7 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix(
func expandAutoGroup(alias string) (*netipx.IPSet, error) {
switch {
case strings.HasPrefix(alias, "autogroup:internet"):
return theInternet(), nil
return util.TheInternet(), nil
default:
return nil, fmt.Errorf("unknown autogroup %q", alias)
@ -1084,24 +990,3 @@ func findUserFromToken(users []types.User, token string) (types.User, error) {
return potentialUsers[0], nil
}
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
func FilterNodesByACL(
node *types.Node,
nodes types.Nodes,
filter []tailcfg.FilterRule,
) types.Nodes {
var result types.Nodes
for index, peer := range nodes {
if peer.ID == node.ID {
continue
}
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
result = append(result, peer)
}
}
return result
}

View File

@ -1,4 +1,4 @@
package policy
package v1
import (
"encoding/json"

View File

@ -0,0 +1,187 @@
package v1
import (
"fmt"
"io"
"net/netip"
"os"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
)
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policyFile, err := os.Open(path)
if err != nil {
return nil, err
}
defer policyFile.Close()
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return nil, err
}
return NewPolicyManager(policyBytes, users, nodes)
}
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
var pol *ACLPolicy
var err error
if polB != nil && len(polB) > 0 {
pol, err = LoadACLPolicyFromBytes(polB)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
}
pm := PolicyManager{
pol: pol,
users: users,
nodes: nodes,
}
_, err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
type PolicyManager struct {
mu sync.Mutex
pol *ACLPolicy
users []types.User
nodes types.Nodes
filterHash deephash.Sum
filter []tailcfg.FilterRule
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManager) updateLocked() (bool, error) {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("compiling filter rules: %w", err)
}
filterHash := deephash.Hash(&filter)
if filterHash == pm.filterHash {
return false, nil
}
pm.filter = filter
pm.filterHash = filterHash
return true, nil
}
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
}
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
if len(polB) == 0 {
return false, nil
}
pol, err := LoadACLPolicyFromBytes(polB)
if err != nil {
return false, fmt.Errorf("parsing policy: %w", err)
}
pm.mu.Lock()
defer pm.mu.Unlock()
pm.pol = pol
return pm.updateLocked()
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
if pm == nil || pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
for _, t := range tags {
if t == tag {
return true
}
}
return false
}
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
if pm == nil || pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
for _, approvedAlias := range approvers {
if approvedAlias == node.User.Username() {
return true
} else {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
if err != nil {
return false
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if ips.Contains(*node.IPv4) {
return true
}
}
}
return false
}
func (pm *PolicyManager) Version() int {
return 1
}
func (pm *PolicyManager) DebugString() string {
return "not implemented for v1"
}

View File

@ -1,4 +1,4 @@
package policy
package v1
import (
"testing"

View File

@ -10,10 +10,9 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
"tailscale.com/net/tsaddr"
@ -459,25 +458,10 @@ func (m *mapSession) handleEndpointUpdate() {
// TODO(kradalby): I am not sure if we need this?
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
// Take all the routes presented to us by the node and check
// if any of them should be auto approved by the policy.
// If any of them are, add them to the approved routes of the node.
// Keep all the old entries and compact the list to remove duplicates.
var newApproved []netip.Prefix
for _, route := range m.node.Hostinfo.RoutableIPs {
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
newApproved = append(newApproved, route)
}
}
if newApproved != nil {
newApproved = append(newApproved, m.node.ApprovedRoutes...)
slices.SortFunc(newApproved, util.ComparePrefix)
slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
m.node.ApprovedRoutes = newApproved
// Approve routes if they are auto-approved by the policy.
// If any of them are approved, report them to the primary route tracker
// and send updates accordingly.
if policy.AutoApproveRoutes(m.h.polMan, m.node) {
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())

View File

@ -1,10 +1,13 @@
package util
import (
"cmp"
"context"
"net"
"net/netip"
"sync"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
)
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
@ -49,3 +52,29 @@ func MustStringsToPrefixes(strings []string) []netip.Prefix {
return ret
}
// TheInternet returns the IPSet for the Internet.
// https://www.youtube.com/watch?v=iDbyYGrswtg
var TheInternet = sync.OnceValue(func() *netipx.IPSet {
var internetBuilder netipx.IPSetBuilder
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
internetBuilder.AddPrefix(tsaddr.AllIPv4())
// Delete Private network addresses
// https://datatracker.ietf.org/doc/html/rfc1918
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
// Delete Tailscale networks
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
// Delete "can't find DHCP networks"
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
})