mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	initial work on db migration
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									fe9557a729
								
							
						
					
					
						commit
						df85797954
					
				@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@ -622,6 +623,62 @@ AND auth_key_id NOT IN (
 | 
			
		||||
				},
 | 
			
		||||
				Rollback: func(db *gorm.DB) error { return nil },
 | 
			
		||||
			},
 | 
			
		||||
			// Migrate all routes from the Route table to the new field ApprovedRoutes
 | 
			
		||||
			// in the Node table. Then drop the Route table.
 | 
			
		||||
			{
 | 
			
		||||
				ID: "202502131714",
 | 
			
		||||
				Migrate: func(tx *gorm.DB) error {
 | 
			
		||||
					if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") {
 | 
			
		||||
						err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes")
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							return fmt.Errorf("adding column types.Node: %w", err)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					// Ensure the ApprovedRoutes exist.
 | 
			
		||||
					// err := tx.AutoMigrate(&types.Node{})
 | 
			
		||||
					// if err != nil {
 | 
			
		||||
					// 	return fmt.Errorf("automigrating types.Node: %w", err)
 | 
			
		||||
					// }
 | 
			
		||||
 | 
			
		||||
					nodeRoutes := map[uint64][]netip.Prefix{}
 | 
			
		||||
 | 
			
		||||
					var routes []types.Route
 | 
			
		||||
					err = tx.Find(&routes).Error
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return fmt.Errorf("fetching routes: %w", err)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					for _, route := range routes {
 | 
			
		||||
						if route.Enabled {
 | 
			
		||||
							nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					for nodeID, routes := range nodeRoutes {
 | 
			
		||||
						slices.SortFunc(routes, util.ComparePrefix)
 | 
			
		||||
						slices.Compact(routes)
 | 
			
		||||
 | 
			
		||||
						data, err := json.Marshal(routes)
 | 
			
		||||
 | 
			
		||||
						err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							return fmt.Errorf("saving approved routes to new column: %w", err)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				},
 | 
			
		||||
				Rollback: func(db *gorm.DB) error { return nil },
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				ID: "202502171819",
 | 
			
		||||
				Migrate: func(tx *gorm.DB) error {
 | 
			
		||||
					_ = tx.Migrator().DropColumn(&types.Node{}, "last_seen")
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				},
 | 
			
		||||
				Rollback: func(db *gorm.DB) error { return nil },
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,25 +48,43 @@ func TestMigrationsSQLite(t *testing.T) {
 | 
			
		||||
		{
 | 
			
		||||
			dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
 | 
			
		||||
			wantFunc: func(t *testing.T, h *HSDatabase) {
 | 
			
		||||
				routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
 | 
			
		||||
					return GetRoutes(rx)
 | 
			
		||||
				nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
 | 
			
		||||
					n1, err := GetNodeByID(rx, 1)
 | 
			
		||||
					n26, err := GetNodeByID(rx, 26)
 | 
			
		||||
					n31, err := GetNodeByID(rx, 31)
 | 
			
		||||
					n32, err := GetNodeByID(rx, 32)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return nil, err
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					return types.Nodes{n1, n26, n31, n32}, nil
 | 
			
		||||
				})
 | 
			
		||||
				require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
				assert.Len(t, routes, 10)
 | 
			
		||||
				want := types.Routes{
 | 
			
		||||
					r(1, "0.0.0.0/0", true, true, false),
 | 
			
		||||
					r(1, "::/0", true, true, false),
 | 
			
		||||
					r(1, "10.9.110.0/24", true, true, true),
 | 
			
		||||
					r(26, "172.100.100.0/24", true, true, true),
 | 
			
		||||
					r(26, "172.100.100.0/24", true, false, false),
 | 
			
		||||
					r(31, "0.0.0.0/0", true, true, false),
 | 
			
		||||
					r(31, "0.0.0.0/0", true, false, false),
 | 
			
		||||
					r(31, "::/0", true, true, false),
 | 
			
		||||
					r(31, "::/0", true, false, false),
 | 
			
		||||
					r(32, "192.168.0.24/32", true, true, true),
 | 
			
		||||
				// want := types.Routes{
 | 
			
		||||
				// 	r(1, "0.0.0.0/0", true, true, false),
 | 
			
		||||
				// 	r(1, "::/0", true, true, false),
 | 
			
		||||
				// 	r(1, "10.9.110.0/24", true, true, true),
 | 
			
		||||
				// 	r(26, "172.100.100.0/24", true, true, true),
 | 
			
		||||
				// 	r(26, "172.100.100.0/24", true, false, false),
 | 
			
		||||
				// 	r(31, "0.0.0.0/0", true, true, false),
 | 
			
		||||
				// 	r(31, "0.0.0.0/0", true, false, false),
 | 
			
		||||
				// 	r(31, "::/0", true, true, false),
 | 
			
		||||
				// 	r(31, "::/0", true, false, false),
 | 
			
		||||
				// 	r(32, "192.168.0.24/32", true, true, true),
 | 
			
		||||
				// }
 | 
			
		||||
				want := [][]netip.Prefix{
 | 
			
		||||
					{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.9.110.0/24")},
 | 
			
		||||
					{ipp("172.100.100.0/24")},
 | 
			
		||||
					{ipp("0.0.0.0/0"), ipp("::/0")},
 | 
			
		||||
					{ipp("192.168.0.24/32")},
 | 
			
		||||
				}
 | 
			
		||||
				if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
 | 
			
		||||
				var got [][]netip.Prefix
 | 
			
		||||
				for _, node := range nodes {
 | 
			
		||||
					got = append(got, node.ApprovedRoutes)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if diff := cmp.Diff(want, got, util.PrefixComparer); diff != "" {
 | 
			
		||||
					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | 
			
		||||
				}
 | 
			
		||||
			},
 | 
			
		||||
@ -74,13 +92,13 @@ func TestMigrationsSQLite(t *testing.T) {
 | 
			
		||||
		{
 | 
			
		||||
			dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
 | 
			
		||||
			wantFunc: func(t *testing.T, h *HSDatabase) {
 | 
			
		||||
				routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
 | 
			
		||||
					return GetRoutes(rx)
 | 
			
		||||
				node, err := Read(h.DB, func(rx *gorm.DB) (*types.Node, error) {
 | 
			
		||||
					return GetNodeByID(rx, 13)
 | 
			
		||||
				})
 | 
			
		||||
				require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
				assert.Len(t, routes, 4)
 | 
			
		||||
				want := types.Routes{
 | 
			
		||||
				assert.Len(t, node.ApprovedRoutes, 3)
 | 
			
		||||
				_ = types.Routes{
 | 
			
		||||
					// These routes exists, but have no nodes associated with them
 | 
			
		||||
					// when the migration starts.
 | 
			
		||||
					// r(1, "0.0.0.0/0", true, true, false),
 | 
			
		||||
@ -111,7 +129,8 @@ func TestMigrationsSQLite(t *testing.T) {
 | 
			
		||||
					r(13, "::/0", true, true, false),
 | 
			
		||||
					r(13, "10.18.80.2/32", true, true, true),
 | 
			
		||||
				}
 | 
			
		||||
				if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
 | 
			
		||||
				want := []netip.Prefix{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.18.80.2/32")}
 | 
			
		||||
				if diff := cmp.Diff(want, node.ApprovedRoutes, util.PrefixComparer); diff != "" {
 | 
			
		||||
					t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
 | 
			
		||||
				}
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,9 @@ type PolicyManager interface {
 | 
			
		||||
	SetPolicy([]byte) (bool, error)
 | 
			
		||||
	SetUsers(users []types.User) (bool, error)
 | 
			
		||||
	SetNodes(nodes types.Nodes) (bool, error)
 | 
			
		||||
 | 
			
		||||
	// NodeCanApproveRoute reports whether the given node can approve the given route.
 | 
			
		||||
	NodeCanApproveRoute(*types.Node, netip.Prefix) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
 | 
			
		||||
@ -185,3 +188,31 @@ func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
 | 
			
		||||
	}
 | 
			
		||||
	return ips, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
 | 
			
		||||
	if pm.pol == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pm.mu.Lock()
 | 
			
		||||
	defer pm.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
 | 
			
		||||
 | 
			
		||||
	for _, approvedAlias := range approvers {
 | 
			
		||||
		if approvedAlias == node.User.Username() {
 | 
			
		||||
			return true
 | 
			
		||||
		} else {
 | 
			
		||||
			ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
 | 
			
		||||
			if ips.Contains(*node.IPv4) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -7,12 +7,12 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/db"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/mapper"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/types"
 | 
			
		||||
	"github.com/juanfont/headscale/hscontrol/util"
 | 
			
		||||
	"github.com/rs/zerolog/log"
 | 
			
		||||
	"github.com/sasha-s/go-deadlock"
 | 
			
		||||
	xslices "golang.org/x/exp/slices"
 | 
			
		||||
@ -205,7 +205,15 @@ func (m *mapSession) serveLongPoll() {
 | 
			
		||||
		if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
 | 
			
		||||
			// Failover the node's routes if any.
 | 
			
		||||
			m.h.updateNodeOnlineStatus(false, m.node)
 | 
			
		||||
			m.pollFailoverRoutes("node closing connection", m.node)
 | 
			
		||||
 | 
			
		||||
			// When a node disconnects, and it causes the primary route map to change,
 | 
			
		||||
			// send a full update to all nodes.
 | 
			
		||||
			// TODO(kradalby): This can likely be made more effective, but likely most
 | 
			
		||||
			// nodes has access to the same routes, so it might not be a big deal.
 | 
			
		||||
			if m.h.primaryRoutes.DeregisterRoutes(m.node.ID, m.node.SubnetRoutes()...) {
 | 
			
		||||
				ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
 | 
			
		||||
				m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		m.afterServeLongPoll()
 | 
			
		||||
@ -216,7 +224,7 @@ func (m *mapSession) serveLongPoll() {
 | 
			
		||||
	m.h.pollNetMapStreamWG.Add(1)
 | 
			
		||||
	defer m.h.pollNetMapStreamWG.Done()
 | 
			
		||||
 | 
			
		||||
	m.pollFailoverRoutes("node connected", m.node)
 | 
			
		||||
	m.h.primaryRoutes.RegisterRoutes(m.node.ID, m.node.SubnetRoutes()...)
 | 
			
		||||
 | 
			
		||||
	// Upgrade the writer to a ResponseController
 | 
			
		||||
	rc := http.NewResponseController(m.w)
 | 
			
		||||
@ -383,22 +391,6 @@ func (m *mapSession) serveLongPoll() {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
 | 
			
		||||
	update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | 
			
		||||
		return db.FailoverNodeRoutesIfNecessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if update != nil && !update.Empty() {
 | 
			
		||||
		ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
 | 
			
		||||
		m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
 | 
			
		||||
// about change in their online/offline status.
 | 
			
		||||
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
 | 
			
		||||
@ -471,36 +463,36 @@ func (m *mapSession) handleEndpointUpdate() {
 | 
			
		||||
	// If the hostinfo has changed, but not the routes, just update
 | 
			
		||||
	// hostinfo and let the function continue.
 | 
			
		||||
	if routesChanged {
 | 
			
		||||
		var err error
 | 
			
		||||
		_, err = m.h.db.SaveNodeRoutes(m.node)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			m.errf(err, "Error processing node routes")
 | 
			
		||||
			http.Error(m.w, "", http.StatusInternalServerError)
 | 
			
		||||
			mapResponseEndpointUpdates.WithLabelValues("error").Inc()
 | 
			
		||||
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// TODO(kradalby): Only update the node that has actually changed
 | 
			
		||||
		// TODO(kradalby): I am not sure if we need this?
 | 
			
		||||
		nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
 | 
			
		||||
 | 
			
		||||
		if m.h.polMan != nil {
 | 
			
		||||
			// update routes with peer information
 | 
			
		||||
			err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				m.errf(err, "Error running auto approved routes")
 | 
			
		||||
				mapResponseEndpointUpdates.WithLabelValues("error").Inc()
 | 
			
		||||
		// Take all the routes presented to us by the node and check
 | 
			
		||||
		// if any of them should be auto approved by the policy.
 | 
			
		||||
		// If any of them are, add them to the approved routes of the node.
 | 
			
		||||
		// Keep all the old entries and compact the list to remove duplicates.
 | 
			
		||||
		var newApproved []netip.Prefix
 | 
			
		||||
		for _, route := range m.node.Hostinfo.RoutableIPs {
 | 
			
		||||
			if m.h.polMan.NodeCanApproveRoute(m.node, route) {
 | 
			
		||||
				newApproved = append(newApproved, route)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if newApproved != nil {
 | 
			
		||||
			newApproved = append(newApproved, m.node.ApprovedRoutes...)
 | 
			
		||||
			slices.SortFunc(newApproved, util.ComparePrefix)
 | 
			
		||||
			slices.Compact(newApproved)
 | 
			
		||||
			m.node.ApprovedRoutes = newApproved
 | 
			
		||||
 | 
			
		||||
			// TODO(kradalby): I am not sure if we need this?
 | 
			
		||||
			// Send an update to the node itself with to ensure it
 | 
			
		||||
			// has an updated packetfilter allowing the new route
 | 
			
		||||
			// if it is defined in the ACL.
 | 
			
		||||
			ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
 | 
			
		||||
			m.h.nodeNotifier.NotifyByNodeID(
 | 
			
		||||
				ctx,
 | 
			
		||||
				types.UpdateSelf(m.node.ID),
 | 
			
		||||
				m.node.ID)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Send an update to the node itself with to ensure it
 | 
			
		||||
		// has an updated packetfilter allowing the new route
 | 
			
		||||
		// if it is defined in the ACL.
 | 
			
		||||
		ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
 | 
			
		||||
		m.h.nodeNotifier.NotifyByNodeID(
 | 
			
		||||
			ctx,
 | 
			
		||||
			types.UpdateSelf(m.node.ID),
 | 
			
		||||
			m.node.ID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check if there has been a change to Hostname and update them
 | 
			
		||||
 | 
			
		||||
@ -99,7 +99,7 @@ type Node struct {
 | 
			
		||||
	// as a subnet router. They are not necessarily the routes that the node
 | 
			
		||||
	// announces at the moment.
 | 
			
		||||
	// See [Node.Hostinfo]
 | 
			
		||||
	ApprovedRoutes []netip.Prefix `gorm:"serializer:json"`
 | 
			
		||||
	ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"`
 | 
			
		||||
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	UpdatedAt time.Time
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,10 @@
 | 
			
		||||
package util
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"cmp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
 | 
			
		||||
@ -10,3 +12,20 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
 | 
			
		||||
 | 
			
		||||
	return d.DialContext(ctx, "unix", addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(kradalby): Remove when in stdlib;
 | 
			
		||||
// https://github.com/golang/go/issues/61642
 | 
			
		||||
// Compare returns an integer comparing two prefixes.
 | 
			
		||||
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
 | 
			
		||||
// Prefixes sort first by validity (invalid before valid), then
 | 
			
		||||
// address family (IPv4 before IPv6), then prefix length, then
 | 
			
		||||
// address.
 | 
			
		||||
func ComparePrefix(p, p2 netip.Prefix) int {
 | 
			
		||||
	if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
	if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
	return p.Addr().Compare(p2.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user