1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-24 13:46:53 +02:00

few improvements: clean database table, more cfg, exact db queries by nodeID

This commit is contained in:
Mazlum Toprak 2025-07-24 11:14:39 +02:00
parent fd8bd3f6a6
commit 46816c8a1c
7 changed files with 62 additions and 60 deletions

View File

@ -355,10 +355,18 @@ unix_socket_permission: "0770"
# # Note: enabling this will cause `oidc.expiry` to be ignored. # # Note: enabling this will cause `oidc.expiry` to be ignored.
# use_expiry_from_token: false # use_expiry_from_token: false
# #
# # Grace period for invalidating sessions of nodes that have been offline # # Token refresh and session management configuration
# # Sessions for nodes offline longer than this duration will be invalidated (new SSO login required). # token_refresh:
# # Default: 30m # # How often to check for tokens needing refresh (default: 15m)
# session_invalidation_grace_period: 30m # check_interval: 15m
#
# # Refresh tokens this far before expiry (default: 30m)
# expiry_threshold: 30m
#
# # Grace period for invalidating sessions of nodes that have been offline
# # Sessions for nodes offline longer than this duration will be invalidated (new SSO login required).
# # (default: 30m)
# session_invalidation_grace_period: 30m
# #
# # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query # # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". # # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".

View File

