1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-14 13:51:01 +02:00
This commit is contained in:
mazlumtoprak 2025-07-24 20:19:17 +02:00 committed by GitHub
commit bf6cb355a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1442 additions and 16 deletions

View File

@ -1,6 +1,12 @@
# CHANGELOG
## Next
## 0.26.1 (2025-06-06)
### Changes
- Ensure nodes are matching both node key and machine key
when connecting.
[#2642](https://github.com/juanfont/headscale/pull/2642)
### Database integrity improvements

View File

@ -355,13 +355,24 @@ unix_socket_permission: "0770"
# # Note: enabling this will cause `oidc.expiry` to be ignored.
# use_expiry_from_token: false
#
# # The OIDC scopes to use, defaults to "openid", "profile" and "email".
# # Custom scopes can be configured as needed, be sure to always include the
# # required "openid" scope.
# scope: ["openid", "profile", "email"]
# # Token refresh and session management configuration
# token_refresh:
# # How often to check for tokens needing refresh (default: 15m)
# check_interval: 15m
#
# # Refresh tokens this far before expiry (default: 30m)
# expiry_threshold: 30m
#
# # Grace period for invalidating sessions of nodes that have been offline
# # Sessions for nodes offline longer than this duration will be invalidated (new SSO login required).
# # (default: 30m)
# session_invalidation_grace_period: 30m
#
# # Provide custom key/value pairs which get sent to the identity provider's
# # authorization endpoint.
# # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
# # To enable automatic token refresh, add "offline_access" to the scope list.
#
# scope: ["openid", "profile", "email", "offline_access"]
# extra_params:
# domain_hint: example.com
#

View File

@ -19,7 +19,13 @@ OpenID requires configuration in Headscale and your identity provider:
Additionally, there might be some useful hints in the [Identity provider specific
configuration](#identity-provider-specific-configuration) section below.
### Basic configuration
# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
# To enable automatic token refresh, add "offline_access" to the scope list.
scope: ["openid", "profile", "email", "custom"]
# Optional: Passed on to the browser login request used to tweak behaviour for the OIDC provider
extra_params:
domain_hint: example.com
A basic configuration connects Headscale to an identity provider and typically requires:

View File

@ -302,6 +302,41 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
}
}
func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthProviderOIDC) {
checkInterval := oidcProvider.cfg.TokenRefresh.CheckInterval
refreshTicker := time.NewTicker(checkInterval)
gracePeriodTicker := time.NewTicker(checkInterval)
defer refreshTicker.Stop()
defer gracePeriodTicker.Stop()
log.Info().Msgf("OIDC: Background token refresh job started (checking every %v for tokens expiring within %v)",
checkInterval, oidcProvider.cfg.TokenRefresh.ExpiryThreshold)
for {
select {
case <-ctx.Done():
log.Info().Caller().Msg("OIDC token refresh job is shutting down.")
return
// Refresh expired tokens every 15 minutes. Will be refreshed if their expiry is within the next 30 minutes.
case <-refreshTicker.C:
refreshCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
if err := oidcProvider.RefreshExpiredTokens(refreshCtx); err != nil {
log.Error().Err(err).Msg("OIDC: Failed to refresh expired tokens")
}
cancel()
// Invalidate sessions for nodes that have been offline for longer than the configured grace period
case <-gracePeriodTicker.C:
log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period")
gracePeriod := oidcProvider.cfg.TokenRefresh.SessionInvalidationGracePeriod
if err := h.state.InvalidateExpiredOIDCSessions(gracePeriod); err != nil {
log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes")
}
}
}
}
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
@ -617,6 +652,11 @@ func (h *Headscale) Serve() error {
defer scheduleCancel()
go h.scheduledTasks(scheduleCtx)
// Start OIDC token refresh background job if OIDC is enabled
if oidcProvider, ok := h.authProvider.(*AuthProviderOIDC); ok {
go h.oidcTokenRefreshJob(scheduleCtx, oidcProvider)
}
if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
} else {

View File

@ -927,7 +927,6 @@ AND auth_key_id NOT IN (
}
log.Info().Msg("Schema recreation completed successfully")
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
@ -936,6 +935,79 @@ AND auth_key_id NOT IN (
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
// - Never write migrations that requires foreign keys to be disabled.
{
ID: "202507140001",
Migrate: func(tx *gorm.DB) error {
// Create OIDC sessions table for managing OIDC refresh tokens
// This replaces the old OIDC token columns in the users table
if !tx.Migrator().HasTable(&types.OIDCSession{}) {
// Create the table with database-specific SQL
var sql string
if cfg.Type == types.DatabasePostgres {
sql = `
CREATE TABLE oidc_sessions (
id SERIAL PRIMARY KEY,
node_id INTEGER NOT NULL,
session_id TEXT NOT NULL,
registration_id INTEGER NOT NULL,
refresh_token TEXT,
token_expiry TIMESTAMP,
last_refreshed_at TIMESTAMP,
is_active BOOLEAN DEFAULT true,
last_seen_at TIMESTAMP,
created_at TIMESTAMP,
updated_at TIMESTAMP,
deleted_at TIMESTAMP,
CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE
)
`
} else {
sql = `
CREATE TABLE oidc_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
node_id INTEGER NOT NULL,
session_id TEXT NOT NULL,
registration_id INTEGER NOT NULL,
refresh_token TEXT,
token_expiry DATETIME,
last_refreshed_at DATETIME,
is_active NUMERIC DEFAULT true,
last_seen_at DATETIME,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME,
CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE
)
`
}
if err := tx.Exec(sql).Error; err != nil {
return fmt.Errorf("creating OIDC sessions table: %w", err)
}
// Create indexes matching the struct tags
if err := tx.Exec("CREATE UNIQUE INDEX idx_oidc_sessions_node_id ON oidc_sessions(node_id)").Error; err != nil {
return fmt.Errorf("creating node_id unique index: %w", err)
}
if err := tx.Exec("CREATE UNIQUE INDEX idx_oidc_sessions_session_id ON oidc_sessions(session_id)").Error; err != nil {
return fmt.Errorf("creating session_id unique index: %w", err)
}
if err := tx.Exec("CREATE INDEX idx_oidc_sessions_token_expiry ON oidc_sessions(token_expiry)").Error; err != nil {
return fmt.Errorf("creating token_expiry index: %w", err)
}
if err := tx.Exec("CREATE INDEX idx_oidc_sessions_is_active ON oidc_sessions(is_active)").Error; err != nil {
return fmt.Errorf("creating is_active index: %w", err)
}
if err := tx.Exec("CREATE INDEX idx_oidc_sessions_deleted_at ON oidc_sessions(deleted_at)").Error; err != nil {
return fmt.Errorf("creating deleted_at index: %w", err)
}
log.Debug().Msg("Created OIDC sessions table")
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View File

@ -317,6 +317,15 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node) error {
func DeleteNode(tx *gorm.DB,
node *types.Node,
) error {
// Invalidate OIDC sessions for this node before deletion
if err := InvalidateOIDCSessionsForNode(tx, node.ID); err != nil {
log.Error().Err(err).
Uint64("node_id", uint64(node.ID)).
Str("node", node.Hostname).
Msg("Failed to invalidate OIDC sessions for deleted node")
// Continue with deletion even if session invalidation fails
}
// Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
return err
@ -332,6 +341,14 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
nodeID types.NodeID,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
// Invalidate OIDC sessions for this node before deletion
if err := InvalidateOIDCSessionsForNode(tx, nodeID); err != nil {
log.Error().Err(err).
Uint64("node_id", uint64(nodeID)).
Msg("Failed to invalidate OIDC sessions for deleted ephemeral node")
// Continue with deletion even if session invalidation fails
}
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
return err
}

View File

@ -0,0 +1,102 @@
package db
import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
// InvalidateOIDCSessionsForNode invalidates all active OIDC sessions for a specific node
func (hsdb *HSDatabase) InvalidateOIDCSessionsForNode(nodeID types.NodeID) error {
return hsdb.Write(func(tx *gorm.DB) error {
return InvalidateOIDCSessionsForNode(tx, nodeID)
})
}
// InvalidateOIDCSessionsForNode invalidates all active OIDC sessions for a specific node
func InvalidateOIDCSessionsForNode(tx *gorm.DB, nodeID types.NodeID) error {
log.Debug().
Uint64("node_id", uint64(nodeID)).
Msg("OIDC: Invalidating sessions for node")
result := tx.Model(&types.OIDCSession{}).
Where("node_id = ? AND is_active = ?", nodeID, true).
Updates(map[string]interface{}{
"is_active": false,
"last_seen_at": time.Now(),
"refresh_token": nil,
})
if result.Error != nil {
log.Error().Err(result.Error).
Uint64("node_id", uint64(nodeID)).
Msg("OIDC: Failed to invalidate sessions for node")
return fmt.Errorf("failed to invalidate OIDC sessions for node %d: %w", nodeID, result.Error)
}
if result.RowsAffected > 0 {
log.Info().
Uint64("node_id", uint64(nodeID)).
Int64("sessions_invalidated", result.RowsAffected).
Msg("OIDC: Invalidated sessions for disconnected node")
} else {
log.Debug().
Uint64("node_id", uint64(nodeID)).
Msg("OIDC: No active sessions found for node")
}
return nil
}
// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline too long
func (hsdb *HSDatabase) InvalidateExpiredOIDCSessions(offlineGracePeriod time.Duration) error {
return hsdb.Write(func(tx *gorm.DB) error {
return InvalidateExpiredOIDCSessions(tx, offlineGracePeriod)
})
}
// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline too long
func InvalidateExpiredOIDCSessions(tx *gorm.DB, offlineGracePeriod time.Duration) error {
// Find active sessions where the node has been offline for longer than the grace period
cutoff := time.Now().Add(-offlineGracePeriod)
var sessions []types.OIDCSession
err := tx.Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
Where("oidc_sessions.is_active = ? AND nodes.last_seen IS NOT NULL AND nodes.last_seen < ?", true, cutoff).
Find(&sessions).Error
if err != nil {
return fmt.Errorf("failed to find expired OIDC sessions: %w", err)
}
if len(sessions) == 0 {
return nil
}
// Invalidate these sessions
sessionIDs := make([]string, len(sessions))
for i, session := range sessions {
sessionIDs[i] = session.SessionID
}
result := tx.Model(&types.OIDCSession{}).
Where("session_id IN ?", sessionIDs).
Updates(map[string]interface{}{
"is_active": false,
"last_seen_at": time.Now(),
"refresh_token": nil,
})
if result.Error != nil {
return fmt.Errorf("failed to invalidate expired OIDC sessions: %w", result.Error)
}
log.Info().
Int("sessions_invalidated", len(sessions)).
Dur("grace_period", offlineGracePeriod).
Msg("OIDC: Invalidated sessions for nodes offline beyond grace period")
return nil
}

View File

@ -0,0 +1,305 @@
package db
import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"tailscale.com/types/key"
)
func (*Suite) TestInvalidateOIDCSessionsForNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test-oidc-user"})
c.Assert(err, check.IsNil)
nodeKey := key.NewNode()
discoKey := key.NewDisco()
machineKey := key.NewMachine()
nodeExpiry := time.Now().Add(24 * time.Hour)
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: "test-node",
GivenName: "test-node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
AuthKeyID: nil, // No auth key for OIDC
Expiry: &nodeExpiry,
}
db.DB.Save(node)
// Create an active OIDC session
sessionID := "test-session-id-1"
registrationID := types.RegistrationID("test-reg-id-1")
tokenExpiry := time.Now().Add(1 * time.Hour)
session := &types.OIDCSession{
NodeID: types.NodeID(node.ID),
SessionID: sessionID,
RegistrationID: registrationID,
RefreshToken: "test-refresh-token",
TokenExpiry: &tokenExpiry,
IsActive: true,
}
db.DB.Save(session)
// Verify session is active
var checkSession types.OIDCSession
err = db.DB.Where("node_id = ?", node.ID).First(&checkSession).Error
c.Assert(err, check.IsNil)
c.Assert(checkSession.IsActive, check.Equals, true)
// Invalidate the session
err = db.InvalidateOIDCSessionsForNode(types.NodeID(node.ID))
c.Assert(err, check.IsNil)
// Verify session is now inactive
err = db.DB.Where("node_id = ?", node.ID).First(&checkSession).Error
c.Assert(err, check.IsNil)
c.Assert(checkSession.IsActive, check.Equals, false)
}
func (*Suite) TestInvalidateExpiredOIDCSessions(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test-oidc-expire-user"})
c.Assert(err, check.IsNil)
// Create nodes for testing
nodeKey1 := key.NewNode()
discoKey1 := key.NewDisco()
machineKey1 := key.NewMachine()
nodeExpiry1 := time.Now().Add(24 * time.Hour)
node1 := &types.Node{
MachineKey: machineKey1.Public(),
NodeKey: nodeKey1.Public(),
DiscoKey: discoKey1.Public(),
Hostname: "test-node-1",
GivenName: "test-node-1",
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
AuthKeyID: nil, // No auth key for OIDC
Expiry: &nodeExpiry1,
}
db.DB.Save(node1)
nodeKey2 := key.NewNode()
discoKey2 := key.NewDisco()
machineKey2 := key.NewMachine()
nodeExpiry2 := time.Now().Add(24 * time.Hour)
node2 := &types.Node{
MachineKey: machineKey2.Public(),
NodeKey: nodeKey2.Public(),
DiscoKey: discoKey2.Public(),
Hostname: "test-node-2",
GivenName: "test-node-2",
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
AuthKeyID: nil, // No auth key for OIDC
Expiry: &nodeExpiry2,
}
db.DB.Save(node2)
// Create sessions with different expiry times
now := time.Now()
// Session 1: Expired token, last seen within grace period
expiredTime1 := now.Add(-1 * time.Hour)
lastSeen1 := now.Add(-5 * time.Minute)
node1.LastSeen = &lastSeen1
db.DB.Save(node1)
session1 := &types.OIDCSession{
NodeID: types.NodeID(node1.ID),
SessionID: "expired-session-1",
RegistrationID: types.RegistrationID("reg-1"),
RefreshToken: "refresh-1",
TokenExpiry: &expiredTime1,
IsActive: true,
}
db.DB.Save(session1)
// Session 2: Expired token, last seen outside grace period
expiredTime2 := now.Add(-2 * time.Hour)
lastSeen2 := now.Add(-20 * time.Minute)
node2.LastSeen = &lastSeen2
db.DB.Save(node2)
session2 := &types.OIDCSession{
NodeID: types.NodeID(node2.ID),
SessionID: "expired-session-2",
RegistrationID: types.RegistrationID("reg-2"),
RefreshToken: "refresh-2",
TokenExpiry: &expiredTime2,
IsActive: true,
}
db.DB.Save(session2)
// Create a third node for session 3
nodeKey3 := key.NewNode()
discoKey3 := key.NewDisco()
machineKey3 := key.NewMachine()
nodeExpiry3 := time.Now().Add(24 * time.Hour)
node3 := &types.Node{
MachineKey: machineKey3.Public(),
NodeKey: nodeKey3.Public(),
DiscoKey: discoKey3.Public(),
Hostname: "test-node-3",
GivenName: "test-node-3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
AuthKeyID: nil, // No auth key for OIDC
Expiry: &nodeExpiry3,
}
db.DB.Save(node3)
// Session 3: Valid token
validTime := now.Add(1 * time.Hour)
session3 := &types.OIDCSession{
NodeID: types.NodeID(node3.ID),
SessionID: "valid-session",
RegistrationID: types.RegistrationID("reg-3"),
RefreshToken: "refresh-3",
TokenExpiry: &validTime,
IsActive: true,
}
db.DB.Save(session3)
// Invalidate expired sessions with 10 minute grace period
err = db.InvalidateExpiredOIDCSessions(10 * time.Minute)
c.Assert(err, check.IsNil)
// Check results
var checkSession1, checkSession2, checkSession3 types.OIDCSession
// Session 1: Should still be active (within grace period)
err = db.DB.Where("session_id = ?", "expired-session-1").First(&checkSession1).Error
c.Assert(err, check.IsNil)
c.Assert(checkSession1.IsActive, check.Equals, true)
// Session 2: Should be inactive (outside grace period)
err = db.DB.Where("session_id = ?", "expired-session-2").First(&checkSession2).Error
c.Assert(err, check.IsNil)
c.Assert(checkSession2.IsActive, check.Equals, false)
// Session 3: Should still be active (valid token)
err = db.DB.Where("session_id = ?", "valid-session").First(&checkSession3).Error
c.Assert(err, check.IsNil)
c.Assert(checkSession3.IsActive, check.Equals, true)
}
func (*Suite) TestInvalidateOIDCSessionsWithNoSessions(c *check.C) {
// Test with non-existent node ID
err := db.InvalidateOIDCSessionsForNode(types.NodeID(99999))
c.Assert(err, check.IsNil) // Should not error even if no sessions exist
}
func (*Suite) TestInvalidateExpiredOIDCSessionsWithNoExpired(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test-no-expired-user"})
c.Assert(err, check.IsNil)
// Create a node for the session
nodeKey := key.NewNode()
discoKey := key.NewDisco()
machineKey := key.NewMachine()
nodeExpiry := time.Now().Add(24 * time.Hour)
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: "test-valid-node",
GivenName: "test-valid-node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
AuthKeyID: nil, // No auth key for OIDC
Expiry: &nodeExpiry,
}
db.DB.Save(node)
// Create only valid sessions
validTime := time.Now().Add(24 * time.Hour)
session := &types.OIDCSession{
NodeID: types.NodeID(node.ID),
SessionID: "valid-only-session",
RegistrationID: types.RegistrationID("reg-valid"),
RefreshToken: "refresh-valid",
TokenExpiry: &validTime,
IsActive: true,
}
db.DB.Save(session)
// Run invalidation
err = db.InvalidateExpiredOIDCSessions(10 * time.Minute)
c.Assert(err, check.IsNil)
// Verify session is still active
var checkSession types.OIDCSession
err = db.DB.Where("session_id = ?", "valid-only-session").First(&checkSession).Error
c.Assert(err, check.IsNil)
c.Assert(checkSession.IsActive, check.Equals, true)
}
func (*Suite) TestInvalidateOIDCSessionsTransaction(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test-transaction-user"})
c.Assert(err, check.IsNil)
// Create multiple nodes and sessions (one session per node)
var nodeIDs []types.NodeID
for i := 0; i < 3; i++ {
// Create a node for each session
nodeKey := key.NewNode()
discoKey := key.NewDisco()
machineKey := key.NewMachine()
nodeExpiry := time.Now().Add(24 * time.Hour)
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: fmt.Sprintf("test-transaction-node-%d", i),
GivenName: fmt.Sprintf("test-transaction-node-%d", i),
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
AuthKeyID: nil, // No auth key for OIDC
Expiry: &nodeExpiry,
}
db.DB.Save(node)
nodeIDs = append(nodeIDs, types.NodeID(node.ID))
// Create a session for this node
tokenExpiry := time.Now().Add(1 * time.Hour)
sessionID := fmt.Sprintf("session-%d", i)
session := &types.OIDCSession{
NodeID: types.NodeID(node.ID),
SessionID: sessionID,
RegistrationID: types.RegistrationID(fmt.Sprintf("reg-%d", i)),
RefreshToken: fmt.Sprintf("refresh-%d", i),
TokenExpiry: &tokenExpiry,
IsActive: true,
}
result := db.DB.Create(session)
c.Assert(result.Error, check.IsNil)
}
// Verify all sessions are active
var count int64
db.DB.Model(&types.OIDCSession{}).Where("is_active = ?", true).Count(&count)
c.Assert(count, check.Equals, int64(3))
// Invalidate all sessions for the first node
err = db.InvalidateOIDCSessionsForNode(nodeIDs[0])
c.Assert(err, check.IsNil)
// Verify one session is now inactive, two still active
db.DB.Model(&types.OIDCSession{}).Where("is_active = ?", true).Count(&count)
c.Assert(count, check.Equals, int64(2))
db.DB.Model(&types.OIDCSession{}).Where("is_active = ?", false).Count(&count)
c.Assert(count, check.Equals, int64(1))
// Verify the correct session was invalidated
var inactiveSession types.OIDCSession
err = db.DB.Where("node_id = ? AND is_active = ?", nodeIDs[0], false).First(&inactiveSession).Error
c.Assert(err, check.IsNil)
c.Assert(inactiveSession.SessionID, check.Equals, "session-0")
}

