mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-08 00:11:42 +01:00
Enhance route command with ptables and multiple routes
This commit rewrites the `routes list` command to use ptables to present a slightly nicer list, including a new field if the route is enabled or not (which is quite useful). In addition, it reworks the enable command to support enabling multiple routes (not only one route as per removed TODO). This allows users to actually take advantage of exit-nodes and subnet relays.
This commit is contained in:
parent
47b61c0cea
commit
c883e79884
@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/pterm/pterm"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -44,19 +45,25 @@ var listRoutesCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
log.Fatalf("Error initializing: %s", err)
|
||||
}
|
||||
routes, err := h.GetNodeRoutes(n, args[0])
|
||||
|
||||
if strings.HasPrefix(o, "json") {
|
||||
JsonOutput(routes, err, o)
|
||||
return
|
||||
}
|
||||
|
||||
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(routes)
|
||||
if strings.HasPrefix(o, "json") {
|
||||
// TODO: Add enable/disabled information to this interface
|
||||
JsonOutput(availableRoutes, err, o)
|
||||
return
|
||||
}
|
||||
|
||||
d := h.RoutesToPtables(n, args[0], *availableRoutes)
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -80,9 +87,10 @@ var enableRouteCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
log.Fatalf("Error initializing: %s", err)
|
||||
}
|
||||
route, err := h.EnableNodeRoute(n, args[0], args[1])
|
||||
|
||||
err = h.EnableNodeRoute(n, args[0], args[1])
|
||||
if strings.HasPrefix(o, "json") {
|
||||
JsonOutput(route, err, o)
|
||||
JsonOutput(args[1], err, o)
|
||||
return
|
||||
}
|
||||
|
||||
@ -90,6 +98,6 @@ var enableRouteCmd = &cobra.Command{
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("Enabled route %s\n", route)
|
||||
fmt.Printf("Enabled route %s\n", args[1])
|
||||
},
|
||||
}
|
||||
|
113
routes.go
113
routes.go
@ -2,55 +2,140 @@ package headscale
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/pterm/pterm"
|
||||
"gorm.io/datatypes"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// GetNodeRoutes returns the subnet routes advertised by a node (identified by
|
||||
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
|
||||
// namespace and node name)
|
||||
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
|
||||
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
|
||||
m, err := h.GetMachine(namespace, nodeName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hi, err := m.GetHostInfo()
|
||||
hostInfo, err := m.GetHostInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &hi.RoutableIPs, nil
|
||||
return &hostInfo.RoutableIPs, nil
|
||||
}
|
||||
|
||||
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
||||
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
|
||||
// namespace and node name)
|
||||
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) {
|
||||
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
|
||||
m, err := h.GetMachine(namespace, nodeName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hi, err := m.GetHostInfo()
|
||||
|
||||
data, err := m.EnabledRoutes.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routesStr := []string{}
|
||||
err = json.Unmarshal(data, &routesStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := make([]netaddr.IPPrefix, len(routesStr))
|
||||
for index, routeStr := range routesStr {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
routes[index] = route
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, enabledRoute := range enabledRoutes {
|
||||
if route == enabledRoute {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
||||
// namespace and node name)
|
||||
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
|
||||
m, err := h.GetMachine(namespace, nodeName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
available := false
|
||||
for _, availableRoute := range *availableRoutes {
|
||||
// If the route is available, and not yet enabled, add it to the new routing table
|
||||
if route == availableRoute {
|
||||
available = true
|
||||
if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
|
||||
enabledRoutes = append(enabledRoutes, route)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !available {
|
||||
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
|
||||
}
|
||||
|
||||
routes, err := json.Marshal(enabledRoutes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rIP := range hi.RoutableIPs {
|
||||
if rIP == route {
|
||||
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
|
||||
m.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&m)
|
||||
|
||||
err = h.RequestMapUpdates(m.NamespaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
return &rIP, nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
|
||||
d := pterm.TableData{{"Route", "Enabled"}}
|
||||
|
||||
for _, route := range availableRoutes {
|
||||
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
|
||||
|
||||
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
|
||||
}
|
||||
return nil, errors.New("could not find routable range")
|
||||
return d
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
Name: "test_get_route_machine",
|
||||
NamespaceID: n.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
r, err := h.GetNodeRoutes("test", "testmachine")
|
||||
r, err := h.GetAdvertisedNodeRoutes("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(*r), check.Equals, 1)
|
||||
|
||||
_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
|
||||
err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
|
||||
err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netaddr.ParseIPPrefix(
|
||||
"10.0.0.0/24",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
route2, err := netaddr.ParseIPPrefix(
|
||||
"150.0.10.0/25",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hi := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netaddr.IPPrefix{route, route2},
|
||||
}
|
||||
hostinfo, err := json.Marshal(hi)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "test_enable_route_machine",
|
||||
NamespaceID: n.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: datatypes.JSON(hostinfo),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(*availableRoutes), check.Equals, 2)
|
||||
|
||||
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes), check.Equals, 0)
|
||||
|
||||
err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
// Adding it twice will just let it pass through
|
||||
err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||
|
||||
err = h.EnableNodeRoute("test", "testmachine", "150.0.10.0/25")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes3), check.Equals, 2)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user