1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

replace ephemeral deletion logic (#2008)

* replace ephemeral deletion logic

this commit replaces the way we remove ephemeral nodes,
currently they are deleted in a loop and we look at last seen
time. This time is now only set when a node disconnects and
there was a bug (#2006) where nodes that had never disconnected
was deleted since they did not have a last seen.

The new logic will start an expiry timer when the node disconnects
and delete the node from the database when the timer is up.

If the node reconnects within the expiry, the timer is cancelled.

Fixes #2006

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use uint64 as authekyid and ptr helper in tests

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add test db helper

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add list ephemeral node func

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* schedule ephemeral nodes for removal on startup

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix gorm query for postgres

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add godoc

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-07-18 10:01:59 +02:00 committed by GitHub
parent 58bd38a609
commit 7e62031444
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 417 additions and 206 deletions

View File

@ -42,6 +42,7 @@ jobs:
- TestPingAllByIPPublicDERP - TestPingAllByIPPublicDERP
- TestAuthKeyLogoutAndRelogin - TestAuthKeyLogoutAndRelogin
- TestEphemeral - TestEphemeral
- TestEphemeral2006DeletedTooQuickly
- TestPingAllByHostname - TestPingAllByHostname
- TestTaildrop - TestTaildrop
- TestResolveMagicDNS - TestResolveMagicDNS

View File

@ -91,6 +91,7 @@ type Headscale struct {
db *db.HSDatabase db *db.HSDatabase
ipAlloc *db.IPAllocator ipAlloc *db.IPAllocator
noisePrivateKey *key.MachinePrivate noisePrivateKey *key.MachinePrivate
ephemeralGC *db.EphemeralGarbageCollector
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer DERPServer *derpServer.DERPServer
@ -153,6 +154,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, err return nil, err
} }
app.ephemeralGC = db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
if err := app.db.DeleteEphemeralNode(ni); err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node")
}
})
if cfg.OIDC.Issuer != "" { if cfg.OIDC.Issuer != "" {
err = app.initOIDC() err = app.initOIDC()
if err != nil { if err != nil {
@ -217,47 +224,6 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, target, http.StatusFound) http.Redirect(w, req, target, http.StatusFound)
} }
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) {
ticker := time.NewTicker(every)
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.Write(func(tx *gorm.DB) error {
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}
if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
}
}
}
}
// expireExpiredNodes expires nodes that have an explicit expiry set // expireExpiredNodes expires nodes that have an explicit expiry set
// after that expiry time has passed. // after that expiry time has passed.
func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) { func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) {
@ -557,9 +523,18 @@ func (h *Headscale) Serve() error {
return errEmptyInitialDERPMap return errEmptyInitialDERPMap
} }
expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background()) // Start ephemeral node garbage collector and schedule all nodes
defer expireEphemeralCancel() // that are already in the database and ephemeral. If they are still
go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval) // around between restarts, they will reconnect and the GC will
// be cancelled.
go h.ephemeralGC.Start()
ephmNodes, err := h.db.ListEphemeralNodes()
if err != nil {
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
}
for _, node := range ephmNodes {
h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout)
}
expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background())
defer expireNodeCancel() defer expireNodeCancel()
@ -809,7 +784,7 @@ func (h *Headscale) Serve() error {
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
expireNodeCancel() expireNodeCancel()
expireEphemeralCancel() h.ephemeralGC.Close()
trace("waiting for netmap stream to close") trace("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait() h.pollNetMapStreamWG.Wait()

View File

@ -16,6 +16,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
func logAuthFunc( func logAuthFunc(
@ -314,9 +315,8 @@ func (h *Headscale) handleAuthKey(
Msg("node was already registered before, refreshing with new auth key") Msg("node was already registered before, refreshing with new auth key")
node.NodeKey = nodeKey node.NodeKey = nodeKey
pakID := uint(pak.ID) if pak.ID != 0 {
if pakID != 0 { node.AuthKeyID = ptr.To(pak.ID)
node.AuthKeyID = &pakID
} }
node.Expiry = &registerRequest.Expiry node.Expiry = &registerRequest.Expiry
@ -394,7 +394,7 @@ func (h *Headscale) handleAuthKey(
pakID := uint(pak.ID) pakID := uint(pak.ID)
if pakID != 0 { if pakID != 0 {
nodeToRegister.AuthKeyID = &pakID nodeToRegister.AuthKeyID = ptr.To(pak.ID)
} }
node, err = h.db.RegisterNode( node, err = h.db.RegisterNode(
nodeToRegister, nodeToRegister,

View File

@ -12,6 +12,7 @@ import (
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -78,6 +79,17 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) {
return nodes, nil return nodes, nil
} }
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{}
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
return nil, err
}
return nodes, nil
})
}
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx. if err := tx.
@ -286,6 +298,20 @@ func DeleteNode(tx *gorm.DB,
return changed, nil return changed, nil
} }
// DeleteEphemeralNode deletes a Node from the database, note that this method
// will remove it straight, and not notify any changes or consider any routes.
// It is intended for Ephemeral nodes.
func (hsdb *HSDatabase) DeleteEphemeralNode(
nodeID types.NodeID,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
return err
}
return nil
})
}
// SetLastSeen sets a node's last seen field indicating that we // SetLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node. // have recently communicating with this node.
func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
@ -660,51 +686,6 @@ func GenerateGivenName(
return givenName, nil return givenName, nil
} }
func DeleteExpiredEphemeralNodes(tx *gorm.DB,
inactivityThreshold time.Duration,
) ([]types.NodeID, []types.NodeID) {
users, err := ListUsers(tx)
if err != nil {
return nil, nil
}
var expired []types.NodeID
var changedNodes []types.NodeID
for _, user := range users {
nodes, err := ListNodesByUser(tx, user.Name)
if err != nil {
return nil, nil
}
for idx, node := range nodes {
if node.IsEphemeral() && node.LastSeen != nil &&
time.Now().
After(node.LastSeen.Add(inactivityThreshold)) {
expired = append(expired, node.ID)
log.Info().
Str("node", node.Hostname).
Msg("Ephemeral client removed from database")
// empty isConnected map as ephemeral nodes are not routes
changed, err := DeleteNode(tx, nodes[idx], nil)
if err != nil {
log.Error().
Err(err).
Str("node", node.Hostname).
Msg("🤮 Cannot delete ephemeral node from the database")
}
changedNodes = append(changedNodes, changed...)
}
}
// TODO(kradalby): needs to be moved out of transaction
}
return expired, changedNodes
}
func ExpireExpiredNodes(tx *gorm.DB, func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time, lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) { ) (time.Time, types.StateUpdate, bool) {
@ -737,3 +718,78 @@ func ExpireExpiredNodes(tx *gorm.DB,
return started, types.StateUpdate{}, false return started, types.StateUpdate{}, false
} }
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
// a certain amount of time.
// It is used to delete ephemeral nodes that have disconnected and should be
// cleaned up.
type EphemeralGarbageCollector struct {
mu deadlock.Mutex
deleteFunc func(types.NodeID)
toBeDeleted map[types.NodeID]*time.Timer
deleteCh chan types.NodeID
cancelCh chan struct{}
}
// NewEphemeralGarbageCollector creates a new EphemeralGarbageCollector, it takes
// a deleteFunc that will be called when a node is scheduled for deletion.
func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarbageCollector {
return &EphemeralGarbageCollector{
toBeDeleted: make(map[types.NodeID]*time.Timer),
deleteCh: make(chan types.NodeID, 10),
cancelCh: make(chan struct{}),
deleteFunc: deleteFunc,
}
}
// Close stops the garbage collector.
func (e *EphemeralGarbageCollector) Close() {
e.cancelCh <- struct{}{}
}
// Schedule schedules a node for deletion after the expiry duration.
func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) {
e.mu.Lock()
defer e.mu.Unlock()
timer := time.NewTimer(expiry)
e.toBeDeleted[nodeID] = timer
go func() {
select {
case _, ok := <-timer.C:
if ok {
e.deleteCh <- nodeID
}
}
}()
}
// Cancel cancels the deletion of a node.
func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) {
e.mu.Lock()
defer e.mu.Unlock()
if timer, ok := e.toBeDeleted[nodeID]; ok {
timer.Stop()
delete(e.toBeDeleted, nodeID)
}
}
// Start starts the garbage collector.
func (e *EphemeralGarbageCollector) Start() {
for {
select {
case <-e.cancelCh:
return
case nodeID := <-e.deleteCh:
e.mu.Lock()
delete(e.toBeDeleted, nodeID)
e.mu.Unlock()
go e.deleteFunc(nodeID)
}
}
}