View File

@ -108,3 +108,24 @@ CREATE TABLE policies(
deleted_at datetime
);
CREATE INDEX idx_policies_deleted_at ON policies(deleted_at);
CREATE TABLE oidc_sessions(
id INTEGER PRIMARY KEY AUTOINCREMENT,
node_id INTEGER NOT NULL,
session_id TEXT NOT NULL,
registration_id INTEGER NOT NULL,
refresh_token TEXT,
token_expiry DATETIME,
last_refreshed_at DATETIME,
is_active NUMERIC DEFAULT true,
last_seen_at DATETIME,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME,
CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_oidc_sessions_node_id ON oidc_sessions(node_id);
CREATE UNIQUE INDEX idx_oidc_sessions_session_id ON oidc_sessions(session_id);
CREATE INDEX idx_oidc_sessions_token_expiry ON oidc_sessions(token_expiry);
CREATE INDEX idx_oidc_sessions_is_active ON oidc_sessions(is_active);
CREATE INDEX idx_oidc_sessions_deleted_at ON oidc_sessions(deleted_at);

View File

@ -331,7 +331,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
nodeID, newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
httpError(writer, err)
return
@ -341,6 +341,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
verb = "Authenticated"
}
// Create or update OIDC session for this node
if err := a.createOrUpdateOIDCSession(*registrationId, oauth2Token, *nodeID); err != nil {
log.Error().Err(err).Msg("Failed to create OIDC session")
// Don't fail the auth flow, just log the error
}
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
@ -360,8 +366,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
func extractCodeAndStateParamFromRequest(
@ -403,6 +407,165 @@ func (a *AuthProviderOIDC) getOauth2Token(
return oauth2Token, err
}
// createOrUpdateOIDCSession creates or updates an OIDC session for a node
func (a *AuthProviderOIDC) createOrUpdateOIDCSession(registrationID types.RegistrationID, token *oauth2.Token, nodeID types.NodeID) error {
if token.RefreshToken == "" {
log.Warn().
Str("node_id", nodeID.String()).
Str("registration_id", registrationID.String()).
Msg("OIDC: No refresh token in OAuth2 token, skipping session creation (check offline_access scope)")
return nil
}
// Generate session ID
sessionID, err := util.GenerateRandomStringURLSafe(32)
if err != nil {
return fmt.Errorf("failed to generate session ID: %w", err)
}
// Create or update session
tokenExpiryUTC := token.Expiry.UTC()
session := &types.OIDCSession{
NodeID: nodeID,
SessionID: sessionID,
RegistrationID: registrationID,
RefreshToken: token.RefreshToken,
TokenExpiry: &tokenExpiryUTC,
IsActive: true,
}
now := time.Now().UTC()
session.LastRefreshedAt = &now
session.LastSeenAt = &now
// Try to update existing session first
existingSession, err := a.state.GetOIDCSessionByNodeID(nodeID)
if err == nil {
// Update existing session
existingSession.RefreshToken = token.RefreshToken
tokenExpiryUTC := token.Expiry.UTC()
existingSession.TokenExpiry = &tokenExpiryUTC
existingSession.LastRefreshedAt = &now
existingSession.LastSeenAt = &now
existingSession.IsActive = true
err = a.state.SaveOIDCSession(existingSession)
if err != nil {
return fmt.Errorf("failed to update OIDC session: %w", err)
}
} else {
// Create new session
err = a.state.CreateOIDCSession(session)
if err != nil {
return fmt.Errorf("failed to create OIDC session: %w", err)
}
}
return nil
}
// RefreshOIDCSession refreshes an expired OIDC session using the stored refresh token
// and updates the node expiry using the existing HandleNodeFromAuthPath flow
func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *types.OIDCSession) error {
if session.RefreshToken == "" {
return fmt.Errorf("no refresh token available for session %s", session.SessionID)
}
tokenSource := a.oauth2Config.TokenSource(ctx, &oauth2.Token{
RefreshToken: session.RefreshToken,
})
newToken, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to refresh OIDC token: %w", err)
}
// Calculate new node expiry based on the access token expiry (not ID token)
// Access tokens determine when we need to refresh, so node should live as long as access token
nodeExpiry := a.determineNodeExpiry(newToken.Expiry)
// Load the node for logging
node, err := a.state.GetNodeWithUser(session.NodeID)
if err != nil {
return fmt.Errorf("failed to load node: %w", err)
}
// Update the node expiry directly for token refresh
// We don't use HandleNodeFromAuthPath for refresh as the node is already registered
err = a.state.UpdateNodeExpiry(session.NodeID, nodeExpiry)
if err != nil {
return fmt.Errorf("failed to update node expiry: %w", err)
}
// Update the local node object for consistency
node.Expiry = &nodeExpiry
// Update the session with new token information
now := time.Now().UTC()
if newToken.RefreshToken != "" {
session.RefreshToken = newToken.RefreshToken
} else {
log.Debug().
Str("session_id", session.SessionID).
Msg("OIDC: No new refresh token received, keeping existing one")
}
// Store token expiry in UTC to avoid timezone issues
utcExpiry := newToken.Expiry.UTC()
session.TokenExpiry = &utcExpiry
session.LastRefreshedAt = &now
// Save the updated session
err = a.state.SaveOIDCSession(session)
if err != nil {
log.Error().Err(err).
Str("session_id", session.SessionID).
Msg("OIDC: Failed to save updated session to database")
return fmt.Errorf("failed to save updated session: %w", err)
}
return nil
}
// RefreshExpiredTokens checks for and refreshes OIDC sessions that will expire soon
func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error {
var sessions []types.OIDCSession
// Find active sessions with tokens expiring in the configured threshold time
currentTime := time.Now().UTC()
threshold := currentTime.Add(a.cfg.TokenRefresh.ExpiryThreshold)
// Only refresh tokens for sessions linked to OIDC-registered nodes
sessions, err := a.state.GetOIDCSessionsNeedingRefresh(threshold, "oidc")
if err != nil {
return fmt.Errorf("failed to query sessions with expiring tokens: %w", err)
}
if len(sessions) == 0 {
return nil
}
log.Debug().Msgf("OIDC: Found %d sessions with tokens expiring soon, refreshing...", len(sessions))
for _, session := range sessions {
if err := a.RefreshOIDCSession(ctx, &session); err != nil {
log.Error().Err(err).
Str("session_id", session.SessionID).
Uint("user_id", session.Node.UserID).
Uint64("node_id", uint64(session.NodeID)).
Msg("OIDC: Failed to refresh session, deactivating")
// Deactivate the session if refresh fails
session.Deactivate()
a.state.SaveOIDCSession(&session)
continue
}
}
return nil
}
// extractIDToken extracts the ID token from the oauth2 token.
func (a *AuthProviderOIDC) extractIDToken(
ctx context.Context,
@ -525,7 +688,7 @@ func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.RegistrationID,
expiry time.Time,
) (bool, error) {
) (*types.NodeID, bool, error) {
node, newNode, err := a.state.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
@ -533,7 +696,7 @@ func (a *AuthProviderOIDC) handleRegistration(
util.RegisterMethodOIDC,
)
if err != nil {
return false, fmt.Errorf("could not register node: %w", err)
return nil, false, fmt.Errorf("could not register node: %w", err)
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
@ -550,7 +713,7 @@ func (a *AuthProviderOIDC) handleRegistration(
routesChanged := a.state.AutoApproveRoutes(node)
_, policyChanged, err := a.state.SaveNode(node)
if err != nil {
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
return nil, false, fmt.Errorf("saving auto approved routes to node: %w", err)
}
// Send policy update notifications if needed (from SaveNode or route changes)
@ -571,7 +734,7 @@ func (a *AuthProviderOIDC) handleRegistration(
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
}
return newNode, nil
return &node.ID, newNode, nil
}
// TODO(kradalby):

529
hscontrol/oidc_test.go Normal file
View File

@ -0,0 +1,529 @@
package hscontrol
import (
"context"
"fmt"
"net/netip"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"tailscale.com/types/key"
)
// createTestNode creates a test node for testing
func createTestNode(t *testing.T, st *state.State, user *types.User, hostname string) *types.Node {
t.Helper()
nodeKey := key.NewNode()
discoKey := key.NewDisco()
machineKey := key.NewMachine()
nodeExpiry := time.Now().Add(24 * time.Hour)
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: user.ID,
RegisterMethod: util.RegisterMethodOIDC,
Expiry: &nodeExpiry,
}
createdNode, _, err := st.CreateNode(node)
require.NoError(t, err)
return createdNode
}
// setupTestState creates a test state with database
func setupTestState(t *testing.T) *state.State {
t.Helper()
tmpDir := t.TempDir()
prefixV4, _ := netip.ParsePrefix("100.64.0.0/10")
prefixV6, _ := netip.ParsePrefix("fd7a:115c:a1e0::/48")
cfg := &types.Config{
Database: types.DatabaseConfig{
Type: types.DatabaseSqlite,
Sqlite: types.SqliteConfig{
Path: tmpDir + "/test.db",
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
},
BaseDomain: "test.local",
PrefixV4: &prefixV4,
PrefixV6: &prefixV6,
IPAllocation: types.IPAllocationStrategySequential,
}
st, err := state.NewState(cfg)
require.NoError(t, err)
return st
}
func TestCreateOrUpdateOIDCSession(t *testing.T) {
st := setupTestState(t)
defer st.Close()
// Create test OIDC provider
oidcProvider := &AuthProviderOIDC{
state: st,
}
// Create test user
user := &types.User{
Name: "testuser",
}
createdUser, _, err := st.CreateUser(*user)
require.NoError(t, err)
// Create test node
nodeKey := key.NewNode()
discoKey := key.NewDisco()
machineKey := key.NewMachine()
nodeExpiry := time.Now().Add(24 * time.Hour)
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: "test-node",
GivenName: "test-node",
UserID: createdUser.ID,
RegisterMethod: util.RegisterMethodOIDC,
Expiry: &nodeExpiry,
}
createdNode, _, err := st.CreateNode(node)
require.NoError(t, err)
tests := []struct {
name string
user *types.User
registrationID types.RegistrationID
token *oauth2.Token
nodeExpiry time.Time
expectError bool
expectSession bool
}{
{
name: "create new session with refresh token",
user: user,
registrationID: types.RegistrationID("reg-123"),
token: &oauth2.Token{
AccessToken: "access-token",
RefreshToken: "refresh-token",
Expiry: time.Now().Add(1 * time.Hour),
},
nodeExpiry: time.Now().Add(24 * time.Hour),
expectError: false,
expectSession: true,
},
{
name: "skip session creation without refresh token",
user: user,
registrationID: types.RegistrationID("reg-456"),
token: &oauth2.Token{
AccessToken: "access-token-only",
Expiry: time.Now().Add(1 * time.Hour),
},
nodeExpiry: time.Now().Add(24 * time.Hour),
expectError: false,
expectSession: false,
},
{
name: "update existing session",
user: user,
registrationID: types.RegistrationID("reg-789"),
token: &oauth2.Token{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
Expiry: time.Now().Add(2 * time.Hour),
},
nodeExpiry: time.Now().Add(24 * time.Hour),
expectError: false,
expectSession: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := oidcProvider.createOrUpdateOIDCSession(tt.registrationID, tt.token, createdNode.ID)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
if tt.expectSession && tt.token.RefreshToken != "" {
// Verify session was created/updated
session, err := st.GetOIDCSessionByNodeID(createdNode.ID)
assert.NoError(t, err)
assert.Equal(t, tt.token.RefreshToken, session.RefreshToken)
assert.True(t, session.IsActive)
assert.NotNil(t, session.TokenExpiry)
}
})
}
}
func TestRefreshTokenValidation(t *testing.T) {
// Test the basic validation logic that's done before OAuth2 calls
tests := []struct {
name string
session *types.OIDCSession
shouldFail bool
description string
}{
{
name: "session_no_refresh_token",
session: &types.OIDCSession{
SessionID: "no-token-session",
RefreshToken: "",
IsActive: true,
},
shouldFail: true,
description: "Session without refresh token should fail validation",
},
{
name: "session_with_refresh_token",
session: &types.OIDCSession{
SessionID: "valid-session",
RefreshToken: "valid-refresh-token",
IsActive: true,
},
shouldFail: false,
description: "Session with refresh token should pass initial validation",
},
{
name: "session_inactive",
session: &types.OIDCSession{
SessionID: "inactive-session",
RefreshToken: "refresh-token",
IsActive: false,
},
shouldFail: false,
description: "Session validation only checks refresh token presence",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the basic validation logic that checks for refresh token
hasValidRefreshToken := tt.session.RefreshToken != ""
if tt.shouldFail {
assert.False(t, hasValidRefreshToken, tt.description)
} else {
assert.True(t, hasValidRefreshToken, tt.description)
}
})
}
}
func TestRefreshExpiredTokensLogic(t *testing.T) {
// Test the logic that determines which sessions need refresh without DB
now := time.Now()
tests := []struct {
name string
session *types.OIDCSession
shouldRefresh bool
description string
}{
{
name: "session_needs_refresh_expiring_soon",
session: &types.OIDCSession{
SessionID: "session1",
RefreshToken: "refresh1",
TokenExpiry: &[]time.Time{now.Add(3 * time.Minute)}[0],
IsActive: true,
},
shouldRefresh: true,
description: "Active session with token expiring in 3 minutes should need refresh (5 min threshold)",
},
{
name: "session_needs_refresh_already_expired",
session: &types.OIDCSession{
SessionID: "session2",
RefreshToken: "refresh2",
TokenExpiry: &[]time.Time{now.Add(-1 * time.Hour)}[0],
IsActive: true,
},
shouldRefresh: true,
description: "Active session with expired token should need refresh",
},
{
name: "session_no_refresh_valid_token",
session: &types.OIDCSession{
SessionID: "session3",
RefreshToken: "refresh3",
TokenExpiry: &[]time.Time{now.Add(2 * time.Hour)}[0],
IsActive: true,
},
shouldRefresh: false,
description: "Active session with valid token should not need refresh",
},
{
name: "session_no_refresh_inactive",
session: &types.OIDCSession{
SessionID: "session4",
RefreshToken: "refresh4",
TokenExpiry: &[]time.Time{now.Add(-1 * time.Hour)}[0],
IsActive: false,
},
shouldRefresh: false,
description: "Inactive session should not be refreshed even if expired",
},
{
name: "session_no_refresh_no_token",
session: &types.OIDCSession{
SessionID: "session5",
RefreshToken: "",
TokenExpiry: &[]time.Time{now.Add(-1 * time.Hour)}[0],
IsActive: true,
},
shouldRefresh: false,
description: "Session without refresh token should not be refreshed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the logic that determines if a session needs refresh
// This mirrors the logic in RefreshExpiredTokens
threshold := now.Add(5 * time.Minute)
needsRefresh := tt.session.IsActive &&
tt.session.RefreshToken != "" &&
tt.session.TokenExpiry != nil &&
tt.session.TokenExpiry.Before(threshold)
assert.Equal(t, tt.shouldRefresh, needsRefresh, tt.description)
})
}
}
func TestDetermineNodeExpiry(t *testing.T) {
oidcProvider := &AuthProviderOIDC{
cfg: &types.OIDCConfig{
UseExpiryFromToken: true,
Expiry: 180 * 24 * time.Hour, // Default expiry
},
}
now := time.Now()
idTokenExpiry := now.Add(2 * time.Hour)
// Test with UseExpiryFromToken = true
expiry := oidcProvider.determineNodeExpiry(idTokenExpiry)
assert.Equal(t, idTokenExpiry, expiry)
// Test with UseExpiryFromToken = false
oidcProvider.cfg.UseExpiryFromToken = false
expiry = oidcProvider.determineNodeExpiry(idTokenExpiry)
// Should return current time + cfg.Expiry
expectedExpiry := now.Add(oidcProvider.cfg.Expiry)
assert.WithinDuration(t, expectedExpiry, expiry, 1*time.Second)
}
func TestRefreshExpiredTokens(t *testing.T) {
st := setupTestState(t)
defer st.Close()
// Create test OIDC provider
oidcProvider := &AuthProviderOIDC{
state: st,
cfg: &types.OIDCConfig{
Issuer: "https://test.example.com",
ClientID: "test-client-id",
TokenRefresh: types.TokenRefreshConfig{
ExpiryThreshold: 5 * time.Minute,
},
},
}
// Create test user
user := &types.User{
Name: "testuser",
}
createdUser, _, err := st.CreateUser(*user)
require.NoError(t, err)
now := time.Now().UTC()
tests := []struct {
name string
setupSession func() *types.OIDCSession
expectError bool
expectCalled int // How many refresh calls should be attempted
}{
{
name: "no sessions need refresh",
setupSession: func() *types.OIDCSession {
node := createTestNode(t, st, createdUser, "test-node-1")
return &types.OIDCSession{
NodeID: node.ID,
SessionID: "valid-session",
RegistrationID: types.RegistrationID("reg-123"),
RefreshToken: "refresh-token",
TokenExpiry: &[]time.Time{now.Add(2 * time.Hour)}[0],
IsActive: true,
}
},
expectError: false,
expectCalled: 0,
},
{
name: "session needs refresh but no refresh token",
setupSession: func() *types.OIDCSession {
node := createTestNode(t, st, createdUser, "test-node-2")
return &types.OIDCSession{
NodeID: node.ID,
SessionID: "no-token-session",
RegistrationID: types.RegistrationID("reg-456"),
RefreshToken: "", // No refresh token
TokenExpiry: &[]time.Time{now.Add(10 * time.Minute)}[0], // Expiring soon
IsActive: true,
}
},
expectError: false,
expectCalled: 0, // Won't be called due to empty refresh token
},
{
name: "valid token should be ignored",
setupSession: func() *types.OIDCSession {
node := createTestNode(t, st, createdUser, "test-node-3")
return &types.OIDCSession{
NodeID: node.ID,
SessionID: "valid-token-session",
RegistrationID: types.RegistrationID("reg-789"),
RefreshToken: "refresh-token",
TokenExpiry: &[]time.Time{now.Add(2 * time.Hour)}[0], // Not expiring soon
IsActive: true, // Active but token not expiring
}
},
expectError: false,
expectCalled: 0, // Won't be called due to valid token
},
{
name: "no sessions at all",
setupSession: func() *types.OIDCSession {
return nil // No session
},
expectError: false,
expectCalled: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup test session if needed
if tt.setupSession != nil {
session := tt.setupSession()
if session != nil {
err := st.CreateOIDCSession(session)
require.NoError(t, err)
}
}
// Call RefreshExpiredTokens
// Note: This will fail when it tries to make OAuth2 calls, but we can test
// the initial logic (finding sessions, filtering them, etc.)
ctx := context.Background()
err = oidcProvider.RefreshExpiredTokens(ctx)
// We expect this to fail if there are sessions that need refresh
// because we don't have a real OAuth2 config, but it should succeed
// if no sessions need refresh
if tt.expectCalled == 0 {
// Should succeed when no sessions need refreshing
assert.NoError(t, err)
} else {
// Will fail due to OAuth2 config being nil, but that's expected
// The important thing is that it found the sessions that need refresh
assert.Error(t, err)
}
})
}
}
func TestRefreshOIDCSessionValidation(t *testing.T) {
st := setupTestState(t)
defer st.Close()
// Create test OIDC provider
oidcProvider := &AuthProviderOIDC{
state: st,
cfg: &types.OIDCConfig{
Issuer: "https://test.example.com",
ClientID: "test-client-id",
},
}
tests := []struct {
name string
session *types.OIDCSession
expectError bool
errorMsg string
}{
{
name: "session without refresh token should fail",
session: &types.OIDCSession{
SessionID: "no-token-session",
RefreshToken: "", // No refresh token
IsActive: true,
},
expectError: true,
errorMsg: "no refresh token available",
},
{
name: "session with refresh token should fail due to no OAuth2 config",
session: &types.OIDCSession{
SessionID: "valid-session",
RefreshToken: "valid-refresh-token",
IsActive: true,
},
expectError: true, // Will fail due to OAuth2 config being nil
errorMsg: "failed to refresh OIDC token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
// Test validation - catch panics from OAuth2 config being nil
defer func() {
if r := recover(); r != nil {
// OAuth2 config being nil causes a panic, which means we got past
// the refresh token validation - that's expected for the second test
if tt.session.RefreshToken != "" {
// This means the function got to the OAuth2 call part, which is expected
// The panic indicates we successfully passed the refresh token validation
assert.Contains(t, fmt.Sprintf("%v", r), "nil pointer")
} else {
// Shouldn't panic for missing refresh token
t.Errorf("Unexpected panic for empty refresh token: %v", r)
}
}
}()
err := oidcProvider.RefreshOIDCSession(ctx, tt.session)
// If we get here, it means no panic occurred (good for empty refresh token test)
if err != nil {
assert.Contains(t, err.Error(), tt.errorMsg)
}
})
}
}

