mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-05 13:49:57 +02:00
initial work on db migration
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
fe9557a729
commit
df85797954
@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -622,6 +623,62 @@ AND auth_key_id NOT IN (
|
|||||||
},
|
},
|
||||||
Rollback: func(db *gorm.DB) error { return nil },
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
},
|
},
|
||||||
|
// Migrate all routes from the Route table to the new field ApprovedRoutes
|
||||||
|
// in the Node table. Then drop the Route table.
|
||||||
|
{
|
||||||
|
ID: "202502131714",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") {
|
||||||
|
err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("adding column types.Node: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Ensure the ApprovedRoutes exist.
|
||||||
|
// err := tx.AutoMigrate(&types.Node{})
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("automigrating types.Node: %w", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||||
|
|
||||||
|
var routes []types.Route
|
||||||
|
err = tx.Find(&routes).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetching routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.Enabled {
|
||||||
|
nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for nodeID, routes := range nodeRoutes {
|
||||||
|
slices.SortFunc(routes, util.ComparePrefix)
|
||||||
|
slices.Compact(routes)
|
||||||
|
|
||||||
|
data, err := json.Marshal(routes)
|
||||||
|
|
||||||
|
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("saving approved routes to new column: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "202502171819",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
_ = tx.Migrator().DropColumn(&types.Node{}, "last_seen")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,25 +48,43 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||||||
{
|
{
|
||||||
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
|
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
|
||||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||||
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return GetRoutes(rx)
|
n1, err := GetNodeByID(rx, 1)
|
||||||
|
n26, err := GetNodeByID(rx, 26)
|
||||||
|
n31, err := GetNodeByID(rx, 31)
|
||||||
|
n32, err := GetNodeByID(rx, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return types.Nodes{n1, n26, n31, n32}, nil
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, routes, 10)
|
// want := types.Routes{
|
||||||
want := types.Routes{
|
// r(1, "0.0.0.0/0", true, true, false),
|
||||||
r(1, "0.0.0.0/0", true, true, false),
|
// r(1, "::/0", true, true, false),
|
||||||
r(1, "::/0", true, true, false),
|
// r(1, "10.9.110.0/24", true, true, true),
|
||||||
r(1, "10.9.110.0/24", true, true, true),
|
// r(26, "172.100.100.0/24", true, true, true),
|
||||||
r(26, "172.100.100.0/24", true, true, true),
|
// r(26, "172.100.100.0/24", true, false, false),
|
||||||
r(26, "172.100.100.0/24", true, false, false),
|
// r(31, "0.0.0.0/0", true, true, false),
|
||||||
r(31, "0.0.0.0/0", true, true, false),
|
// r(31, "0.0.0.0/0", true, false, false),
|
||||||
r(31, "0.0.0.0/0", true, false, false),
|
// r(31, "::/0", true, true, false),
|
||||||
r(31, "::/0", true, true, false),
|
// r(31, "::/0", true, false, false),
|
||||||
r(31, "::/0", true, false, false),
|
// r(32, "192.168.0.24/32", true, true, true),
|
||||||
r(32, "192.168.0.24/32", true, true, true),
|
// }
|
||||||
|
want := [][]netip.Prefix{
|
||||||
|
{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.9.110.0/24")},
|
||||||
|
{ipp("172.100.100.0/24")},
|
||||||
|
{ipp("0.0.0.0/0"), ipp("::/0")},
|
||||||
|
{ipp("192.168.0.24/32")},
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
var got [][]netip.Prefix
|
||||||
|
for _, node := range nodes {
|
||||||
|
got = append(got, node.ApprovedRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, got, util.PrefixComparer); diff != "" {
|
||||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -74,13 +92,13 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||||||
{
|
{
|
||||||
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
|
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
|
||||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||||
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
node, err := Read(h.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
return GetRoutes(rx)
|
return GetNodeByID(rx, 13)
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, routes, 4)
|
assert.Len(t, node.ApprovedRoutes, 3)
|
||||||
want := types.Routes{
|
_ = types.Routes{
|
||||||
// These routes exists, but have no nodes associated with them
|
// These routes exists, but have no nodes associated with them
|
||||||
// when the migration starts.
|
// when the migration starts.
|
||||||
// r(1, "0.0.0.0/0", true, true, false),
|
// r(1, "0.0.0.0/0", true, true, false),
|
||||||
@ -111,7 +129,8 @@ func TestMigrationsSQLite(t *testing.T) {
|
|||||||
r(13, "::/0", true, true, false),
|
r(13, "::/0", true, true, false),
|
||||||
r(13, "10.18.80.2/32", true, true, true),
|
r(13, "10.18.80.2/32", true, true, true),
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
|
want := []netip.Prefix{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.18.80.2/32")}
|
||||||
|
if diff := cmp.Diff(want, node.ApprovedRoutes, util.PrefixComparer); diff != "" {
|
||||||
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -23,6 +23,9 @@ type PolicyManager interface {
|
|||||||
SetPolicy([]byte) (bool, error)
|
SetPolicy([]byte) (bool, error)
|
||||||
SetUsers(users []types.User) (bool, error)
|
SetUsers(users []types.User) (bool, error)
|
||||||
SetNodes(nodes types.Nodes) (bool, error)
|
SetNodes(nodes types.Nodes) (bool, error)
|
||||||
|
|
||||||
|
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||||
|
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
@ -185,3 +188,31 @@ func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
|||||||
}
|
}
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||||
|
if pm.pol == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||||
|
|
||||||
|
for _, approvedAlias := range approvers {
|
||||||
|
if approvedAlias == node.User.Username() {
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||||
|
if ips.Contains(*node.IPv4) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -7,12 +7,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/sasha-s/go-deadlock"
|
"github.com/sasha-s/go-deadlock"
|
||||||
xslices "golang.org/x/exp/slices"
|
xslices "golang.org/x/exp/slices"
|
||||||
@ -205,7 +205,15 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
|
if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) {
|
||||||
// Failover the node's routes if any.
|
// Failover the node's routes if any.
|
||||||
m.h.updateNodeOnlineStatus(false, m.node)
|
m.h.updateNodeOnlineStatus(false, m.node)
|
||||||
m.pollFailoverRoutes("node closing connection", m.node)
|
|
||||||
|
// When a node disconnects, and it causes the primary route map to change,
|
||||||
|
// send a full update to all nodes.
|
||||||
|
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||||
|
// nodes has access to the same routes, so it might not be a big deal.
|
||||||
|
if m.h.primaryRoutes.DeregisterRoutes(m.node.ID, m.node.SubnetRoutes()...) {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
|
||||||
|
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.afterServeLongPoll()
|
m.afterServeLongPoll()
|
||||||
@ -216,7 +224,7 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
m.h.pollNetMapStreamWG.Add(1)
|
m.h.pollNetMapStreamWG.Add(1)
|
||||||
defer m.h.pollNetMapStreamWG.Done()
|
defer m.h.pollNetMapStreamWG.Done()
|
||||||
|
|
||||||
m.pollFailoverRoutes("node connected", m.node)
|
m.h.primaryRoutes.RegisterRoutes(m.node.ID, m.node.SubnetRoutes()...)
|
||||||
|
|
||||||
// Upgrade the writer to a ResponseController
|
// Upgrade the writer to a ResponseController
|
||||||
rc := http.NewResponseController(m.w)
|
rc := http.NewResponseController(m.w)
|
||||||
@ -383,22 +391,6 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
|
|
||||||
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
|
||||||
return db.FailoverNodeRoutesIfNecessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if update != nil && !update.Empty() {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname)
|
|
||||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
|
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
|
||||||
// about change in their online/offline status.
|
// about change in their online/offline status.
|
||||||
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
|
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
|
||||||
@ -471,28 +463,26 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||||||
// If the hostinfo has changed, but not the routes, just update
|
// If the hostinfo has changed, but not the routes, just update
|
||||||
// hostinfo and let the function continue.
|
// hostinfo and let the function continue.
|
||||||
if routesChanged {
|
if routesChanged {
|
||||||
var err error
|
// TODO(kradalby): I am not sure if we need this?
|
||||||
_, err = m.h.db.SaveNodeRoutes(m.node)
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Error processing node routes")
|
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(kradalby): Only update the node that has actually changed
|
|
||||||
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||||
|
|
||||||
if m.h.polMan != nil {
|
// Take all the routes presented to us by the node and check
|
||||||
// update routes with peer information
|
// if any of them should be auto approved by the policy.
|
||||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
|
// If any of them are, add them to the approved routes of the node.
|
||||||
if err != nil {
|
// Keep all the old entries and compact the list to remove duplicates.
|
||||||
m.errf(err, "Error running auto approved routes")
|
var newApproved []netip.Prefix
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
for _, route := range m.node.Hostinfo.RoutableIPs {
|
||||||
|
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
|
||||||
|
newApproved = append(newApproved, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if newApproved != nil {
|
||||||
|
newApproved = append(newApproved, m.node.ApprovedRoutes...)
|
||||||
|
slices.SortFunc(newApproved, util.ComparePrefix)
|
||||||
|
slices.Compact(newApproved)
|
||||||
|
m.node.ApprovedRoutes = newApproved
|
||||||
|
|
||||||
|
// TODO(kradalby): I am not sure if we need this?
|
||||||
// Send an update to the node itself with to ensure it
|
// Send an update to the node itself with to ensure it
|
||||||
// has an updated packetfilter allowing the new route
|
// has an updated packetfilter allowing the new route
|
||||||
// if it is defined in the ACL.
|
// if it is defined in the ACL.
|
||||||
@ -503,6 +493,8 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||||||
m.node.ID)
|
m.node.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// Check if there has been a change to Hostname and update them
|
// Check if there has been a change to Hostname and update them
|
||||||
// in the database. Then send a Changed update
|
// in the database. Then send a Changed update
|
||||||
// (containing the whole node object) to peers to inform about
|
// (containing the whole node object) to peers to inform about
|
||||||
|
@ -99,7 +99,7 @@ type Node struct {
|
|||||||
// as a subnet router. They are not necessarily the routes that the node
|
// as a subnet router. They are not necessarily the routes that the node
|
||||||
// announces at the moment.
|
// announces at the moment.
|
||||||
// See [Node.Hostinfo]
|
// See [Node.Hostinfo]
|
||||||
ApprovedRoutes []netip.Prefix `gorm:"serializer:json"`
|
ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"`
|
||||||
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
@ -10,3 +12,20 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
|||||||
|
|
||||||
return d.DialContext(ctx, "unix", addr)
|
return d.DialContext(ctx, "unix", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Remove when in stdlib;
|
||||||
|
// https://github.com/golang/go/issues/61642
|
||||||
|
// Compare returns an integer comparing two prefixes.
|
||||||
|
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
|
||||||
|
// Prefixes sort first by validity (invalid before valid), then
|
||||||
|
// address family (IPv4 before IPv6), then prefix length, then
|
||||||
|
// address.
|
||||||
|
func ComparePrefix(p, p2 netip.Prefix) int {
|
||||||
|
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return p.Addr().Compare(p2.Addr())
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user