diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 54c80bac..d7deb0a5 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/fs" + "net/netip" "net/url" "os" "path" @@ -308,9 +309,15 @@ func (m *Mapper) PeerChangedResponse( resp.PeersChangedPatch = patches } + _, matchers := m.polMan.Filter() // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.primary, m.cfg) + tailnode, err := tailNode( + node, mapRequest.Version, m.polMan, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers) + }, + m.cfg) if err != nil { return nil, err } @@ -347,7 +354,7 @@ func (m *Mapper) marshalMapResponse( } if debugDumpMapResponsePath != "" { - data := map[string]interface{}{ + data := map[string]any{ "Messages": messages, "MapRequest": mapRequest, "MapResponse": resp, @@ -457,7 +464,13 @@ func (m *Mapper) baseWithConfigMapResponse( ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, m.polMan, m.primary, m.cfg) + _, matchers := m.polMan.Filter() + tailnode, err := tailNode( + node, capVer, m.polMan, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, m.primary.PrimaryRoutes(id), matchers) + }, + m.cfg) if err != nil { return nil, err } @@ -513,15 +526,10 @@ func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { return nodes, nil } -func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { - ret := make(types.Nodes, 0) - - for _, node := range nodes { - ret = append(ret, node) - } - - return ret -} +// routeFilterFunc is a function that takes a node ID and returns a list of +// netip.Prefixes that are allowed for that node. It is used to filter routes +// from the primary route manager to the node. +type routeFilterFunc func(id types.NodeID) []netip.Prefix // appendPeerChanges mutates a tailcfg.MapResponse with all the // necessary changes when peers have changed. @@ -553,7 +561,12 @@ func appendPeerChanges( dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, polMan, primary, cfg) + tailPeers, err := tailNodes( + changed, capVer, polMan, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, primary.PrimaryRoutes(id), matchers) + }, + cfg) if err != nil { return err } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index c9365f1a..eae70e96 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -5,7 +5,6 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "github.com/samber/lo" "tailscale.com/net/tsaddr" @@ -16,7 +15,7 @@ func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, polMan policy.PolicyManager, - primary *routes.PrimaryRoutes, + primaryRouteFunc routeFilterFunc, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -26,7 +25,7 @@ func tailNodes( node, capVer, polMan, - primary, + primaryRouteFunc, cfg, ) if err != nil { @@ -44,7 +43,7 @@ func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, polMan policy.PolicyManager, - primary *routes.PrimaryRoutes, + primaryRouteFunc routeFilterFunc, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,8 +80,7 @@ func tailNode( } tags = lo.Uniq(append(tags, node.ForcedTags...)) - _, matchers := polMan.Filter() - routes := policy.ReduceRoutes(node, primary.PrimaryRoutes(node.ID), matchers) + routes := primaryRouteFunc(node.ID) allowed := append(node.Prefixes(), routes...) allowed = append(allowed, node.ExitRoutes()...) tsaddr.SortPrefixes(allowed) @@ -101,7 +99,7 @@ func tailNode( Machine: node.MachineKey, DiscoKey: node.DiscoKey, Addresses: addrs, - PrimaryRoutes: primary.PrimaryRoutes(node.ID), + PrimaryRoutes: routes, AllowedIPs: allowed, Endpoints: node.Endpoints, HomeDERP: derp, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 30c6d4a8..cacc4930 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -219,7 +219,9 @@ func TestTailNode(t *testing.T) { tt.node, 0, polMan, - primary, + func(id types.NodeID) []netip.Prefix { + return primary.PrimaryRoutes(id) + }, cfg, ) @@ -266,6 +268,7 @@ func TestNodeExpiry(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { node := &types.Node{ + ID: 0, GivenName: "test", Expiry: tt.exp, } @@ -276,7 +279,9 @@ func TestNodeExpiry(t *testing.T) { node, 0, polMan, - nil, + func(id types.NodeID) []netip.Prefix { + return []netip.Prefix{} + }, &types.Config{}, ) if err != nil {