@ -303,12 +303,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
} }
func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthProviderOIDC) { func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthProviderOIDC) {
refreshTicker := time.NewTicker(15 * time.Minute) checkInterval := oidcProvider.cfg.TokenRefresh.CheckInterval
gracePeriodTicker := time.NewTicker(15 * time.Minute) refreshTicker := time.NewTicker(checkInterval)
gracePeriodTicker := time.NewTicker(checkInterval)
defer refreshTicker.Stop() defer refreshTicker.Stop()
defer gracePeriodTicker.Stop() defer gracePeriodTicker.Stop()
log.Info().Msg("OIDC: Background token refresh job started (checking every 15 minute for tokens expiring within 30 minutes)") log.Info().Msgf("OIDC: Background token refresh job started (checking every %v for tokens expiring within %v)",
checkInterval, oidcProvider.cfg.TokenRefresh.ExpiryThreshold)
for { for {
select { select {
@ -327,7 +329,7 @@ func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthP
// Invalidate sessions for nodes that have been offline for longer than the configured grace period // Invalidate sessions for nodes that have been offline for longer than the configured grace period
case <-gracePeriodTicker.C: case <-gracePeriodTicker.C:
log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period") log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period")
gracePeriod := oidcProvider.cfg.SessionInvalidationGracePeriod gracePeriod := oidcProvider.cfg.TokenRefresh.SessionInvalidationGracePeriod
if err := h.db.InvalidateExpiredOIDCSessions(gracePeriod); err != nil { if err := h.db.InvalidateExpiredOIDCSessions(gracePeriod); err != nil {
log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes") log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes")
} }

View File

@ -38,7 +38,6 @@ func (*Suite) TestInvalidateOIDCSessionsForNode(c *check.C) {
registrationID := types.RegistrationID("test-reg-id-1") registrationID := types.RegistrationID("test-reg-id-1")
tokenExpiry := time.Now().Add(1 * time.Hour) tokenExpiry := time.Now().Add(1 * time.Hour)
session := &types.OIDCSession{ session := &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node.ID), NodeID: types.NodeID(node.ID),
SessionID: sessionID, SessionID: sessionID,
RegistrationID: registrationID, RegistrationID: registrationID,
@ -113,7 +112,6 @@ func (*Suite) TestInvalidateExpiredOIDCSessions(c *check.C) {
node1.LastSeen = &lastSeen1 node1.LastSeen = &lastSeen1
db.DB.Save(node1) db.DB.Save(node1)
session1 := &types.OIDCSession{ session1 := &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node1.ID), NodeID: types.NodeID(node1.ID),
SessionID: "expired-session-1", SessionID: "expired-session-1",
RegistrationID: types.RegistrationID("reg-1"), RegistrationID: types.RegistrationID("reg-1"),
@ -129,7 +127,6 @@ func (*Suite) TestInvalidateExpiredOIDCSessions(c *check.C) {
node2.LastSeen = &lastSeen2 node2.LastSeen = &lastSeen2
db.DB.Save(node2) db.DB.Save(node2)
session2 := &types.OIDCSession{ session2 := &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node2.ID), NodeID: types.NodeID(node2.ID),
SessionID: "expired-session-2", SessionID: "expired-session-2",
RegistrationID: types.RegistrationID("reg-2"), RegistrationID: types.RegistrationID("reg-2"),
@ -160,7 +157,6 @@ func (*Suite) TestInvalidateExpiredOIDCSessions(c *check.C) {
// Session 3: Valid token // Session 3: Valid token
validTime := now.Add(1 * time.Hour) validTime := now.Add(1 * time.Hour)
session3 := &types.OIDCSession{ session3 := &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node3.ID), NodeID: types.NodeID(node3.ID),
SessionID: "valid-session", SessionID: "valid-session",
RegistrationID: types.RegistrationID("reg-3"), RegistrationID: types.RegistrationID("reg-3"),
@ -224,7 +220,6 @@ func (*Suite) TestInvalidateExpiredOIDCSessionsWithNoExpired(c *check.C) {
// Create only valid sessions // Create only valid sessions
validTime := time.Now().Add(24 * time.Hour) validTime := time.Now().Add(24 * time.Hour)
session := &types.OIDCSession{ session := &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node.ID), NodeID: types.NodeID(node.ID),
SessionID: "valid-only-session", SessionID: "valid-only-session",
RegistrationID: types.RegistrationID("reg-valid"), RegistrationID: types.RegistrationID("reg-valid"),
@ -275,7 +270,6 @@ func (*Suite) TestInvalidateOIDCSessionsTransaction(c *check.C) {
tokenExpiry := time.Now().Add(1 * time.Hour) tokenExpiry := time.Now().Add(1 * time.Hour)
sessionID := fmt.Sprintf("session-%d", i) sessionID := fmt.Sprintf("session-%d", i)
session := &types.OIDCSession{ session := &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node.ID), NodeID: types.NodeID(node.ID),
SessionID: sessionID, SessionID: sessionID,
RegistrationID: types.RegistrationID(fmt.Sprintf("reg-%d", i)), RegistrationID: types.RegistrationID(fmt.Sprintf("reg-%d", i)),

View File

@ -331,7 +331,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist. // Register the node if it does not exist.
if registrationId != nil { if registrationId != nil {
verb := "Reauthenticated" verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) nodeID, newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil { if err != nil {
httpError(writer, err) httpError(writer, err)
return return
@ -342,7 +342,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
} }
// Create or update OIDC session for this node // Create or update OIDC session for this node
if err := a.createOrUpdateOIDCSession(user, *registrationId, oauth2Token, nodeExpiry); err != nil { if err := a.createOrUpdateOIDCSession(*registrationId, oauth2Token, *nodeID); err != nil {
log.Error().Err(err).Msg("Failed to create OIDC session") log.Error().Err(err).Msg("Failed to create OIDC session")
// Don't fail the auth flow, just log the error // Don't fail the auth flow, just log the error
} }
@ -366,8 +366,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning // Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node. // that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
} }
func extractCodeAndStateParamFromRequest( func extractCodeAndStateParamFromRequest(
@ -410,23 +408,16 @@ func (a *AuthProviderOIDC) getOauth2Token(
} }
// createOrUpdateOIDCSession creates or updates an OIDC session for a node // createOrUpdateOIDCSession creates or updates an OIDC session for a node
func (a *AuthProviderOIDC) createOrUpdateOIDCSession(user *types.User, registrationID types.RegistrationID, token *oauth2.Token, nodeExpiry time.Time) error { func (a *AuthProviderOIDC) createOrUpdateOIDCSession(registrationID types.RegistrationID, token *oauth2.Token, nodeID types.NodeID) error {
if token.RefreshToken == "" { if token.RefreshToken == "" {
log.Warn(). log.Warn().
Str("user", user.Username()). Str("node_id", nodeID.String()).
Str("registration_id", registrationID.String()). Str("registration_id", registrationID.String()).
Msg("OIDC: No refresh token in OAuth2 token, skipping session creation (check offline_access scope)") Msg("OIDC: No refresh token in OAuth2 token, skipping session creation (check offline_access scope)")
return nil return nil
} }
// Find the node in the database - we'll get the most recent OIDC node for this user
var node types.Node
err := a.db.DB.Where("user_id = ? AND register_method = ?", user.ID, util.RegisterMethodOIDC).Order("created_at DESC").First(&node).Error
if err != nil {
return fmt.Errorf("failed to find node: %w", err)
}
// Generate session ID // Generate session ID
sessionID, err := util.GenerateRandomStringURLSafe(32) sessionID, err := util.GenerateRandomStringURLSafe(32)
if err != nil { if err != nil {
@ -436,8 +427,7 @@ func (a *AuthProviderOIDC) createOrUpdateOIDCSession(user *types.User, registrat
// Create or update session // Create or update session
tokenExpiryUTC := token.Expiry.UTC() tokenExpiryUTC := token.Expiry.UTC()
session := &types.OIDCSession{ session := &types.OIDCSession{
UserID: user.ID, NodeID: nodeID,
NodeID: node.ID,
SessionID: sessionID, SessionID: sessionID,
RegistrationID: registrationID, RegistrationID: registrationID,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
@ -451,7 +441,7 @@ func (a *AuthProviderOIDC) createOrUpdateOIDCSession(user *types.User, registrat
// Try to update existing session first // Try to update existing session first
var existingSession types.OIDCSession var existingSession types.OIDCSession
err = a.db.DB.Where("user_id = ? AND node_id = ?", user.ID, node.ID).First(&existingSession).Error err = a.db.DB.Where("node_id = ?", nodeID).First(&existingSession).Error
if err == nil { if err == nil {
// Update existing session // Update existing session
existingSession.RefreshToken = token.RefreshToken existingSession.RefreshToken = token.RefreshToken
@ -549,12 +539,13 @@ func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *type
func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error { func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error {
var sessions []types.OIDCSession var sessions []types.OIDCSession
// Find active sessions with tokens expiring in the next 30 minutes // Find active sessions with tokens expiring in the configured threshold time
currentTime := time.Now().UTC() currentTime := time.Now().UTC()
threshold := currentTime.Add(30 * time.Minute) threshold := currentTime.Add(a.cfg.TokenRefresh.ExpiryThreshold)
// Only refresh tokens for sessions linked to OIDC-registered nodes // Only refresh tokens for sessions linked to OIDC-registered nodes
err := a.db.DB.Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id"). err := a.db.DB.Preload("Node").Preload("Node.User").
Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?", Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?",
true, threshold, "oidc"). true, threshold, "oidc").
Find(&sessions).Error Find(&sessions).Error
@ -572,7 +563,7 @@ func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error {
if err := a.RefreshOIDCSession(ctx, &session); err != nil { if err := a.RefreshOIDCSession(ctx, &session); err != nil {
log.Error().Err(err). log.Error().Err(err).
Str("session_id", session.SessionID). Str("session_id", session.SessionID).
Uint("user_id", session.UserID). Uint("user_id", session.Node.UserID).
Uint64("node_id", uint64(session.NodeID)). Uint64("node_id", uint64(session.NodeID)).
Msg("OIDC: Failed to refresh session, deactivating") Msg("OIDC: Failed to refresh session, deactivating")
// Deactivate the session if refresh fails // Deactivate the session if refresh fails
@ -707,7 +698,7 @@ func (a *AuthProviderOIDC) handleRegistration(
user *types.User, user *types.User,
registrationID types.RegistrationID, registrationID types.RegistrationID,
expiry time.Time, expiry time.Time,
) (bool, error) { ) (*types.NodeID, bool, error) {
node, newNode, err := a.state.HandleNodeFromAuthPath( node, newNode, err := a.state.HandleNodeFromAuthPath(
registrationID, registrationID,
types.UserID(user.ID), types.UserID(user.ID),
@ -715,7 +706,7 @@ func (a *AuthProviderOIDC) handleRegistration(
util.RegisterMethodOIDC, util.RegisterMethodOIDC,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("could not register node: %w", err) return nil, false, fmt.Errorf("could not register node: %w", err)
} }
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
@ -732,7 +723,7 @@ func (a *AuthProviderOIDC) handleRegistration(
routesChanged := a.state.AutoApproveRoutes(node) routesChanged := a.state.AutoApproveRoutes(node)
_, policyChanged, err := a.state.SaveNode(node) _, policyChanged, err := a.state.SaveNode(node)
if err != nil { if err != nil {
return false, fmt.Errorf("saving auto approved routes to node: %w", err) return nil, false, fmt.Errorf("saving auto approved routes to node: %w", err)
} }
// Send policy update notifications if needed (from SaveNode or route changes) // Send policy update notifications if needed (from SaveNode or route changes)
@ -753,7 +744,7 @@ func (a *AuthProviderOIDC) handleRegistration(
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
} }
return newNode, nil return &node.ID, newNode, nil
} }
// TODO(kradalby): // TODO(kradalby):

View File

@ -153,7 +153,7 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := oidcProvider.createOrUpdateOIDCSession(tt.user, tt.registrationID, tt.token, tt.nodeExpiry) err := oidcProvider.createOrUpdateOIDCSession(tt.user, tt.registrationID, tt.token, node.ID)
if tt.expectError { if tt.expectError {
assert.Error(t, err) assert.Error(t, err)
@ -366,7 +366,6 @@ func TestRefreshExpiredTokens(t *testing.T) {
setupSession: func() *types.OIDCSession { setupSession: func() *types.OIDCSession {
node := createTestNode(t, hsdb, user, "test-node-1") node := createTestNode(t, hsdb, user, "test-node-1")
return &types.OIDCSession{ return &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node.ID), NodeID: types.NodeID(node.ID),
SessionID: "valid-session", SessionID: "valid-session",
RegistrationID: types.RegistrationID("reg-123"), RegistrationID: types.RegistrationID("reg-123"),
@ -383,7 +382,6 @@ func TestRefreshExpiredTokens(t *testing.T) {
setupSession: func() *types.OIDCSession { setupSession: func() *types.OIDCSession {
node := createTestNode(t, hsdb, user, "test-node-2") node := createTestNode(t, hsdb, user, "test-node-2")
return &types.OIDCSession{ return &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node.ID), NodeID: types.NodeID(node.ID),
SessionID: "no-token-session", SessionID: "no-token-session",
RegistrationID: types.RegistrationID("reg-456"), RegistrationID: types.RegistrationID("reg-456"),
@ -400,7 +398,6 @@ func TestRefreshExpiredTokens(t *testing.T) {
setupSession: func() *types.OIDCSession { setupSession: func() *types.OIDCSession {
node := createTestNode(t, hsdb, user, "test-node-3") node := createTestNode(t, hsdb, user, "test-node-3")
return &types.OIDCSession{ return &types.OIDCSession{
UserID: user.ID,
NodeID: types.NodeID(node.ID), NodeID: types.NodeID(node.ID),
SessionID: "valid-token-session", SessionID: "valid-token-session",
RegistrationID: types.RegistrationID("reg-789"), RegistrationID: types.RegistrationID("reg-789"),

View File

@ -172,20 +172,26 @@ type PKCEConfig struct {
Method string Method string
} }
type OIDCConfig struct { type TokenRefreshConfig struct {
OnlyStartIfOIDCIsAvailable bool CheckInterval time.Duration
Issuer string ExpiryThreshold time.Duration
ClientID string
ClientSecret string
Scope []string
ExtraParams map[string]string
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
Expiry time.Duration
UseExpiryFromToken bool
SessionInvalidationGracePeriod time.Duration SessionInvalidationGracePeriod time.Duration
PKCE PKCEConfig }
type OIDCConfig struct {
OnlyStartIfOIDCIsAvailable bool
Issuer string
ClientID string
ClientSecret string
Scope []string
ExtraParams map[string]string
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
Expiry time.Duration
UseExpiryFromToken bool
TokenRefresh TokenRefreshConfig
PKCE PKCEConfig
} }
type DERPConfig struct { type DERPConfig struct {
@ -321,7 +327,9 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.session_invalidation_grace_period", "30m") viper.SetDefault("oidc.token_refresh.check_interval", "15m")
viper.SetDefault("oidc.token_refresh.expiry_threshold", "30m")
viper.SetDefault("oidc.token_refresh.session_invalidation_grace_period", "30m")
viper.SetDefault("oidc.pkce.enabled", false) viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256") viper.SetDefault("oidc.pkce.method", "S256")
@ -965,8 +973,12 @@ func LoadServerConfig() (*Config, error) {
return time.Duration(expiry) return time.Duration(expiry)
} }
}(), }(),
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
SessionInvalidationGracePeriod: viper.GetDuration("oidc.session_invalidation_grace_period"), TokenRefresh: TokenRefreshConfig{
CheckInterval: viper.GetDuration("oidc.token_refresh.check_interval"),
ExpiryThreshold: viper.GetDuration("oidc.token_refresh.expiry_threshold"),
SessionInvalidationGracePeriod: viper.GetDuration("oidc.token_refresh.session_invalidation_grace_period"),
},
PKCE: PKCEConfig{ PKCE: PKCEConfig{
Enabled: viper.GetBool("oidc.pkce.enabled"), Enabled: viper.GetBool("oidc.pkce.enabled"),
Method: viper.GetString("oidc.pkce.method"), Method: viper.GetString("oidc.pkce.method"),

View File

@ -11,8 +11,6 @@ type OIDCSession struct {
gorm.Model gorm.Model
// Core relationships // Core relationships
UserID uint `gorm:"not null;index"`
User User `gorm:"constraint:OnDelete:CASCADE;"`
NodeID NodeID `gorm:"not null;uniqueIndex"` NodeID NodeID `gorm:"not null;uniqueIndex"`
Node Node `gorm:"constraint:OnDelete:CASCADE;"` Node Node `gorm:"constraint:OnDelete:CASCADE;"`