1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-01 13:46:49 +02:00

initial work on db migration

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-14 08:51:24 +01:00
parent fe9557a729
commit df85797954
No known key found for this signature in database
6 changed files with 183 additions and 65 deletions

View File

@ -8,6 +8,7 @@ import (
"fmt"
"net/netip"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
@ -622,6 +623,62 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
// Migrate all routes from the Route table to the new field ApprovedRoutes
// in the Node table. Then drop the Route table.
{
ID: "202502131714",
Migrate: func(tx *gorm.DB) error {
if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") {
err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes")
if err != nil {
return fmt.Errorf("adding column types.Node: %w", err)
}
}
// Ensure the ApprovedRoutes exist.
// err := tx.AutoMigrate(&types.Node{})
// if err != nil {
// return fmt.Errorf("automigrating types.Node: %w", err)
// }
nodeRoutes := map[uint64][]netip.Prefix{}
var routes []types.Route
err = tx.Find(&routes).Error
if err != nil {
return fmt.Errorf("fetching routes: %w", err)
}
for _, route := range routes {
if route.Enabled {
nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix)
}
}
for nodeID, routes := range nodeRoutes {
slices.SortFunc(routes, util.ComparePrefix)
slices.Compact(routes)
data, err := json.Marshal(routes)
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
if err != nil {
return fmt.Errorf("saving approved routes to new column: %w", err)
}
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
ID: "202502171819",
Migrate: func(tx *gorm.DB) error {
_ = tx.Migrator().DropColumn(&types.Node{}, "last_seen")
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View File

@ -48,25 +48,43 @@ func TestMigrationsSQLite(t *testing.T) {
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
n1, err := GetNodeByID(rx, 1)
n26, err := GetNodeByID(rx, 26)
n31, err := GetNodeByID(rx, 31)
n32, err := GetNodeByID(rx, 32)
if err != nil {
return nil, err
}
return types.Nodes{n1, n26, n31, n32}, nil
})
require.NoError(t, err)
assert.Len(t, routes, 10)
want := types.Routes{
r(1, "0.0.0.0/0", true, true, false),
r(1, "::/0", true, true, false),
r(1, "10.9.110.0/24", true, true, true),
r(26, "172.100.100.0/24", true, true, true),
r(26, "172.100.100.0/24", true, false, false),
r(31, "0.0.0.0/0", true, true, false),
r(31, "0.0.0.0/0", true, false, false),
r(31, "::/0", true, true, false),
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
// want := types.Routes{
// r(1, "0.0.0.0/0", true, true, false),
// r(1, "::/0", true, true, false),
// r(1, "10.9.110.0/24", true, true, true),
// r(26, "172.100.100.0/24", true, true, true),
// r(26, "172.100.100.0/24", true, false, false),
// r(31, "0.0.0.0/0", true, true, false),
// r(31, "0.0.0.0/0", true, false, false),
// r(31, "::/0", true, true, false),
// r(31, "::/0", true, false, false),
// r(32, "192.168.0.24/32", true, true, true),
// }
want := [][]netip.Prefix{
{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.9.110.0/24")},
{ipp("172.100.100.0/24")},
{ipp("0.0.0.0/0"), ipp("::/0")},
{ipp("192.168.0.24/32")},
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
var got [][]netip.Prefix
for _, node := range nodes {
got = append(got, node.ApprovedRoutes)
}
if diff := cmp.Diff(want, got, util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
@ -74,13 +92,13 @@ func TestMigrationsSQLite(t *testing.T) {
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
node, err := Read(h.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, 13)
})
require.NoError(t, err)
assert.Len(t, routes, 4)
want := types.Routes{
assert.Len(t, node.ApprovedRoutes, 3)
_ = types.Routes{
// These routes exists, but have no nodes associated with them
// when the migration starts.
// r(1, "0.0.0.0/0", true, true, false),
@ -111,7 +129,8 @@ func TestMigrationsSQLite(t *testing.T) {
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
want := []netip.Prefix{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.18.80.2/32")}
if diff := cmp.Diff(want, node.ApprovedRoutes, util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},

View File

@ -23,6 +23,9 @@ type PolicyManager interface {
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error)
// NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
}
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
@ -185,3 +188,31 @@ func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
}
return ips, nil
}
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
if pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
for _, approvedAlias := range approvers {
if approvedAlias == node.User.Username() {
return true
} else {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
if err != nil {
return false
}
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if ips.Contains(*node.IPv4) {
return true
}
}
}
return false
}

View File

@ -7,12 +7,12 @@ import (
"net/http"
"net/netip"
"slices"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
@ -205,7 +205,15 @@ func (m *mapSession) serveLongPoll() {
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
// Failover the node's routes if any.
m.h.updateNodeOnlineStatus(false, m.node)
m.pollFailoverRoutes("node closing connection", m.node)
// When a node disconnects, and it causes the primary route map to change,
// send a full update to all nodes.
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
if m.h.primaryRoutes.DeregisterRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
m.afterServeLongPoll()
@ -216,7 +224,7 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
m.pollFailoverRoutes("node connected", m.node)
m.h.primaryRoutes.RegisterRoutes(m.node.ID, m.node.SubnetRoutes()...)
// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
@ -383,22 +391,6 @@ func (m *mapSession) serveLongPoll() {
}
}
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.FailoverNodeRoutesIfNecessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
})
if err != nil {
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
return
}
if update != nil && !update.Empty() {
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID)
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
@ -471,36 +463,36 @@ func (m *mapSession) handleEndpointUpdate() {
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
var err error
_, err = m.h.db.SaveNodeRoutes(m.node)
if err != nil {
m.errf(err, "Error processing node routes")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
// TODO(kradalby): Only update the node that has actually changed
// TODO(kradalby): I am not sure if we need this?
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
if m.h.polMan != nil {
// update routes with peer information
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
if err != nil {
m.errf(err, "Error running auto approved routes")
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
// Take all the routes presented to us by the node and check
// if any of them should be auto approved by the policy.
// If any of them are, add them to the approved routes of the node.
// Keep all the old entries and compact the list to remove duplicates.
var newApproved []netip.Prefix
for _, route := range m.node.Hostinfo.RoutableIPs {
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
newApproved = append(newApproved, route)
}
}
if newApproved != nil {
newApproved = append(newApproved, m.node.ApprovedRoutes...)
slices.SortFunc(newApproved, util.ComparePrefix)
slices.Compact(newApproved)
m.node.ApprovedRoutes = newApproved
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(m.node.ID),
m.node.ID)
}
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(m.node.ID),
m.node.ID)
}
// Check if there has been a change to Hostname and update them

View File

@ -99,7 +99,7 @@ type Node struct {
// as a subnet router. They are not necessarily the routes that the node
// announces at the moment.
// See [Node.Hostinfo]
ApprovedRoutes []netip.Prefix `gorm:"serializer:json"`
ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"`
CreatedAt time.Time
UpdatedAt time.Time

View File

@ -1,8 +1,10 @@
package util
import (
"cmp"
"context"
"net"
"net/netip"
)
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
@ -10,3 +12,20 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
return d.DialContext(ctx, "unix", addr)
}
// TODO(kradalby): Remove when in stdlib;
// https://github.com/golang/go/issues/61642
// Compare returns an integer comparing two prefixes.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// Prefixes sort first by validity (invalid before valid), then
// address family (IPv4 before IPv6), then prefix length, then
// address.
func ComparePrefix(p, p2 netip.Prefix) int {
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
return c
}
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
return c
}
return p.Addr().Compare(p2.Addr())
}