View File

@ -1,17 +1,23 @@
package db package db
import ( import (
"crypto/rand"
"fmt" "fmt"
"math/big"
"net/netip" "net/netip"
"regexp" "regexp"
"strconv" "strconv"
"sync"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@ -30,7 +36,6 @@ func (s *Suite) TestGetNode(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
@ -39,7 +44,7 @@ func (s *Suite) TestGetNode(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(node) trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -61,7 +66,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -69,7 +73,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -93,7 +97,6 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -101,7 +104,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -145,7 +148,6 @@ func (s *Suite) TestListPeers(c *check.C) {
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
pakID := uint(pak.ID)
for index := 0; index <= 10; index++ { for index := 0; index <= 10; index++ {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
@ -157,7 +159,7 @@ func (s *Suite) TestListPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index), Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -197,7 +199,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
for index := 0; index <= 10; index++ { for index := 0; index <= 10; index++ {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(stor[index%2].key.ID)
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
node := types.Node{ node := types.Node{
@ -208,7 +209,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index), Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID, UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(stor[index%2].key.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -283,7 +284,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
@ -292,7 +292,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Expiry: &time.Time{}, Expiry: &time.Time{},
} }
db.DB.Save(node) db.DB.Save(node)
@ -328,7 +328,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
machineKey2 := key.NewMachine() machineKey2 := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -337,7 +336,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
GivenName: "hostname-1", GivenName: "hostname-1",
UserID: user1.ID, UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(node) trx := db.DB.Save(node)
@ -372,7 +371,6 @@ func (s *Suite) TestSetTags(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -380,7 +378,7 @@ func (s *Suite) TestSetTags(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(node) trx := db.DB.Save(node)
@ -566,7 +564,6 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
route2 := netip.MustParsePrefix("10.11.0.0/24") route2 := netip.MustParsePrefix("10.11.0.0/24")
v4 := netip.MustParseAddr("100.64.0.1") v4 := netip.MustParseAddr("100.64.0.1")
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -574,7 +571,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
Hostname: "test", Hostname: "test",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"}, RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
@ -600,3 +597,121 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(enabledRoutes, check.HasLen, 4) c.Assert(enabledRoutes, check.HasLen, 4)
} }
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
want := []types.NodeID{1, 3}
got := []types.NodeID{}
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
got = append(got, ni)
})
go e.Start()
e.Schedule(1, 1*time.Second)
e.Schedule(2, 2*time.Second)
e.Schedule(3, 3*time.Second)
e.Schedule(4, 4*time.Second)
e.Cancel(2)
e.Cancel(4)
time.Sleep(6 * time.Second)
e.Close()
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("wrong nodes deleted, unexpected result (-want +got):\n%s", diff)
}
}
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
var got []types.NodeID
var mu sync.Mutex
want := 1000
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
defer mu.Unlock()
mu.Lock()
time.Sleep(time.Duration(generateRandomNumber(t, 3)) * time.Millisecond)
got = append(got, ni)
})
go e.Start()
for i := 0; i < want; i++ {
go e.Schedule(types.NodeID(i), 1*time.Second)
}
time.Sleep(10 * time.Second)
e.Close()
if len(got) != want {
t.Errorf("expected %d, got %d", want, len(got))
}
}
func generateRandomNumber(t *testing.T, max int64) int64 {
t.Helper()
maxB := big.NewInt(max)
n, err := rand.Int(rand.Reader, maxB)
if err != nil {
t.Fatalf("getting random number: %s", err)
}
return n.Int64() + 1
}
func TestListEphemeralNodes(t *testing.T) {
db, err := newTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser("test")
assert.NoError(t, err)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
assert.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
assert.NoError(t, err)
node := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
nodeEph := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "ephemeral",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pakEph.ID),
}
err = db.DB.Save(&node).Error
assert.NoError(t, err)
err = db.DB.Save(&nodeEph).Error
assert.NoError(t, err)
nodes, err := db.ListNodes()
assert.NoError(t, err)
ephemeralNodes, err := db.ListEphemeralNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Len(t, ephemeralNodes, 1)
assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID)
assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID)
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/ptr"
) )
var ( var (
@ -197,10 +198,9 @@ func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
} }
nodes := types.Nodes{} nodes := types.Nodes{}
pakID := uint(pak.ID)
if err := tx. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Where(&types.Node{AuthKeyID: &pakID}). Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
Find(&nodes).Error; err != nil { Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }

View File

@ -6,7 +6,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "tailscale.com/types/ptr"
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) {
@ -76,13 +76,12 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -99,13 +98,12 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 1, ID: 1,
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -127,77 +125,6 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
c.Assert(key.ID, check.Equals, pak.ID) c.Assert(key.ID, check.Equals, pak.ID)
} }
func (*Suite) TestEphemeralKeyReusable(c *check.C) {
user, err := db.CreateUser("test7")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, true, true, nil, nil)
c.Assert(err, check.IsNil)
now := time.Now().Add(-time.Second * 30)
pakID := uint(pak.ID)
node := types.Node{
ID: 0,
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
LastSeen: &now,
AuthKeyID: &pakID,
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
_, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil)
db.Write(func(tx *gorm.DB) error {
DeleteExpiredEphemeralNodes(tx, time.Second*20)
return nil
})
// The machine record should have been deleted
_, err = db.getNode("test7", "testest")
c.Assert(err, check.NotNil)
}
func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
user, err := db.CreateUser("test7")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
c.Assert(err, check.IsNil)
now := time.Now().Add(-time.Second * 30)
pakId := uint(pak.ID)
node := types.Node{
ID: 0,
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
LastSeen: &now,
AuthKeyID: &pakId,
}
db.DB.Save(&node)
_, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.NotNil)
_, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil)
db.Write(func(tx *gorm.DB) error {
DeleteExpiredEphemeralNodes(tx, time.Second*20)
return nil
})
// The machine record should have been deleted
_, err = db.getNode("test7", "testest")
c.Assert(err, check.NotNil)
}
func (*Suite) TestExpirePreauthKey(c *check.C) { func (*Suite) TestExpirePreauthKey(c *check.C) {
user, err := db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View File

@ -14,6 +14,7 @@ import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/ptr"
) )
var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] { var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] {
@ -43,13 +44,12 @@ func (s *Suite) TestGetRoutes(c *check.C) {
RoutableIPs: []netip.Prefix{route}, RoutableIPs: []netip.Prefix{route},
} }
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "test_get_route_node", Hostname: "test_get_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
@ -95,13 +95,12 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
RoutableIPs: []netip.Prefix{route, route2}, RoutableIPs: []netip.Prefix{route, route2},
} }
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
@ -169,13 +168,12 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
hostInfo1 := tailcfg.Hostinfo{ hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route, route2}, RoutableIPs: []netip.Prefix{route, route2},
} }
pakID := uint(pak.ID)
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Hostinfo: &hostInfo1, Hostinfo: &hostInfo1,
} }
trx := db.DB.Save(&node1) trx := db.DB.Save(&node1)
@ -199,7 +197,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Hostinfo: &hostInfo2, Hostinfo: &hostInfo2,
} }
db.DB.Save(&node2) db.DB.Save(&node2)
@ -253,13 +251,12 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
} }
now := time.Now() now := time.Now()
pakID := uint(pak.ID)
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
Hostinfo: &hostInfo1, Hostinfo: &hostInfo1,
LastSeen: &now, LastSeen: &now,
} }

