mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-07 20:04:00 +01:00
node: implement disable key expiry via CLI and API
Wire up the disable_expiry proto field through gRPC, state, database and CLI layers to allow clearing a node expiry. - Update NodeSetExpiry to accept *time.Time (nil disables expiry) - Fix SetNodeExpiry to persist via db.NodeSetExpiry directly instead of persistNodeToDB which omits the expiry field - Add mutual exclusion validation for expiry vs disable_expiry in gRPC - Add --disable flag to headscale nodes expire CLI command - Fix missing return after ErrorOutput in both expire and disable paths - Add unit test TestDisableNodeExpiry in db/node_test.go - Add integration test TestDisableNodeExpiry in general_test.go Fixes #2681 Co-authored-by: Marco Santos <me@marcopsantos.com>
This commit is contained in:
parent
544a2bd7cd
commit
82f2faaa32
@ -55,6 +55,7 @@ func init() {
|
||||
|
||||
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
|
||||
expireNodeCmd.Flags().BoolP("disable", "d", false, "Disable key expiry (node will never expire)")
|
||||
|
||||
err = expireNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
@ -260,9 +261,11 @@ var listNodeRoutesCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
var expireNodeCmd = &cobra.Command{
|
||||
Use: "expire",
|
||||
Short: "Expire (log out) a node in your network",
|
||||
Long: "Expiring a node will keep the node in the database and force it to reauthenticate.",
|
||||
Use: "expire",
|
||||
Short: "Expire (log out) a node in your network",
|
||||
Long: `Expiring a node will keep the node in the database and force it to reauthenticate.
|
||||
|
||||
Use --disable to disable key expiry (node will never expire).`,
|
||||
Aliases: []string{"logout", "exp", "e"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
@ -274,6 +277,49 @@ var expireNodeCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
disableExpiry, err := cmd.Flags().GetBool("disable")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting disable flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
// Handle disable expiry - node will never expire.
|
||||
if disableExpiry {
|
||||
request := &v1.ExpireNodeRequest{
|
||||
NodeId: identifier,
|
||||
DisableExpiry: true,
|
||||
}
|
||||
|
||||
response, err := client.ExpireNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot disable node expiry: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node expiry disabled", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiry, err := cmd.Flags().GetString("expiry")
|
||||
@ -303,10 +349,6 @@ var expireNodeCmd = &cobra.Command{
|
||||
}
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
request := &v1.ExpireNodeRequest{
|
||||
NodeId: identifier,
|
||||
Expiry: timestamppb.New(expiryTime),
|
||||
@ -322,6 +364,8 @@ var expireNodeCmd = &cobra.Command{
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if now.Equal(expiryTime) || now.After(expiryTime) {
|
||||
|
||||
@ -212,7 +212,9 @@ func (h *Headscale) handleLogout(
|
||||
|
||||
// Update the internal state with the nodes new expiry, meaning it is
|
||||
// logged out.
|
||||
updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry)
|
||||
expiry := req.Expiry
|
||||
|
||||
updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), &expiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
@ -587,7 +587,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
|
||||
// Expire the node
|
||||
expiredTime := time.Now().Add(-1 * time.Hour)
|
||||
_, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime)
|
||||
_, _, err = app.state.SetNodeExpiry(node.ID(), &expiredTime)
|
||||
|
||||
return "", err
|
||||
},
|
||||
|
||||
@ -315,16 +315,15 @@ func RenameNode(tx *gorm.DB,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error {
|
||||
func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry *time.Time) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
}
|
||||
|
||||
// NodeSetExpiry takes a Node struct and a new expiry time.
|
||||
func NodeSetExpiry(tx *gorm.DB,
|
||||
nodeID types.NodeID, expiry time.Time,
|
||||
) error {
|
||||
// NodeSetExpiry sets a new expiry time for a node.
|
||||
// If expiry is nil, the node's expiry is disabled (node will never expire).
|
||||
func NodeSetExpiry(tx *gorm.DB, nodeID types.NodeID, expiry *time.Time) error {
|
||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
||||
}
|
||||
|
||||
|
||||
@ -128,7 +128,7 @@ func TestExpireNode(t *testing.T) {
|
||||
assert.False(t, nodeFromDB.IsExpired())
|
||||
|
||||
now := time.Now()
|
||||
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
||||
err = db.NodeSetExpiry(nodeFromDB.ID, &now)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
@ -137,6 +137,48 @@ func TestExpireNode(t *testing.T) {
|
||||
assert.True(t, nodeFromDB.IsExpired())
|
||||
}
|
||||
|
||||
func TestDisableNodeExpiry(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pakID := pak.ID
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: &pakID,
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
db.DB.Save(node)
|
||||
|
||||
// Set an expiry first.
|
||||
past := time.Now().Add(-time.Hour)
|
||||
err = db.NodeSetExpiry(node.ID, &past)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nodeFromDB.IsExpired(), "node should be expired")
|
||||
|
||||
// Disable expiry by setting nil.
|
||||
err = db.NodeSetExpiry(node.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nodeFromDB.IsExpired(), "node should not be expired after disabling expiry")
|
||||
assert.Nil(t, nodeFromDB.Expiry, "expiry should be nil after disabling")
|
||||
}
|
||||
|
||||
func TestSetTags(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -451,12 +451,40 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
ctx context.Context,
|
||||
request *v1.ExpireNodeRequest,
|
||||
) (*v1.ExpireNodeResponse, error) {
|
||||
if request.GetDisableExpiry() && request.GetExpiry() != nil {
|
||||
return nil, status.Error(
|
||||
codes.InvalidArgument,
|
||||
"cannot set both disable_expiry and expiry",
|
||||
)
|
||||
}
|
||||
|
||||
// Handle disable expiry request - node will never expire.
|
||||
if request.GetDisableExpiry() {
|
||||
node, nodeChange, err := api.h.state.SetNodeExpiry(
|
||||
types.NodeID(request.GetNodeId()), nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
EmbedObject(node).
|
||||
Msg("node expiry disabled")
|
||||
|
||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
|
||||
expiry := time.Now()
|
||||
if request.GetExpiry() != nil {
|
||||
expiry = request.GetExpiry().AsTime()
|
||||
}
|
||||
|
||||
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry)
|
||||
node, nodeChange, err := api.h.state.SetNodeExpiry(
|
||||
types.NodeID(request.GetNodeId()), &expiry,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -467,7 +495,7 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
log.Trace().
|
||||
Caller().
|
||||
EmbedObject(node).
|
||||
Time(zf.ExpiresAt, *node.AsStruct().Expiry).
|
||||
Time(zf.ExpiresAt, expiry).
|
||||
Msg("node expired")
|
||||
|
||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||
|
||||
@ -638,22 +638,38 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] {
|
||||
}
|
||||
|
||||
// SetNodeExpiry updates the expiration time for a node.
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) {
|
||||
// If expiry is nil, the node's expiry is disabled (node will never expire).
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry *time.Time) (types.NodeView, change.Change, error) {
|
||||
// Update NodeStore before database to ensure consistency. The NodeStore update is
|
||||
// blocking and will be the source of truth for the batcher. The database update must
|
||||
// make the exact same change. If the database update fails, the NodeStore change will
|
||||
// remain, but since we return an error, no change notification will be sent to the
|
||||
// batcher, preventing inconsistent state propagation.
|
||||
expiryPtr := expiry
|
||||
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
|
||||
node.Expiry = &expiryPtr
|
||||
node.Expiry = expiry
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
|
||||
}
|
||||
|
||||
return s.persistNodeToDB(n)
|
||||
// Persist expiry change to database directly since persistNodeToDB omits expiry.
|
||||
err := s.db.NodeSetExpiry(nodeID, expiry)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("setting node expiry in database: %w", err)
|
||||
}
|
||||
|
||||
// Update policy manager and generate change notification.
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return n, change.Change{}, fmt.Errorf("updating policy manager after setting expiry: %w", err)
|
||||
}
|
||||
|
||||
if c.IsEmpty() {
|
||||
c = change.NodeAdded(n.ID())
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetNodeTags assigns tags to a node, making it a "tagged node".
|
||||
|
||||
@ -1166,6 +1166,103 @@ func TestSetNodeExpiryInFuture(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDisableNodeExpiry tests disabling key expiry for a node.
|
||||
// First sets an expiry, then disables it and verifies the node never expires.
|
||||
func TestDisableNodeExpiry(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
Users: []string{"user1"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("disableexpiry"))
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
|
||||
// First set an expiry on the node.
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale", "nodes", "expire",
|
||||
"--identifier", "1",
|
||||
"--output", "json",
|
||||
"--expiry", time.Now().Add(time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = json.Unmarshal([]byte(result), &node)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, node.GetExpiry(), "node should have an expiry set")
|
||||
|
||||
// Now disable the expiry.
|
||||
result, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale", "nodes", "expire",
|
||||
"--identifier", "1",
|
||||
"--output", "json",
|
||||
"--disable",
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
var nodeDisabled v1.Node
|
||||
err = json.Unmarshal([]byte(result), &nodeDisabled)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expiry should be nil (or zero time) when disabled.
|
||||
if nodeDisabled.GetExpiry() != nil {
|
||||
require.True(t, nodeDisabled.GetExpiry().AsTime().IsZero(),
|
||||
"node expiry should be zero/nil after disabling")
|
||||
}
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
err = nodeKey.UnmarshalText([]byte(nodeDisabled.GetNodeKey()))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify peers see the node as not expired.
|
||||
for _, client := range allClients {
|
||||
if client.Hostname() == nodeDisabled.GetName() {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EventuallyWithT(
|
||||
t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
peerStatus, ok := status.Peer[nodeKey]
|
||||
assert.True(ct, ok, "node key should be present in peer list")
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Node should not be expired.
|
||||
assert.Falsef(
|
||||
ct,
|
||||
peerStatus.Expired,
|
||||
"node %q should not be marked as expired after disabling expiry",
|
||||
peerStatus.HostName,
|
||||
)
|
||||
}, 3*time.Minute, 5*time.Second, "waiting for disabled expiry to propagate",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeOnlineStatus(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user