mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-14 13:51:01 +02:00
implement oauth2 refresh tokens with background refreshing
This commit is contained in:
parent
7d3e7a28e2
commit
fd8bd3f6a6
@ -355,13 +355,16 @@ unix_socket_permission: "0770"
|
|||||||
# # Note: enabling this will cause `oidc.expiry` to be ignored.
|
# # Note: enabling this will cause `oidc.expiry` to be ignored.
|
||||||
# use_expiry_from_token: false
|
# use_expiry_from_token: false
|
||||||
#
|
#
|
||||||
# # The OIDC scopes to use, defaults to "openid", "profile" and "email".
|
# # Grace period for invalidating sessions of nodes that have been offline
|
||||||
# # Custom scopes can be configured as needed, be sure to always include the
|
# # Sessions for nodes offline longer than this duration will be invalidated (new SSO login required).
|
||||||
# # required "openid" scope.
|
# # Default: 30m
|
||||||
# scope: ["openid", "profile", "email"]
|
# session_invalidation_grace_period: 30m
|
||||||
#
|
#
|
||||||
# # Provide custom key/value pairs which get sent to the identity provider's
|
# # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
|
||||||
# # authorization endpoint.
|
# # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
|
||||||
|
# # To enable automatic token refresh, add "offline_access" to the scope list.
|
||||||
|
#
|
||||||
|
# scope: ["openid", "profile", "email", "offline_access"]
|
||||||
# extra_params:
|
# extra_params:
|
||||||
# domain_hint: example.com
|
# domain_hint: example.com
|
||||||
#
|
#
|
||||||
|
@ -19,7 +19,13 @@ OpenID requires configuration in Headscale and your identity provider:
|
|||||||
Additionally, there might be some useful hints in the [Identity provider specific
|
Additionally, there might be some useful hints in the [Identity provider specific
|
||||||
configuration](#identity-provider-specific-configuration) section below.
|
configuration](#identity-provider-specific-configuration) section below.
|
||||||
|
|
||||||
### Basic configuration
|
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
|
||||||
|
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
|
||||||
|
# To enable automatic token refresh, add "offline_access" to the scope list.
|
||||||
|
scope: ["openid", "profile", "email", "custom"]
|
||||||
|
# Optional: Passed on to the browser login request – used to tweak behaviour for the OIDC provider
|
||||||
|
extra_params:
|
||||||
|
domain_hint: example.com
|
||||||
|
|
||||||
A basic configuration connects Headscale to an identity provider and typically requires:
|
A basic configuration connects Headscale to an identity provider and typically requires:
|
||||||
|
|
||||||
|
@ -302,6 +302,39 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthProviderOIDC) {
|
||||||
|
refreshTicker := time.NewTicker(15 * time.Minute)
|
||||||
|
gracePeriodTicker := time.NewTicker(15 * time.Minute)
|
||||||
|
defer refreshTicker.Stop()
|
||||||
|
defer gracePeriodTicker.Stop()
|
||||||
|
|
||||||
|
log.Info().Msg("OIDC: Background token refresh job started (checking every 15 minute for tokens expiring within 30 minutes)")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info().Caller().Msg("OIDC token refresh job is shutting down.")
|
||||||
|
return
|
||||||
|
|
||||||
|
// Refresh expired tokens every 15 minutes. Will be refreshed if their expiry is within the next 30 minutes.
|
||||||
|
case <-refreshTicker.C:
|
||||||
|
refreshCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
if err := oidcProvider.RefreshExpiredTokens(refreshCtx); err != nil {
|
||||||
|
log.Error().Err(err).Msg("OIDC: Failed to refresh expired tokens")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Invalidate sessions for nodes that have been offline for longer than the configured grace period
|
||||||
|
case <-gracePeriodTicker.C:
|
||||||
|
log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period")
|
||||||
|
gracePeriod := oidcProvider.cfg.SessionInvalidationGracePeriod
|
||||||
|
if err := h.db.InvalidateExpiredOIDCSessions(gracePeriod); err != nil {
|
||||||
|
log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
req interface{},
|
req interface{},
|
||||||
info *grpc.UnaryServerInfo,
|
info *grpc.UnaryServerInfo,
|
||||||
@ -617,6 +650,11 @@ func (h *Headscale) Serve() error {
|
|||||||
defer scheduleCancel()
|
defer scheduleCancel()
|
||||||
go h.scheduledTasks(scheduleCtx)
|
go h.scheduledTasks(scheduleCtx)
|
||||||
|
|
||||||
|
// Start OIDC token refresh background job if OIDC is enabled
|
||||||
|
if oidcProvider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||||
|
go h.oidcTokenRefreshJob(scheduleCtx, oidcProvider)
|
||||||
|
}
|
||||||
|
|
||||||
if zl.GlobalLevel() == zl.TraceLevel {
|
if zl.GlobalLevel() == zl.TraceLevel {
|
||||||
zerolog.RespLog = true
|
zerolog.RespLog = true
|
||||||
} else {
|
} else {
|
||||||
|
@ -927,7 +927,6 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msg("Schema recreation completed successfully")
|
log.Info().Msg("Schema recreation completed successfully")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Rollback: func(db *gorm.DB) error { return nil },
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
@ -936,6 +935,23 @@ AND auth_key_id NOT IN (
|
|||||||
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
|
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
|
||||||
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
|
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
|
||||||
// - Never write migrations that requires foreign keys to be disabled.
|
// - Never write migrations that requires foreign keys to be disabled.
|
||||||
|
{
|
||||||
|
ID: "202507140001",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
// Create OIDC sessions table for managing OIDC refresh tokens
|
||||||
|
// This replaces the old OIDC token columns in the users table
|
||||||
|
if !tx.Migrator().HasTable(&types.OIDCSession{}) {
|
||||||
|
err := tx.AutoMigrate(&types.OIDCSession{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating OIDC sessions table: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug().Msg("Created OIDC sessions table")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -317,6 +317,15 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
|
|||||||
func DeleteNode(tx *gorm.DB,
|
func DeleteNode(tx *gorm.DB,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) error {
|
||||||
|
// Invalidate OIDC sessions for this node before deletion
|
||||||
|
if err := InvalidateOIDCSessionsForNode(tx, node.ID); err != nil {
|
||||||
|
log.Error().Err(err).
|
||||||
|
Uint64("node_id", uint64(node.ID)).
|
||||||
|
Str("node", node.Hostname).
|
||||||
|
Msg("Failed to invalidate OIDC sessions for deleted node")
|
||||||
|
// Continue with deletion even if session invalidation fails
|
||||||
|
}
|
||||||
|
|
||||||
// Unscoped causes the node to be fully removed from the database.
|
// Unscoped causes the node to be fully removed from the database.
|
||||||
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
@ -332,6 +341,14 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
|||||||
nodeID types.NodeID,
|
nodeID types.NodeID,
|
||||||
) error {
|
) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
// Invalidate OIDC sessions for this node before deletion
|
||||||
|
if err := InvalidateOIDCSessionsForNode(tx, nodeID); err != nil {
|
||||||
|
log.Error().Err(err).
|
||||||
|
Uint64("node_id", uint64(nodeID)).
|
||||||
|
Msg("Failed to invalidate OIDC sessions for deleted ephemeral node")
|
||||||
|
// Continue with deletion even if session invalidation fails
|
||||||
|
}
|
||||||
|
|
||||||
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
|
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
103
hscontrol/db/oidc_session.go
Normal file
103
hscontrol/db/oidc_session.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InvalidateOIDCSessionsForNode invalidates all active OIDC sessions for a specific node
|
||||||
|
func (hsdb *HSDatabase) InvalidateOIDCSessionsForNode(nodeID types.NodeID) error {
|
||||||
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
return InvalidateOIDCSessionsForNode(tx, nodeID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateOIDCSessionsForNode invalidates all active OIDC sessions for a specific node
|
||||||
|
func InvalidateOIDCSessionsForNode(tx *gorm.DB, nodeID types.NodeID) error {
|
||||||
|
log.Debug().
|
||||||
|
Uint64("node_id", uint64(nodeID)).
|
||||||
|
Msg("OIDC: Invalidating sessions for node")
|
||||||
|
|
||||||
|
result := tx.Model(&types.OIDCSession{}).
|
||||||
|
Where("node_id = ? AND is_active = ?", nodeID, true).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"is_active": false,
|
||||||
|
"last_seen_at": time.Now(),
|
||||||
|
"refresh_token": nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).
|
||||||
|
Uint64("node_id", uint64(nodeID)).
|
||||||
|
Msg("OIDC: Failed to invalidate sessions for node")
|
||||||
|
return fmt.Errorf("failed to invalidate OIDC sessions for node %d: %w", nodeID, result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected > 0 {
|
||||||
|
log.Info().
|
||||||
|
Uint64("node_id", uint64(nodeID)).
|
||||||
|
Int64("sessions_invalidated", result.RowsAffected).
|
||||||
|
Msg("OIDC: Invalidated sessions for disconnected node")
|
||||||
|
} else {
|
||||||
|
log.Debug().
|
||||||
|
Uint64("node_id", uint64(nodeID)).
|
||||||
|
Msg("OIDC: No active sessions found for node")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline too long
|
||||||
|
func (hsdb *HSDatabase) InvalidateExpiredOIDCSessions(offlineGracePeriod time.Duration) error {
|
||||||
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
|
return InvalidateExpiredOIDCSessions(tx, offlineGracePeriod)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline too long
|
||||||
|
func InvalidateExpiredOIDCSessions(tx *gorm.DB, offlineGracePeriod time.Duration) error {
|
||||||
|
// Find active sessions where the node has been offline for longer than the grace period
|
||||||
|
cutoff := time.Now().Add(-offlineGracePeriod)
|
||||||
|
|
||||||
|
var sessions []types.OIDCSession
|
||||||
|
err := tx.Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
|
||||||
|
Where("oidc_sessions.is_active = ? AND nodes.last_seen IS NOT NULL AND nodes.last_seen < ?", true, cutoff).
|
||||||
|
Find(&sessions).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to find expired OIDC sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sessions) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate these sessions
|
||||||
|
sessionIDs := make([]string, len(sessions))
|
||||||
|
for i, session := range sessions {
|
||||||
|
sessionIDs[i] = session.SessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tx.Model(&types.OIDCSession{}).
|
||||||
|
Where("session_id IN ?", sessionIDs).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"is_active": false,
|
||||||
|
"last_seen_at": time.Now(),
|
||||||
|
"refresh_token": nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("failed to invalidate expired OIDC sessions: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Int("sessions_invalidated", len(sessions)).
|
||||||
|
Dur("grace_period", offlineGracePeriod).
|
||||||
|
Msg("OIDC: Invalidated sessions for nodes offline beyond grace period")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
311
hscontrol/db/oidc_session_test.go
Normal file
311
hscontrol/db/oidc_session_test.go
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (*Suite) TestInvalidateOIDCSessionsForNode(c *check.C) {
|
||||||
|
user, err := db.CreateUser(types.User{Name: "test-oidc-user"})
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
|
nodeExpiry := time.Now().Add(24 * time.Hour)
|
||||||
|
node := &types.Node{
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: "test-node",
|
||||||
|
GivenName: "test-node",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
AuthKeyID: nil, // No auth key for OIDC
|
||||||
|
Expiry: &nodeExpiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
db.DB.Save(node)
|
||||||
|
|
||||||
|
// Create an active OIDC session
|
||||||
|
sessionID := "test-session-id-1"
|
||||||
|
registrationID := types.RegistrationID("test-reg-id-1")
|
||||||
|
tokenExpiry := time.Now().Add(1 * time.Hour)
|
||||||
|
session := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node.ID),
|
||||||
|
SessionID: sessionID,
|
||||||
|
RegistrationID: registrationID,
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
TokenExpiry: &tokenExpiry,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
db.DB.Save(session)
|
||||||
|
|
||||||
|
// Verify session is active
|
||||||
|
var checkSession types.OIDCSession
|
||||||
|
err = db.DB.Where("node_id = ?", node.ID).First(&checkSession).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(checkSession.IsActive, check.Equals, true)
|
||||||
|
|
||||||
|
// Invalidate the session
|
||||||
|
err = db.InvalidateOIDCSessionsForNode(types.NodeID(node.ID))
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Verify session is now inactive
|
||||||
|
err = db.DB.Where("node_id = ?", node.ID).First(&checkSession).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(checkSession.IsActive, check.Equals, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestInvalidateExpiredOIDCSessions(c *check.C) {
|
||||||
|
user, err := db.CreateUser(types.User{Name: "test-oidc-expire-user"})
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Create nodes for testing
|
||||||
|
nodeKey1 := key.NewNode()
|
||||||
|
discoKey1 := key.NewDisco()
|
||||||
|
machineKey1 := key.NewMachine()
|
||||||
|
nodeExpiry1 := time.Now().Add(24 * time.Hour)
|
||||||
|
node1 := &types.Node{
|
||||||
|
MachineKey: machineKey1.Public(),
|
||||||
|
NodeKey: nodeKey1.Public(),
|
||||||
|
DiscoKey: discoKey1.Public(),
|
||||||
|
Hostname: "test-node-1",
|
||||||
|
GivenName: "test-node-1",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
AuthKeyID: nil, // No auth key for OIDC
|
||||||
|
Expiry: &nodeExpiry1,
|
||||||
|
}
|
||||||
|
db.DB.Save(node1)
|
||||||
|
|
||||||
|
nodeKey2 := key.NewNode()
|
||||||
|
discoKey2 := key.NewDisco()
|
||||||
|
machineKey2 := key.NewMachine()
|
||||||
|
nodeExpiry2 := time.Now().Add(24 * time.Hour)
|
||||||
|
node2 := &types.Node{
|
||||||
|
MachineKey: machineKey2.Public(),
|
||||||
|
NodeKey: nodeKey2.Public(),
|
||||||
|
DiscoKey: discoKey2.Public(),
|
||||||
|
Hostname: "test-node-2",
|
||||||
|
GivenName: "test-node-2",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
AuthKeyID: nil, // No auth key for OIDC
|
||||||
|
Expiry: &nodeExpiry2,
|
||||||
|
}
|
||||||
|
db.DB.Save(node2)
|
||||||
|
|
||||||
|
// Create sessions with different expiry times
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Session 1: Expired token, last seen within grace period
|
||||||
|
expiredTime1 := now.Add(-1 * time.Hour)
|
||||||
|
lastSeen1 := now.Add(-5 * time.Minute)
|
||||||
|
node1.LastSeen = &lastSeen1
|
||||||
|
db.DB.Save(node1)
|
||||||
|
session1 := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node1.ID),
|
||||||
|
SessionID: "expired-session-1",
|
||||||
|
RegistrationID: types.RegistrationID("reg-1"),
|
||||||
|
RefreshToken: "refresh-1",
|
||||||
|
TokenExpiry: &expiredTime1,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
db.DB.Save(session1)
|
||||||
|
|
||||||
|
// Session 2: Expired token, last seen outside grace period
|
||||||
|
expiredTime2 := now.Add(-2 * time.Hour)
|
||||||
|
lastSeen2 := now.Add(-20 * time.Minute)
|
||||||
|
node2.LastSeen = &lastSeen2
|
||||||
|
db.DB.Save(node2)
|
||||||
|
session2 := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node2.ID),
|
||||||
|
SessionID: "expired-session-2",
|
||||||
|
RegistrationID: types.RegistrationID("reg-2"),
|
||||||
|
RefreshToken: "refresh-2",
|
||||||
|
TokenExpiry: &expiredTime2,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
db.DB.Save(session2)
|
||||||
|
|
||||||
|
// Create a third node for session 3
|
||||||
|
nodeKey3 := key.NewNode()
|
||||||
|
discoKey3 := key.NewDisco()
|
||||||
|
machineKey3 := key.NewMachine()
|
||||||
|
nodeExpiry3 := time.Now().Add(24 * time.Hour)
|
||||||
|
node3 := &types.Node{
|
||||||
|
MachineKey: machineKey3.Public(),
|
||||||
|
NodeKey: nodeKey3.Public(),
|
||||||
|
DiscoKey: discoKey3.Public(),
|
||||||
|
Hostname: "test-node-3",
|
||||||
|
GivenName: "test-node-3",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
AuthKeyID: nil, // No auth key for OIDC
|
||||||
|
Expiry: &nodeExpiry3,
|
||||||
|
}
|
||||||
|
db.DB.Save(node3)
|
||||||
|
|
||||||
|
// Session 3: Valid token
|
||||||
|
validTime := now.Add(1 * time.Hour)
|
||||||
|
session3 := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node3.ID),
|
||||||
|
SessionID: "valid-session",
|
||||||
|
RegistrationID: types.RegistrationID("reg-3"),
|
||||||
|
RefreshToken: "refresh-3",
|
||||||
|
TokenExpiry: &validTime,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
db.DB.Save(session3)
|
||||||
|
|
||||||
|
// Invalidate expired sessions with 10 minute grace period
|
||||||
|
err = db.InvalidateExpiredOIDCSessions(10 * time.Minute)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Check results
|
||||||
|
var checkSession1, checkSession2, checkSession3 types.OIDCSession
|
||||||
|
|
||||||
|
// Session 1: Should still be active (within grace period)
|
||||||
|
err = db.DB.Where("session_id = ?", "expired-session-1").First(&checkSession1).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(checkSession1.IsActive, check.Equals, true)
|
||||||
|
|
||||||
|
// Session 2: Should be inactive (outside grace period)
|
||||||
|
err = db.DB.Where("session_id = ?", "expired-session-2").First(&checkSession2).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(checkSession2.IsActive, check.Equals, false)
|
||||||
|
|
||||||
|
// Session 3: Should still be active (valid token)
|
||||||
|
err = db.DB.Where("session_id = ?", "valid-session").First(&checkSession3).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(checkSession3.IsActive, check.Equals, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestInvalidateOIDCSessionsWithNoSessions(c *check.C) {
|
||||||
|
// Test with non-existent node ID
|
||||||
|
err := db.InvalidateOIDCSessionsForNode(types.NodeID(99999))
|
||||||
|
c.Assert(err, check.IsNil) // Should not error even if no sessions exist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestInvalidateExpiredOIDCSessionsWithNoExpired(c *check.C) {
|
||||||
|
user, err := db.CreateUser(types.User{Name: "test-no-expired-user"})
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Create a node for the session
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
nodeExpiry := time.Now().Add(24 * time.Hour)
|
||||||
|
node := &types.Node{
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: "test-valid-node",
|
||||||
|
GivenName: "test-valid-node",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
AuthKeyID: nil, // No auth key for OIDC
|
||||||
|
Expiry: &nodeExpiry,
|
||||||
|
}
|
||||||
|
db.DB.Save(node)
|
||||||
|
|
||||||
|
// Create only valid sessions
|
||||||
|
validTime := time.Now().Add(24 * time.Hour)
|
||||||
|
session := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node.ID),
|
||||||
|
SessionID: "valid-only-session",
|
||||||
|
RegistrationID: types.RegistrationID("reg-valid"),
|
||||||
|
RefreshToken: "refresh-valid",
|
||||||
|
TokenExpiry: &validTime,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
db.DB.Save(session)
|
||||||
|
|
||||||
|
// Run invalidation
|
||||||
|
err = db.InvalidateExpiredOIDCSessions(10 * time.Minute)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Verify session is still active
|
||||||
|
var checkSession types.OIDCSession
|
||||||
|
err = db.DB.Where("session_id = ?", "valid-only-session").First(&checkSession).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(checkSession.IsActive, check.Equals, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestInvalidateOIDCSessionsTransaction(c *check.C) {
|
||||||
|
user, err := db.CreateUser(types.User{Name: "test-transaction-user"})
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Create multiple nodes and sessions (one session per node)
|
||||||
|
var nodeIDs []types.NodeID
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
// Create a node for each session
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
nodeExpiry := time.Now().Add(24 * time.Hour)
|
||||||
|
node := &types.Node{
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: fmt.Sprintf("test-transaction-node-%d", i),
|
||||||
|
GivenName: fmt.Sprintf("test-transaction-node-%d", i),
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
AuthKeyID: nil, // No auth key for OIDC
|
||||||
|
Expiry: &nodeExpiry,
|
||||||
|
}
|
||||||
|
db.DB.Save(node)
|
||||||
|
nodeIDs = append(nodeIDs, types.NodeID(node.ID))
|
||||||
|
|
||||||
|
// Create a session for this node
|
||||||
|
tokenExpiry := time.Now().Add(1 * time.Hour)
|
||||||
|
sessionID := fmt.Sprintf("session-%d", i)
|
||||||
|
session := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node.ID),
|
||||||
|
SessionID: sessionID,
|
||||||
|
RegistrationID: types.RegistrationID(fmt.Sprintf("reg-%d", i)),
|
||||||
|
RefreshToken: fmt.Sprintf("refresh-%d", i),
|
||||||
|
TokenExpiry: &tokenExpiry,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
result := db.DB.Create(session)
|
||||||
|
c.Assert(result.Error, check.IsNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all sessions are active
|
||||||
|
var count int64
|
||||||
|
db.DB.Model(&types.OIDCSession{}).Where("is_active = ?", true).Count(&count)
|
||||||
|
c.Assert(count, check.Equals, int64(3))
|
||||||
|
|
||||||
|
// Invalidate all sessions for the first node
|
||||||
|
err = db.InvalidateOIDCSessionsForNode(nodeIDs[0])
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
// Verify one session is now inactive, two still active
|
||||||
|
db.DB.Model(&types.OIDCSession{}).Where("is_active = ?", true).Count(&count)
|
||||||
|
c.Assert(count, check.Equals, int64(2))
|
||||||
|
|
||||||
|
db.DB.Model(&types.OIDCSession{}).Where("is_active = ?", false).Count(&count)
|
||||||
|
c.Assert(count, check.Equals, int64(1))
|
||||||
|
|
||||||
|
// Verify the correct session was invalidated
|
||||||
|
var inactiveSession types.OIDCSession
|
||||||
|
err = db.DB.Where("node_id = ? AND is_active = ?", nodeIDs[0], false).First(&inactiveSession).Error
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(inactiveSession.SessionID, check.Equals, "session-0")
|
||||||
|
}
|
@ -341,6 +341,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
verb = "Authenticated"
|
verb = "Authenticated"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create or update OIDC session for this node
|
||||||
|
if err := a.createOrUpdateOIDCSession(user, *registrationId, oauth2Token, nodeExpiry); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to create OIDC session")
|
||||||
|
// Don't fail the auth flow, just log the error
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): replace with go-elem
|
// TODO(kradalby): replace with go-elem
|
||||||
content, err := renderOIDCCallbackTemplate(user, verb)
|
content, err := renderOIDCCallbackTemplate(user, verb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -403,6 +409,182 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||||||
return oauth2Token, err
|
return oauth2Token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createOrUpdateOIDCSession creates or updates an OIDC session for a node
|
||||||
|
func (a *AuthProviderOIDC) createOrUpdateOIDCSession(user *types.User, registrationID types.RegistrationID, token *oauth2.Token, nodeExpiry time.Time) error {
|
||||||
|
|
||||||
|
if token.RefreshToken == "" {
|
||||||
|
log.Warn().
|
||||||
|
Str("user", user.Username()).
|
||||||
|
Str("registration_id", registrationID.String()).
|
||||||
|
Msg("OIDC: No refresh token in OAuth2 token, skipping session creation (check offline_access scope)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the node in the database - we'll get the most recent OIDC node for this user
|
||||||
|
var node types.Node
|
||||||
|
err := a.db.DB.Where("user_id = ? AND register_method = ?", user.ID, util.RegisterMethodOIDC).Order("created_at DESC").First(&node).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to find node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate session ID
|
||||||
|
sessionID, err := util.GenerateRandomStringURLSafe(32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate session ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create or update session
|
||||||
|
tokenExpiryUTC := token.Expiry.UTC()
|
||||||
|
session := &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: node.ID,
|
||||||
|
SessionID: sessionID,
|
||||||
|
RegistrationID: registrationID,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
TokenExpiry: &tokenExpiryUTC,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
session.LastRefreshedAt = &now
|
||||||
|
session.LastSeenAt = &now
|
||||||
|
|
||||||
|
// Try to update existing session first
|
||||||
|
var existingSession types.OIDCSession
|
||||||
|
err = a.db.DB.Where("user_id = ? AND node_id = ?", user.ID, node.ID).First(&existingSession).Error
|
||||||
|
if err == nil {
|
||||||
|
// Update existing session
|
||||||
|
existingSession.RefreshToken = token.RefreshToken
|
||||||
|
tokenExpiryUTC := token.Expiry.UTC()
|
||||||
|
existingSession.TokenExpiry = &tokenExpiryUTC
|
||||||
|
existingSession.LastRefreshedAt = &now
|
||||||
|
existingSession.LastSeenAt = &now
|
||||||
|
existingSession.IsActive = true
|
||||||
|
existingSession.RefreshCount = existingSession.RefreshCount + 1
|
||||||
|
|
||||||
|
err = a.db.DB.Save(&existingSession).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update OIDC session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Create new session
|
||||||
|
err = a.db.DB.Create(session).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create OIDC session: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshOIDCSession refreshes an expired OIDC session using the stored refresh token
|
||||||
|
// and updates the node expiry using the existing HandleNodeFromAuthPath flow
|
||||||
|
func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *types.OIDCSession) error {
|
||||||
|
|
||||||
|
if session.RefreshToken == "" {
|
||||||
|
return fmt.Errorf("no refresh token available for session %s", session.SessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenSource := a.oauth2Config.TokenSource(ctx, &oauth2.Token{
|
||||||
|
RefreshToken: session.RefreshToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
newToken, err := tokenSource.Token()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to refresh OIDC token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate new node expiry based on the access token expiry (not ID token)
|
||||||
|
// Access tokens determine when we need to refresh, so node should live as long as access token
|
||||||
|
nodeExpiry := a.determineNodeExpiry(newToken.Expiry)
|
||||||
|
|
||||||
|
// Load the node for logging
|
||||||
|
var node types.Node
|
||||||
|
err = a.db.DB.Preload("User").First(&node, session.NodeID).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the node expiry directly for token refresh
|
||||||
|
// We don't use HandleNodeFromAuthPath for refresh as the node is already registered
|
||||||
|
err = a.db.DB.Model(&node).Update("expiry", nodeExpiry).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update node expiry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the local node object for consistency
|
||||||
|
node.Expiry = &nodeExpiry
|
||||||
|
|
||||||
|
// Update the session with new token information
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
if newToken.RefreshToken != "" {
|
||||||
|
session.RefreshToken = newToken.RefreshToken
|
||||||
|
} else {
|
||||||
|
log.Debug().
|
||||||
|
Str("session_id", session.SessionID).
|
||||||
|
Msg("OIDC: No new refresh token received, keeping existing one")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store token expiry in UTC to avoid timezone issues
|
||||||
|
utcExpiry := newToken.Expiry.UTC()
|
||||||
|
session.TokenExpiry = &utcExpiry
|
||||||
|
session.LastRefreshedAt = &now
|
||||||
|
session.RefreshCount = session.RefreshCount + 1
|
||||||
|
|
||||||
|
// Save the updated session
|
||||||
|
err = a.db.DB.Save(session).Error
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).
|
||||||
|
Str("session_id", session.SessionID).
|
||||||
|
Msg("OIDC: Failed to save updated session to database")
|
||||||
|
return fmt.Errorf("failed to save updated session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshExpiredTokens checks for and refreshes OIDC sessions that will expire soon
|
||||||
|
func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error {
|
||||||
|
var sessions []types.OIDCSession
|
||||||
|
|
||||||
|
// Find active sessions with tokens expiring in the next 30 minutes
|
||||||
|
currentTime := time.Now().UTC()
|
||||||
|
threshold := currentTime.Add(30 * time.Minute)
|
||||||
|
|
||||||
|
// Only refresh tokens for sessions linked to OIDC-registered nodes
|
||||||
|
err := a.db.DB.Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
|
||||||
|
Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?",
|
||||||
|
true, threshold, "oidc").
|
||||||
|
Find(&sessions).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query sessions with expiring tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sessions) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("OIDC: Found %d sessions with tokens expiring soon, refreshing...", len(sessions))
|
||||||
|
|
||||||
|
for _, session := range sessions {
|
||||||
|
if err := a.RefreshOIDCSession(ctx, &session); err != nil {
|
||||||
|
log.Error().Err(err).
|
||||||
|
Str("session_id", session.SessionID).
|
||||||
|
Uint("user_id", session.UserID).
|
||||||
|
Uint64("node_id", uint64(session.NodeID)).
|
||||||
|
Msg("OIDC: Failed to refresh session, deactivating")
|
||||||
|
// Deactivate the session if refresh fails
|
||||||
|
session.Deactivate()
|
||||||
|
a.db.DB.Save(&session)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// extractIDToken extracts the ID token from the oauth2 token.
|
// extractIDToken extracts the ID token from the oauth2 token.
|
||||||
func (a *AuthProviderOIDC) extractIDToken(
|
func (a *AuthProviderOIDC) extractIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
529
hscontrol/oidc_test.go
Normal file
529
hscontrol/oidc_test.go
Normal file
@ -0,0 +1,529 @@
|
|||||||
|
package hscontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
zcache "zgo.at/zcache/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
||||||
|
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTestNode creates a test node for testing
|
||||||
|
func createTestNode(t *testing.T, hsdb *db.HSDatabase, user *types.User, hostname string) *types.Node {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
nodeExpiry := time.Now().Add(24 * time.Hour)
|
||||||
|
|
||||||
|
node := &types.Node{
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: hostname,
|
||||||
|
GivenName: hostname,
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
Expiry: &nodeExpiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := hsdb.DB.Create(node).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTestDB creates a test database
|
||||||
|
func setupTestDB(t *testing.T) *db.HSDatabase {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
hsdb, err := db.NewHeadscaleDatabase(
|
||||||
|
types.DatabaseConfig{
|
||||||
|
Type: types.DatabaseSqlite,
|
||||||
|
Sqlite: types.SqliteConfig{
|
||||||
|
Path: tmpDir + "/test.db",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
emptyCache(),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return hsdb
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOrUpdateOIDCSession(t *testing.T) {
|
||||||
|
hsdb := setupTestDB(t)
|
||||||
|
|
||||||
|
// Create test OIDC provider
|
||||||
|
oidcProvider := &AuthProviderOIDC{
|
||||||
|
db: hsdb,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test user
|
||||||
|
user := &types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser",
|
||||||
|
}
|
||||||
|
err := hsdb.DB.Create(user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test node
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
nodeExpiry := time.Now().Add(24 * time.Hour)
|
||||||
|
node := &types.Node{
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: "test-node",
|
||||||
|
GivenName: "test-node",
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
|
Expiry: &nodeExpiry,
|
||||||
|
}
|
||||||
|
err = hsdb.DB.Create(node).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
user *types.User
|
||||||
|
registrationID types.RegistrationID
|
||||||
|
token *oauth2.Token
|
||||||
|
nodeExpiry time.Time
|
||||||
|
expectError bool
|
||||||
|
expectSession bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create new session with refresh token",
|
||||||
|
user: user,
|
||||||
|
registrationID: types.RegistrationID("reg-123"),
|
||||||
|
token: &oauth2.Token{
|
||||||
|
AccessToken: "access-token",
|
||||||
|
RefreshToken: "refresh-token",
|
||||||
|
Expiry: time.Now().Add(1 * time.Hour),
|
||||||
|
},
|
||||||
|
nodeExpiry: time.Now().Add(24 * time.Hour),
|
||||||
|
expectError: false,
|
||||||
|
expectSession: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip session creation without refresh token",
|
||||||
|
user: user,
|
||||||
|
registrationID: types.RegistrationID("reg-456"),
|
||||||
|
token: &oauth2.Token{
|
||||||
|
AccessToken: "access-token-only",
|
||||||
|
Expiry: time.Now().Add(1 * time.Hour),
|
||||||
|
},
|
||||||
|
nodeExpiry: time.Now().Add(24 * time.Hour),
|
||||||
|
expectError: false,
|
||||||
|
expectSession: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update existing session",
|
||||||
|
user: user,
|
||||||
|
registrationID: types.RegistrationID("reg-789"),
|
||||||
|
token: &oauth2.Token{
|
||||||
|
AccessToken: "new-access-token",
|
||||||
|
RefreshToken: "new-refresh-token",
|
||||||
|
Expiry: time.Now().Add(2 * time.Hour),
|
||||||
|
},
|
||||||
|
nodeExpiry: time.Now().Add(24 * time.Hour),
|
||||||
|
expectError: false,
|
||||||
|
expectSession: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := oidcProvider.createOrUpdateOIDCSession(tt.user, tt.registrationID, tt.token, tt.nodeExpiry)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectSession && tt.token.RefreshToken != "" {
|
||||||
|
// Verify session was created/updated
|
||||||
|
var session types.OIDCSession
|
||||||
|
err = hsdb.DB.Where("user_id = ? AND node_id = ?", tt.user.ID, node.ID).First(&session).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.token.RefreshToken, session.RefreshToken)
|
||||||
|
assert.True(t, session.IsActive)
|
||||||
|
assert.NotNil(t, session.TokenExpiry)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenValidation(t *testing.T) {
|
||||||
|
// Test the basic validation logic that's done before OAuth2 calls
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session *types.OIDCSession
|
||||||
|
shouldFail bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "session_no_refresh_token",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "no-token-session",
|
||||||
|
RefreshToken: "",
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
description: "Session without refresh token should fail validation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_with_refresh_token",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "valid-session",
|
||||||
|
RefreshToken: "valid-refresh-token",
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
shouldFail: false,
|
||||||
|
description: "Session with refresh token should pass initial validation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_inactive",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "inactive-session",
|
||||||
|
RefreshToken: "refresh-token",
|
||||||
|
IsActive: false,
|
||||||
|
},
|
||||||
|
shouldFail: false,
|
||||||
|
description: "Session validation only checks refresh token presence",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test the basic validation logic that checks for refresh token
|
||||||
|
hasValidRefreshToken := tt.session.RefreshToken != ""
|
||||||
|
|
||||||
|
if tt.shouldFail {
|
||||||
|
assert.False(t, hasValidRefreshToken, tt.description)
|
||||||
|
} else {
|
||||||
|
assert.True(t, hasValidRefreshToken, tt.description)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshExpiredTokensLogic(t *testing.T) {
|
||||||
|
// Test the logic that determines which sessions need refresh without DB
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session *types.OIDCSession
|
||||||
|
shouldRefresh bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "session_needs_refresh_expiring_soon",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "session1",
|
||||||
|
RefreshToken: "refresh1",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(3 * time.Minute)}[0],
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
shouldRefresh: true,
|
||||||
|
description: "Active session with token expiring in 3 minutes should need refresh (5 min threshold)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_needs_refresh_already_expired",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "session2",
|
||||||
|
RefreshToken: "refresh2",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(-1 * time.Hour)}[0],
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
shouldRefresh: true,
|
||||||
|
description: "Active session with expired token should need refresh",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_no_refresh_valid_token",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "session3",
|
||||||
|
RefreshToken: "refresh3",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(2 * time.Hour)}[0],
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
shouldRefresh: false,
|
||||||
|
description: "Active session with valid token should not need refresh",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_no_refresh_inactive",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "session4",
|
||||||
|
RefreshToken: "refresh4",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(-1 * time.Hour)}[0],
|
||||||
|
IsActive: false,
|
||||||
|
},
|
||||||
|
shouldRefresh: false,
|
||||||
|
description: "Inactive session should not be refreshed even if expired",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session_no_refresh_no_token",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "session5",
|
||||||
|
RefreshToken: "",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(-1 * time.Hour)}[0],
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
shouldRefresh: false,
|
||||||
|
description: "Session without refresh token should not be refreshed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test the logic that determines if a session needs refresh
|
||||||
|
// This mirrors the logic in RefreshExpiredTokens
|
||||||
|
threshold := now.Add(5 * time.Minute)
|
||||||
|
needsRefresh := tt.session.IsActive &&
|
||||||
|
tt.session.RefreshToken != "" &&
|
||||||
|
tt.session.TokenExpiry != nil &&
|
||||||
|
tt.session.TokenExpiry.Before(threshold)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.shouldRefresh, needsRefresh, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetermineNodeExpiry(t *testing.T) {
|
||||||
|
oidcProvider := &AuthProviderOIDC{
|
||||||
|
cfg: &types.OIDCConfig{
|
||||||
|
UseExpiryFromToken: true,
|
||||||
|
Expiry: 180 * 24 * time.Hour, // Default expiry
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
idTokenExpiry := now.Add(2 * time.Hour)
|
||||||
|
|
||||||
|
// Test with UseExpiryFromToken = true
|
||||||
|
expiry := oidcProvider.determineNodeExpiry(idTokenExpiry)
|
||||||
|
assert.Equal(t, idTokenExpiry, expiry)
|
||||||
|
|
||||||
|
// Test with UseExpiryFromToken = false
|
||||||
|
oidcProvider.cfg.UseExpiryFromToken = false
|
||||||
|
expiry = oidcProvider.determineNodeExpiry(idTokenExpiry)
|
||||||
|
// Should return current time + cfg.Expiry
|
||||||
|
expectedExpiry := now.Add(oidcProvider.cfg.Expiry)
|
||||||
|
assert.WithinDuration(t, expectedExpiry, expiry, 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshExpiredTokens(t *testing.T) {
|
||||||
|
hsdb := setupTestDB(t)
|
||||||
|
|
||||||
|
// Create test OIDC provider
|
||||||
|
oidcProvider := &AuthProviderOIDC{
|
||||||
|
db: hsdb,
|
||||||
|
cfg: &types.OIDCConfig{
|
||||||
|
Issuer: "https://test.example.com",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test user
|
||||||
|
user := &types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser",
|
||||||
|
}
|
||||||
|
err := hsdb.DB.Create(user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupSession func() *types.OIDCSession
|
||||||
|
expectError bool
|
||||||
|
expectCalled int // How many refresh calls should be attempted
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no sessions need refresh",
|
||||||
|
setupSession: func() *types.OIDCSession {
|
||||||
|
node := createTestNode(t, hsdb, user, "test-node-1")
|
||||||
|
return &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node.ID),
|
||||||
|
SessionID: "valid-session",
|
||||||
|
RegistrationID: types.RegistrationID("reg-123"),
|
||||||
|
RefreshToken: "refresh-token",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(2 * time.Hour)}[0],
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
expectCalled: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session needs refresh but no refresh token",
|
||||||
|
setupSession: func() *types.OIDCSession {
|
||||||
|
node := createTestNode(t, hsdb, user, "test-node-2")
|
||||||
|
return &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node.ID),
|
||||||
|
SessionID: "no-token-session",
|
||||||
|
RegistrationID: types.RegistrationID("reg-456"),
|
||||||
|
RefreshToken: "", // No refresh token
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(10 * time.Minute)}[0], // Expiring soon
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
expectCalled: 0, // Won't be called due to empty refresh token
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid token should be ignored",
|
||||||
|
setupSession: func() *types.OIDCSession {
|
||||||
|
node := createTestNode(t, hsdb, user, "test-node-3")
|
||||||
|
return &types.OIDCSession{
|
||||||
|
UserID: user.ID,
|
||||||
|
NodeID: types.NodeID(node.ID),
|
||||||
|
SessionID: "valid-token-session",
|
||||||
|
RegistrationID: types.RegistrationID("reg-789"),
|
||||||
|
RefreshToken: "refresh-token",
|
||||||
|
TokenExpiry: &[]time.Time{now.Add(2 * time.Hour)}[0], // Not expiring soon
|
||||||
|
IsActive: true, // Active but token not expiring
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
expectCalled: 0, // Won't be called due to valid token
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no sessions at all",
|
||||||
|
setupSession: func() *types.OIDCSession {
|
||||||
|
return nil // No session
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
expectCalled: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Clean up sessions from previous tests
|
||||||
|
err := hsdb.DB.Where("1 = 1").Delete(&types.OIDCSession{}).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup test session if needed
|
||||||
|
if tt.setupSession != nil {
|
||||||
|
session := tt.setupSession()
|
||||||
|
if session != nil {
|
||||||
|
err := hsdb.DB.Create(session).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call RefreshExpiredTokens
|
||||||
|
// Note: This will fail when it tries to make OAuth2 calls, but we can test
|
||||||
|
// the initial logic (finding sessions, filtering them, etc.)
|
||||||
|
ctx := context.Background()
|
||||||
|
err = oidcProvider.RefreshExpiredTokens(ctx)
|
||||||
|
|
||||||
|
// We expect this to fail if there are sessions that need refresh
|
||||||
|
// because we don't have a real OAuth2 config, but it should succeed
|
||||||
|
// if no sessions need refresh
|
||||||
|
if tt.expectCalled == 0 {
|
||||||
|
// Should succeed when no sessions need refreshing
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
// Will fail due to OAuth2 config being nil, but that's expected
|
||||||
|
// The important thing is that it found the sessions that need refresh
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshOIDCSessionValidation(t *testing.T) {
|
||||||
|
hsdb := setupTestDB(t)
|
||||||
|
|
||||||
|
// Create test OIDC provider
|
||||||
|
oidcProvider := &AuthProviderOIDC{
|
||||||
|
db: hsdb,
|
||||||
|
cfg: &types.OIDCConfig{
|
||||||
|
Issuer: "https://test.example.com",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session *types.OIDCSession
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "session without refresh token should fail",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "no-token-session",
|
||||||
|
RefreshToken: "", // No refresh token
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "no refresh token available",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session with refresh token should fail due to no OAuth2 config",
|
||||||
|
session: &types.OIDCSession{
|
||||||
|
SessionID: "valid-session",
|
||||||
|
RefreshToken: "valid-refresh-token",
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
expectError: true, // Will fail due to OAuth2 config being nil
|
||||||
|
errorMsg: "failed to refresh OIDC token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test validation - catch panics from OAuth2 config being nil
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// OAuth2 config being nil causes a panic, which means we got past
|
||||||
|
// the refresh token validation - that's expected for the second test
|
||||||
|
if tt.session.RefreshToken != "" {
|
||||||
|
// This means the function got to the OAuth2 call part, which is expected
|
||||||
|
// The panic indicates we successfully passed the refresh token validation
|
||||||
|
assert.Contains(t, fmt.Sprintf("%v", r), "nil pointer")
|
||||||
|
} else {
|
||||||
|
// Shouldn't panic for missing refresh token
|
||||||
|
t.Errorf("Unexpected panic for empty refresh token: %v", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := oidcProvider.RefreshOIDCSession(ctx, tt.session)
|
||||||
|
|
||||||
|
// If we get here, it means no panic occurred (good for empty refresh token test)
|
||||||
|
if err != nil {
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -173,18 +173,19 @@ type PKCEConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConfig struct {
|
type OIDCConfig struct {
|
||||||
OnlyStartIfOIDCIsAvailable bool
|
OnlyStartIfOIDCIsAvailable bool
|
||||||
Issuer string
|
Issuer string
|
||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
Scope []string
|
Scope []string
|
||||||
ExtraParams map[string]string
|
ExtraParams map[string]string
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
Expiry time.Duration
|
Expiry time.Duration
|
||||||
UseExpiryFromToken bool
|
UseExpiryFromToken bool
|
||||||
PKCE PKCEConfig
|
SessionInvalidationGracePeriod time.Duration
|
||||||
|
PKCE PKCEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type DERPConfig struct {
|
type DERPConfig struct {
|
||||||
@ -320,6 +321,7 @@ func LoadConfig(path string, isFile bool) error {
|
|||||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||||
viper.SetDefault("oidc.expiry", "180d")
|
viper.SetDefault("oidc.expiry", "180d")
|
||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
|
viper.SetDefault("oidc.session_invalidation_grace_period", "30m")
|
||||||
viper.SetDefault("oidc.pkce.enabled", false)
|
viper.SetDefault("oidc.pkce.enabled", false)
|
||||||
viper.SetDefault("oidc.pkce.method", "S256")
|
viper.SetDefault("oidc.pkce.method", "S256")
|
||||||
|
|
||||||
@ -963,7 +965,8 @@ func LoadServerConfig() (*Config, error) {
|
|||||||
return time.Duration(expiry)
|
return time.Duration(expiry)
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
||||||
|
SessionInvalidationGracePeriod: viper.GetDuration("oidc.session_invalidation_grace_period"),
|
||||||
PKCE: PKCEConfig{
|
PKCE: PKCEConfig{
|
||||||
Enabled: viper.GetBool("oidc.pkce.enabled"),
|
Enabled: viper.GetBool("oidc.pkce.enabled"),
|
||||||
Method: viper.GetString("oidc.pkce.method"),
|
Method: viper.GetString("oidc.pkce.method"),
|
||||||
|
59
hscontrol/types/oidc_session.go
Normal file
59
hscontrol/types/oidc_session.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OIDCSession represents an OIDC authentication session linked to a specific node
|
||||||
|
type OIDCSession struct {
|
||||||
|
gorm.Model
|
||||||
|
|
||||||
|
// Core relationships
|
||||||
|
UserID uint `gorm:"not null;index"`
|
||||||
|
User User `gorm:"constraint:OnDelete:CASCADE;"`
|
||||||
|
NodeID NodeID `gorm:"not null;uniqueIndex"`
|
||||||
|
Node Node `gorm:"constraint:OnDelete:CASCADE;"`
|
||||||
|
|
||||||
|
// Session identification
|
||||||
|
SessionID string `gorm:"uniqueIndex;not null"`
|
||||||
|
RegistrationID RegistrationID `gorm:"not null"` // For reusing HandleNodeFromAuthPath
|
||||||
|
|
||||||
|
// Token data
|
||||||
|
RefreshToken string `gorm:"type:text"` //TODO: Encrypt?
|
||||||
|
|
||||||
|
// Token lifecycle
|
||||||
|
TokenExpiry *time.Time `gorm:"index"`
|
||||||
|
LastRefreshedAt *time.Time
|
||||||
|
RefreshCount int `gorm:"default:0"`
|
||||||
|
|
||||||
|
// Session state
|
||||||
|
IsActive bool `gorm:"default:true;index"`
|
||||||
|
LastSeenAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OIDCSession) TableName() string {
|
||||||
|
return "oidc_sessions"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired checks if the session's token has expired
|
||||||
|
func (s *OIDCSession) IsExpired() bool {
|
||||||
|
return s.TokenExpiry != nil && s.TokenExpiry.Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpiringSoon checks if the session's token will expire within the given duration
|
||||||
|
func (s *OIDCSession) IsExpiringSoon(duration time.Duration) bool {
|
||||||
|
return s.TokenExpiry != nil && s.TokenExpiry.Before(time.Now().Add(duration))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deactivate marks the session as inactive
|
||||||
|
func (s *OIDCSession) Deactivate() {
|
||||||
|
s.IsActive = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastSeen updates the last seen timestamp
|
||||||
|
func (s *OIDCSession) UpdateLastSeen() {
|
||||||
|
now := time.Now()
|
||||||
|
s.LastSeenAt = &now
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user