1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-14 13:51:01 +02:00

mapper: compare peer routes against node

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-04 15:55:22 +02:00
parent 8b11ab319d
commit 9b14563617
No known key found for this signature in database
3 changed files with 38 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/fs" "io/fs"
"net/netip"
"net/url" "net/url"
"os" "os"
"path" "path"
@ -308,9 +309,15 @@ func (m *Mapper) PeerChangedResponse(
resp.PeersChangedPatch = patches resp.PeersChangedPatch = patches
} }
_, matchers := m.polMan.Filter()
// Add the node itself, it might have changed, and particularly // Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update. // if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.primary, m.cfg) tailnode, err := tailNode(
node, mapRequest.Version, m.polMan,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -347,7 +354,7 @@ func (m *Mapper) marshalMapResponse(
} }
if debugDumpMapResponsePath != "" { if debugDumpMapResponsePath != "" {
data := map[string]interface{}{ data := map[string]any{
"Messages": messages, "Messages": messages,
"MapRequest": mapRequest, "MapRequest": mapRequest,
"MapResponse": resp, "MapResponse": resp,
@ -457,7 +464,13 @@ func (m *Mapper) baseWithConfigMapResponse(
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
tailnode, err := tailNode(node, capVer, m.polMan, m.primary, m.cfg) _, matchers := m.polMan.Filter()
tailnode, err := tailNode(
node, capVer, m.polMan,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -513,15 +526,10 @@ func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
return nodes, nil return nodes, nil
} }
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { // routeFilterFunc is a function that takes a node ID and returns a list of
ret := make(types.Nodes, 0) // netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node.
for _, node := range nodes { type routeFilterFunc func(id types.NodeID) []netip.Prefix
ret = append(ret, node)
}
return ret
}
// appendPeerChanges mutates a tailcfg.MapResponse with all the // appendPeerChanges mutates a tailcfg.MapResponse with all the
// necessary changes when peers have changed. // necessary changes when peers have changed.
@ -553,7 +561,12 @@ func appendPeerChanges(
dnsConfig := generateDNSConfig(cfg, node) dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(changed, capVer, polMan, primary, cfg) tailPeers, err := tailNodes(
changed, capVer, polMan,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, primary.PrimaryRoutes(id), matchers)
},
cfg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
@ -16,7 +15,7 @@ func tailNodes(
nodes types.Nodes, nodes types.Nodes,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager, polMan policy.PolicyManager,
primary *routes.PrimaryRoutes, primaryRouteFunc routeFilterFunc,
cfg *types.Config, cfg *types.Config,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes)) tNodes := make([]*tailcfg.Node, len(nodes))
@ -26,7 +25,7 @@ func tailNodes(
node, node,
capVer, capVer,
polMan, polMan,
primary, primaryRouteFunc,
cfg, cfg,
) )
if err != nil { if err != nil {
@ -44,7 +43,7 @@ func tailNode(
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager, polMan policy.PolicyManager,
primary *routes.PrimaryRoutes, primaryRouteFunc routeFilterFunc,
cfg *types.Config, cfg *types.Config,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
addrs := node.Prefixes() addrs := node.Prefixes()
@ -81,8 +80,7 @@ func tailNode(
} }
tags = lo.Uniq(append(tags, node.ForcedTags...)) tags = lo.Uniq(append(tags, node.ForcedTags...))
_, matchers := polMan.Filter() routes := primaryRouteFunc(node.ID)
routes := policy.ReduceRoutes(node, primary.PrimaryRoutes(node.ID), matchers)
allowed := append(node.Prefixes(), routes...) allowed := append(node.Prefixes(), routes...)
allowed = append(allowed, node.ExitRoutes()...) allowed = append(allowed, node.ExitRoutes()...)
tsaddr.SortPrefixes(allowed) tsaddr.SortPrefixes(allowed)
@ -101,7 +99,7 @@ func tailNode(
Machine: node.MachineKey, Machine: node.MachineKey,
DiscoKey: node.DiscoKey, DiscoKey: node.DiscoKey,
Addresses: addrs, Addresses: addrs,
PrimaryRoutes: primary.PrimaryRoutes(node.ID), PrimaryRoutes: routes,
AllowedIPs: allowed, AllowedIPs: allowed,
Endpoints: node.Endpoints, Endpoints: node.Endpoints,
HomeDERP: derp, HomeDERP: derp,

View File

@ -219,7 +219,9 @@ func TestTailNode(t *testing.T) {
tt.node, tt.node,
0, 0,
polMan, polMan,
primary, func(id types.NodeID) []netip.Prefix {
return primary.PrimaryRoutes(id)
},
cfg, cfg,
) )
@ -266,6 +268,7 @@ func TestNodeExpiry(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
node := &types.Node{ node := &types.Node{
ID: 0,
GivenName: "test", GivenName: "test",
Expiry: tt.exp, Expiry: tt.exp,
} }
@ -276,7 +279,9 @@ func TestNodeExpiry(t *testing.T) {
node, node,
0, 0,
polMan, polMan,
nil, func(id types.NodeID) []netip.Prefix {
return []netip.Prefix{}
},
&types.Config{}, &types.Config{},
) )
if err != nil { if err != nil {