View File

@ -835,3 +835,86 @@ func (s *State) autoApproveNodes() error {
return nil
}
// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline
// for longer than the configured grace period.
func (s *State) InvalidateExpiredOIDCSessions(offlineGracePeriod time.Duration) error {
return s.db.InvalidateExpiredOIDCSessions(offlineGracePeriod)
}
// GetOIDCSessionByNodeID retrieves an OIDC session by node ID.
func (s *State) GetOIDCSessionByNodeID(nodeID types.NodeID) (*types.OIDCSession, error) {
var session types.OIDCSession
err := s.db.Read(func(tx *gorm.DB) error {
return tx.Where("node_id = ?", nodeID).First(&session).Error
})
if err != nil {
return nil, err
}
return &session, nil
}
// SaveOIDCSession saves or updates an OIDC session.
func (s *State) SaveOIDCSession(session *types.OIDCSession) error {
return s.db.Write(func(tx *gorm.DB) error {
return tx.Save(session).Error
})
}
// CreateOIDCSession creates a new OIDC session.
func (s *State) CreateOIDCSession(session *types.OIDCSession) error {
return s.db.Write(func(tx *gorm.DB) error {
return tx.Create(session).Error
})
}
// UpdateNodeExpiry updates a node's expiry time.
func (s *State) UpdateNodeExpiry(nodeID types.NodeID, expiry time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.db.Write(func(tx *gorm.DB) error {
result := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry)
if result.Error != nil {
return result.Error
}
// Update in-memory state
for i := range s.nodes {
if s.nodes[i].ID == nodeID {
s.nodes[i].Expiry = &expiry
break
}
}
return nil
})
}
// GetNodeWithUser retrieves a node with its associated user.
func (s *State) GetNodeWithUser(nodeID types.NodeID) (*types.Node, error) {
var node types.Node
err := s.db.Read(func(tx *gorm.DB) error {
return tx.Preload("User").First(&node, nodeID).Error
})
if err != nil {
return nil, err
}
return &node, nil
}
// GetOIDCSessionsNeedingRefresh retrieves OIDC sessions that need token refresh.
func (s *State) GetOIDCSessionsNeedingRefresh(threshold time.Time, registerMethod string) ([]types.OIDCSession, error) {
var sessions []types.OIDCSession
err := s.db.Read(func(tx *gorm.DB) error {
return tx.Preload("Node").Preload("Node.User").
Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?",
true, threshold, registerMethod).
Find(&sessions).Error
})
if err != nil {
return nil, err
}
return sessions, nil
}