View File

@ -36,10 +36,18 @@ func (s *Suite) ResetDB(c *check.C) {
// } // }
var err error var err error
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") db, err = newTestDB()
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
}
func newTestDB() (*HSDatabase, error) {
var err error
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
return nil, err
}
log.Printf("database path: %s", tmpDir+"/headscale_test.db") log.Printf("database path: %s", tmpDir+"/headscale_test.db")
@ -53,6 +61,8 @@ func (s *Suite) ResetDB(c *check.C) {
"", "",
) )
if err != nil { if err != nil {
c.Fatal(err) return nil, err
} }
return db, nil
} }

View File

@ -5,6 +5,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/ptr"
) )
func (s *Suite) TestCreateAndDestroyUser(c *check.C) { func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
@ -46,13 +47,12 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
@ -100,13 +100,12 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testnode", Hostname: "testnode",
UserID: oldUser.ID, UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID, AuthKeyID: ptr.To(pak.ID),
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)

View File

@ -135,6 +135,18 @@ func (m *mapSession) resetKeepAlive() {
m.keepAliveTicker.Reset(m.keepAlive) m.keepAliveTicker.Reset(m.keepAlive)
} }
func (m *mapSession) beforeServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Cancel(m.node.ID)
}
}
func (m *mapSession) afterServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
}
}
// serve handles non-streaming requests. // serve handles non-streaming requests.
func (m *mapSession) serve() { func (m *mapSession) serve() {
// TODO(kradalby): A set todos to harden: // TODO(kradalby): A set todos to harden:
@ -180,6 +192,8 @@ func (m *mapSession) serve() {
// //
//nolint:gocyclo //nolint:gocyclo
func (m *mapSession) serveLongPoll() { func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll()
// Clean up the session when the client disconnects // Clean up the session when the client disconnects
defer func() { defer func() {
m.cancelChMu.Lock() m.cancelChMu.Lock()
@ -197,6 +211,7 @@ func (m *mapSession) serveLongPoll() {
m.pollFailoverRoutes("node closing connection", m.node) m.pollFailoverRoutes("node closing connection", m.node)
} }
m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
}() }()

View File

@ -119,7 +119,7 @@ type Node struct {
ForcedTags StringList ForcedTags StringList
// TODO(kradalby): This seems like irrelevant information? // TODO(kradalby): This seems like irrelevant information?
AuthKeyID *uint `sql:"DEFAULT:NULL"` AuthKeyID *uint64 `sql:"DEFAULT:NULL"`
AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"` AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"`
LastSeen *time.Time LastSeen *time.Time

View File

@ -297,6 +297,122 @@ func TestEphemeral(t *testing.T) {
} }
} }
// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not
// deleted by accident if they are still online and active.
func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
headscale, err := scenario.Headscale(
hsic.WithTestName("ephemeral2006"),
hsic.WithConfigEnv(map[string]string{
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s",
}),
)
assertNoErrHeadscaleEnv(t, err)
for userName, clientCount := range spec {
err = scenario.CreateUser(userName)
if err != nil {
t.Fatalf("failed to create user %s: %s", userName, err)
}
err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...)
if err != nil {
t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err)
}
key, err := scenario.CreatePreAuthKey(userName, true, true)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
}
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
if err != nil {
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
// All ephemeral nodes should be online and reachable.
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
// Take down all clients, this should start an expiry timer for each.
for _, client := range allClients {
err := client.Down()
if err != nil {
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
}
}
// Wait a bit and bring up the clients again before the expiry
// time of the ephemeral nodes.
// Nodes should be able to reconnect and work fine.
time.Sleep(30 * time.Second)
for _, client := range allClients {
err := client.Up()
if err != nil {
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success = pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
// Take down all clients, this should start an expiry timer for each.
for _, client := range allClients {
err := client.Down()
if err != nil {
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
}
}
// This time wait for all of the nodes to expire and check that they are no longer
// registered.
time.Sleep(3 * time.Minute)
for userName := range spec {
nodes, err := headscale.ListNodesInUser(userName)
if err != nil {
log.Error().
Err(err).
Str("user", userName).
Msg("Error listing nodes in user")
return
}
if len(nodes) != 0 {
t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName)
}
}
}
func TestPingAllByHostname(t *testing.T) { func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()