diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 7f4ecb32..a130f876 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -8,6 +8,7 @@ import ( "fmt" "net/netip" "path/filepath" + "slices" "strconv" "strings" "time" @@ -622,6 +623,62 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + // Migrate all routes from the Route table to the new field ApprovedRoutes + // in the Node table. Then drop the Route table. + { + ID: "202502131714", + Migrate: func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") { + err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes") + if err != nil { + return fmt.Errorf("adding column types.Node: %w", err) + } + } + // Ensure the ApprovedRoutes exist. + // err := tx.AutoMigrate(&types.Node{}) + // if err != nil { + // return fmt.Errorf("automigrating types.Node: %w", err) + // } + + nodeRoutes := map[uint64][]netip.Prefix{} + + var routes []types.Route + err = tx.Find(&routes).Error + if err != nil { + return fmt.Errorf("fetching routes: %w", err) + } + + for _, route := range routes { + if route.Enabled { + nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix) + } + } + + for nodeID, routes := range nodeRoutes { + slices.SortFunc(routes, util.ComparePrefix) + slices.Compact(routes) + + data, err := json.Marshal(routes) + + err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error + if err != nil { + return fmt.Errorf("saving approved routes to new column: %w", err) + } + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + { + ID: "202502171819", + Migrate: func(tx *gorm.DB) error { + _ = tx.Migrator().DropColumn(&types.Node{}, "last_seen") + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 1efaa282..35170432 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -48,25 +48,43 @@ func TestMigrationsSQLite(t *testing.T) { { dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite", wantFunc: func(t *testing.T, h *HSDatabase) { - routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { - return GetRoutes(rx) + nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) { + n1, err := GetNodeByID(rx, 1) + n26, err := GetNodeByID(rx, 26) + n31, err := GetNodeByID(rx, 31) + n32, err := GetNodeByID(rx, 32) + if err != nil { + return nil, err + } + + return types.Nodes{n1, n26, n31, n32}, nil }) require.NoError(t, err) - assert.Len(t, routes, 10) - want := types.Routes{ - r(1, "0.0.0.0/0", true, true, false), - r(1, "::/0", true, true, false), - r(1, "10.9.110.0/24", true, true, true), - r(26, "172.100.100.0/24", true, true, true), - r(26, "172.100.100.0/24", true, false, false), - r(31, "0.0.0.0/0", true, true, false), - r(31, "0.0.0.0/0", true, false, false), - r(31, "::/0", true, true, false), - r(31, "::/0", true, false, false), - r(32, "192.168.0.24/32", true, true, true), + // want := types.Routes{ + // r(1, "0.0.0.0/0", true, true, false), + // r(1, "::/0", true, true, false), + // r(1, "10.9.110.0/24", true, true, true), + // r(26, "172.100.100.0/24", true, true, true), + // r(26, "172.100.100.0/24", true, false, false), + // r(31, "0.0.0.0/0", true, true, false), + // r(31, "0.0.0.0/0", true, false, false), + // r(31, "::/0", true, true, false), + // r(31, "::/0", true, false, false), + // r(32, "192.168.0.24/32", true, true, true), + // } + want := [][]netip.Prefix{ + {ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.9.110.0/24")}, + {ipp("172.100.100.0/24")}, + {ipp("0.0.0.0/0"), ipp("::/0")}, + {ipp("192.168.0.24/32")}, } - if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" { + var got [][]netip.Prefix + for _, node := range nodes { + got = append(got, node.ApprovedRoutes) + } + + if diff := cmp.Diff(want, got, util.PrefixComparer); diff != "" { t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) } }, @@ -74,13 +92,13 @@ func TestMigrationsSQLite(t *testing.T) { { dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite", wantFunc: func(t *testing.T, h *HSDatabase) { - routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { - return GetRoutes(rx) + node, err := Read(h.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByID(rx, 13) }) require.NoError(t, err) - assert.Len(t, routes, 4) - want := types.Routes{ + assert.Len(t, node.ApprovedRoutes, 3) + _ = types.Routes{ // These routes exists, but have no nodes associated with them // when the migration starts. // r(1, "0.0.0.0/0", true, true, false), @@ -111,7 +129,8 @@ func TestMigrationsSQLite(t *testing.T) { r(13, "::/0", true, true, false), r(13, "10.18.80.2/32", true, true, true), } - if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" { + want := []netip.Prefix{ipp("0.0.0.0/0"), ipp("::/0"), ipp("10.18.80.2/32")} + if diff := cmp.Diff(want, node.ApprovedRoutes, util.PrefixComparer); diff != "" { t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) } }, diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 4e10003e..457fb62c 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -23,6 +23,9 @@ type PolicyManager interface { SetPolicy([]byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes types.Nodes) (bool, error) + + // NodeCanApproveRoute reports whether the given node can approve the given route. + NodeCanApproveRoute(*types.Node, netip.Prefix) bool } func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { @@ -185,3 +188,31 @@ func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) { } return ips, nil } + +func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool { + if pm.pol == nil { + return false + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) + + for _, approvedAlias := range approvers { + if approvedAlias == node.User.Username() { + return true + } else { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias) + if err != nil { + return false + } + + // approvedIPs should contain all of node's IPs if it matches the rule, so check for first + if ips.Contains(*node.IPv4) { + return true + } + } + } + return false +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 2df35c36..3802401f 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -7,12 +7,12 @@ import ( "net/http" "net/netip" "slices" - "strings" "time" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" xslices "golang.org/x/exp/slices" @@ -205,7 +205,15 @@ func (m *mapSession) serveLongPoll() { if m.h.nodeNotifier.RemoveNode(m.node.ID, m.ch) { // Failover the node's routes if any. m.h.updateNodeOnlineStatus(false, m.node) - m.pollFailoverRoutes("node closing connection", m.node) + + // When a node disconnects, and it causes the primary route map to change, + // send a full update to all nodes. + // TODO(kradalby): This can likely be made more effective, but likely most + // nodes has access to the same routes, so it might not be a big deal. + if m.h.primaryRoutes.DeregisterRoutes(m.node.ID, m.node.SubnetRoutes()...) { + ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname) + m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull()) + } } m.afterServeLongPoll() @@ -216,7 +224,7 @@ func (m *mapSession) serveLongPoll() { m.h.pollNetMapStreamWG.Add(1) defer m.h.pollNetMapStreamWG.Done() - m.pollFailoverRoutes("node connected", m.node) + m.h.primaryRoutes.RegisterRoutes(m.node.ID, m.node.SubnetRoutes()...) // Upgrade the writer to a ResponseController rc := http.NewResponseController(m.w) @@ -383,22 +391,6 @@ func (m *mapSession) serveLongPoll() { } } -func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { - update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { - return db.FailoverNodeRoutesIfNecessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node) - }) - if err != nil { - m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) - - return - } - - if update != nil && !update.Empty() { - ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname) - m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID) - } -} - // updateNodeOnlineStatus records the last seen status of a node and notifies peers // about change in their online/offline status. // It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. @@ -471,36 +463,36 @@ func (m *mapSession) handleEndpointUpdate() { // If the hostinfo has changed, but not the routes, just update // hostinfo and let the function continue. if routesChanged { - var err error - _, err = m.h.db.SaveNodeRoutes(m.node) - if err != nil { - m.errf(err, "Error processing node routes") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - - return - } - - // TODO(kradalby): Only update the node that has actually changed + // TODO(kradalby): I am not sure if we need this? nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier) - if m.h.polMan != nil { - // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) - if err != nil { - m.errf(err, "Error running auto approved routes") - mapResponseEndpointUpdates.WithLabelValues("error").Inc() + // Take all the routes presented to us by the node and check + // if any of them should be auto approved by the policy. + // If any of them are, add them to the approved routes of the node. + // Keep all the old entries and compact the list to remove duplicates. + var newApproved []netip.Prefix + for _, route := range m.node.Hostinfo.RoutableIPs { + if m.h.polMan.NodeCanApproveRoute(m.node, route) { + newApproved = append(newApproved, route) } } + if newApproved != nil { + newApproved = append(newApproved, m.node.ApprovedRoutes...) + slices.SortFunc(newApproved, util.ComparePrefix) + slices.Compact(newApproved) + m.node.ApprovedRoutes = newApproved + + // TODO(kradalby): I am not sure if we need this? + // Send an update to the node itself with to ensure it + // has an updated packetfilter allowing the new route + // if it is defined in the ACL. + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) + m.h.nodeNotifier.NotifyByNodeID( + ctx, + types.UpdateSelf(m.node.ID), + m.node.ID) + } - // Send an update to the node itself with to ensure it - // has an updated packetfilter allowing the new route - // if it is defined in the ACL. - ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) - m.h.nodeNotifier.NotifyByNodeID( - ctx, - types.UpdateSelf(m.node.ID), - m.node.ID) } // Check if there has been a change to Hostname and update them diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 527a229b..c50b2995 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -99,7 +99,7 @@ type Node struct { // as a subnet router. They are not necessarily the routes that the node // announces at the moment. // See [Node.Hostinfo] - ApprovedRoutes []netip.Prefix `gorm:"serializer:json"` + ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"` CreatedAt time.Time UpdatedAt time.Time diff --git a/hscontrol/util/net.go b/hscontrol/util/net.go index b704c936..b924caa0 100644 --- a/hscontrol/util/net.go +++ b/hscontrol/util/net.go @@ -1,8 +1,10 @@ package util import ( + "cmp" "context" "net" + "net/netip" ) func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { @@ -10,3 +12,20 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { return d.DialContext(ctx, "unix", addr) } + +// TODO(kradalby): Remove when in stdlib; +// https://github.com/golang/go/issues/61642 +// Compare returns an integer comparing two prefixes. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// Prefixes sort first by validity (invalid before valid), then +// address family (IPv4 before IPv6), then prefix length, then +// address. +func ComparePrefix(p, p2 netip.Prefix) int { + if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 { + return c + } + if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 { + return c + } + return p.Addr().Compare(p2.Addr()) +}