View File

@ -172,6 +172,12 @@ type PKCEConfig struct {
Method string
}
type TokenRefreshConfig struct {
CheckInterval time.Duration
ExpiryThreshold time.Duration
SessionInvalidationGracePeriod time.Duration
}
type OIDCConfig struct {
OnlyStartIfOIDCIsAvailable bool
Issuer string
@ -184,6 +190,7 @@ type OIDCConfig struct {
AllowedGroups []string
Expiry time.Duration
UseExpiryFromToken bool
TokenRefresh TokenRefreshConfig
PKCE PKCEConfig
}
@ -320,6 +327,9 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.token_refresh.check_interval", "15m")
viper.SetDefault("oidc.token_refresh.expiry_threshold", "30m")
viper.SetDefault("oidc.token_refresh.session_invalidation_grace_period", "30m")
viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256")
@ -964,6 +974,11 @@ func LoadServerConfig() (*Config, error) {
}
}(),
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
TokenRefresh: TokenRefreshConfig{
CheckInterval: viper.GetDuration("oidc.token_refresh.check_interval"),
ExpiryThreshold: viper.GetDuration("oidc.token_refresh.expiry_threshold"),
SessionInvalidationGracePeriod: viper.GetDuration("oidc.token_refresh.session_invalidation_grace_period"),
},
PKCE: PKCEConfig{
Enabled: viper.GetBool("oidc.pkce.enabled"),
Method: viper.GetString("oidc.pkce.method"),

View File

@ -0,0 +1,56 @@
package types
import (
"time"
"gorm.io/gorm"
)
// OIDCSession represents an OIDC authentication session linked to a specific node
type OIDCSession struct {
gorm.Model
// Core relationships
NodeID NodeID `gorm:"not null;uniqueIndex"`
Node Node `gorm:"constraint:OnDelete:CASCADE;"`
// Session identification
SessionID string `gorm:"uniqueIndex;not null"`
RegistrationID RegistrationID `gorm:"not null"` // For reusing HandleNodeFromAuthPath
// Token data
RefreshToken string `gorm:"type:text"` // TODO: Encrypt?
// Token lifecycle
TokenExpiry *time.Time `gorm:"index"`
LastRefreshedAt *time.Time
// Session state
IsActive bool `gorm:"default:true;index"`
LastSeenAt *time.Time
}
func (s *OIDCSession) TableName() string {
return "oidc_sessions"
}
// IsExpired checks if the session's token has expired
func (s *OIDCSession) IsExpired() bool {
return s.TokenExpiry != nil && s.TokenExpiry.Before(time.Now())
}
// IsExpiringSoon checks if the session's token will expire within the given duration
func (s *OIDCSession) IsExpiringSoon(duration time.Duration) bool {
return s.TokenExpiry != nil && s.TokenExpiry.Before(time.Now().Add(duration))
}
// Deactivate marks the session as inactive
func (s *OIDCSession) Deactivate() {
s.IsActive = false
}
// UpdateLastSeen updates the last seen timestamp
func (s *OIDCSession) UpdateLastSeen() {
now := time.Now()
s.LastSeenAt = &now
}