mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
rebase and refactor to new state mgmt
This commit is contained in:
parent
da57cf4987
commit
a0007a79b4
@ -330,7 +330,7 @@ func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthP
|
|||||||
case <-gracePeriodTicker.C:
|
case <-gracePeriodTicker.C:
|
||||||
log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period")
|
log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period")
|
||||||
gracePeriod := oidcProvider.cfg.TokenRefresh.SessionInvalidationGracePeriod
|
gracePeriod := oidcProvider.cfg.TokenRefresh.SessionInvalidationGracePeriod
|
||||||
if err := h.db.InvalidateExpiredOIDCSessions(gracePeriod); err != nil {
|
if err := h.state.InvalidateExpiredOIDCSessions(gracePeriod); err != nil {
|
||||||
log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes")
|
log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -941,10 +941,66 @@ AND auth_key_id NOT IN (
|
|||||||
// Create OIDC sessions table for managing OIDC refresh tokens
|
// Create OIDC sessions table for managing OIDC refresh tokens
|
||||||
// This replaces the old OIDC token columns in the users table
|
// This replaces the old OIDC token columns in the users table
|
||||||
if !tx.Migrator().HasTable(&types.OIDCSession{}) {
|
if !tx.Migrator().HasTable(&types.OIDCSession{}) {
|
||||||
err := tx.AutoMigrate(&types.OIDCSession{})
|
// Create the table with database-specific SQL
|
||||||
if err != nil {
|
var sql string
|
||||||
|
if cfg.Type == types.DatabasePostgres {
|
||||||
|
sql = `
|
||||||
|
CREATE TABLE oidc_sessions (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
node_id INTEGER NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
registration_id INTEGER NOT NULL,
|
||||||
|
refresh_token TEXT,
|
||||||
|
token_expiry TIMESTAMP,
|
||||||
|
last_refreshed_at TIMESTAMP,
|
||||||
|
is_active BOOLEAN DEFAULT true,
|
||||||
|
last_seen_at TIMESTAMP,
|
||||||
|
created_at TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP,
|
||||||
|
deleted_at TIMESTAMP,
|
||||||
|
CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
`
|
||||||
|
} else {
|
||||||
|
sql = `
|
||||||
|
CREATE TABLE oidc_sessions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
node_id INTEGER NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
registration_id INTEGER NOT NULL,
|
||||||
|
refresh_token TEXT,
|
||||||
|
token_expiry DATETIME,
|
||||||
|
last_refreshed_at DATETIME,
|
||||||
|
is_active NUMERIC DEFAULT true,
|
||||||
|
last_seen_at DATETIME,
|
||||||
|
created_at DATETIME,
|
||||||
|
updated_at DATETIME,
|
||||||
|
deleted_at DATETIME,
|
||||||
|
CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
`
|
||||||
|
}
|
||||||
|
if err := tx.Exec(sql).Error; err != nil {
|
||||||
return fmt.Errorf("creating OIDC sessions table: %w", err)
|
return fmt.Errorf("creating OIDC sessions table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create indexes matching the struct tags
|
||||||
|
if err := tx.Exec("CREATE UNIQUE INDEX idx_oidc_sessions_node_id ON oidc_sessions(node_id)").Error; err != nil {
|
||||||
|
return fmt.Errorf("creating node_id unique index: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Exec("CREATE UNIQUE INDEX idx_oidc_sessions_session_id ON oidc_sessions(session_id)").Error; err != nil {
|
||||||
|
return fmt.Errorf("creating session_id unique index: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Exec("CREATE INDEX idx_oidc_sessions_token_expiry ON oidc_sessions(token_expiry)").Error; err != nil {
|
||||||
|
return fmt.Errorf("creating token_expiry index: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Exec("CREATE INDEX idx_oidc_sessions_is_active ON oidc_sessions(is_active)").Error; err != nil {
|
||||||
|
return fmt.Errorf("creating is_active index: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Exec("CREATE INDEX idx_oidc_sessions_deleted_at ON oidc_sessions(deleted_at)").Error; err != nil {
|
||||||
|
return fmt.Errorf("creating deleted_at index: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Created OIDC sessions table")
|
log.Debug().Msg("Created OIDC sessions table")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,3 +108,24 @@ CREATE TABLE policies(
|
|||||||
deleted_at datetime
|
deleted_at datetime
|
||||||
);
|
);
|
||||||
CREATE INDEX idx_policies_deleted_at ON policies(deleted_at);
|
CREATE INDEX idx_policies_deleted_at ON policies(deleted_at);
|
||||||
|
|
||||||
|
CREATE TABLE oidc_sessions(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
node_id INTEGER NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
registration_id INTEGER NOT NULL,
|
||||||
|
refresh_token TEXT,
|
||||||
|
token_expiry DATETIME,
|
||||||
|
last_refreshed_at DATETIME,
|
||||||
|
is_active NUMERIC DEFAULT true,
|
||||||
|
last_seen_at DATETIME,
|
||||||
|
created_at DATETIME,
|
||||||
|
updated_at DATETIME,
|
||||||
|
deleted_at DATETIME,
|
||||||
|
CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX idx_oidc_sessions_node_id ON oidc_sessions(node_id);
|
||||||
|
CREATE UNIQUE INDEX idx_oidc_sessions_session_id ON oidc_sessions(session_id);
|
||||||
|
CREATE INDEX idx_oidc_sessions_token_expiry ON oidc_sessions(token_expiry);
|
||||||
|
CREATE INDEX idx_oidc_sessions_is_active ON oidc_sessions(is_active);
|
||||||
|
CREATE INDEX idx_oidc_sessions_deleted_at ON oidc_sessions(deleted_at);
|
||||||
|
@ -439,8 +439,7 @@ func (a *AuthProviderOIDC) createOrUpdateOIDCSession(registrationID types.Regist
|
|||||||
session.LastSeenAt = &now
|
session.LastSeenAt = &now
|
||||||
|
|
||||||
// Try to update existing session first
|
// Try to update existing session first
|
||||||
var existingSession types.OIDCSession
|
existingSession, err := a.state.GetOIDCSessionByNodeID(nodeID)
|
||||||
err = a.db.DB.Where("node_id = ?", nodeID).First(&existingSession).Error
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Update existing session
|
// Update existing session
|
||||||
existingSession.RefreshToken = token.RefreshToken
|
existingSession.RefreshToken = token.RefreshToken
|
||||||
@ -449,16 +448,15 @@ func (a *AuthProviderOIDC) createOrUpdateOIDCSession(registrationID types.Regist
|
|||||||
existingSession.LastRefreshedAt = &now
|
existingSession.LastRefreshedAt = &now
|
||||||
existingSession.LastSeenAt = &now
|
existingSession.LastSeenAt = &now
|
||||||
existingSession.IsActive = true
|
existingSession.IsActive = true
|
||||||
existingSession.RefreshCount = existingSession.RefreshCount + 1
|
|
||||||
|
|
||||||
err = a.db.DB.Save(&existingSession).Error
|
err = a.state.SaveOIDCSession(existingSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update OIDC session: %w", err)
|
return fmt.Errorf("failed to update OIDC session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Create new session
|
// Create new session
|
||||||
err = a.db.DB.Create(session).Error
|
err = a.state.CreateOIDCSession(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create OIDC session: %w", err)
|
return fmt.Errorf("failed to create OIDC session: %w", err)
|
||||||
}
|
}
|
||||||
@ -488,15 +486,14 @@ func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *type
|
|||||||
nodeExpiry := a.determineNodeExpiry(newToken.Expiry)
|
nodeExpiry := a.determineNodeExpiry(newToken.Expiry)
|
||||||
|
|
||||||
// Load the node for logging
|
// Load the node for logging
|
||||||
var node types.Node
|
node, err := a.state.GetNodeWithUser(session.NodeID)
|
||||||
err = a.db.DB.Preload("User").First(&node, session.NodeID).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load node: %w", err)
|
return fmt.Errorf("failed to load node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the node expiry directly for token refresh
|
// Update the node expiry directly for token refresh
|
||||||
// We don't use HandleNodeFromAuthPath for refresh as the node is already registered
|
// We don't use HandleNodeFromAuthPath for refresh as the node is already registered
|
||||||
err = a.db.DB.Model(&node).Update("expiry", nodeExpiry).Error
|
err = a.state.UpdateNodeExpiry(session.NodeID, nodeExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update node expiry: %w", err)
|
return fmt.Errorf("failed to update node expiry: %w", err)
|
||||||
}
|
}
|
||||||
@ -519,10 +516,9 @@ func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *type
|
|||||||
utcExpiry := newToken.Expiry.UTC()
|
utcExpiry := newToken.Expiry.UTC()
|
||||||
session.TokenExpiry = &utcExpiry
|
session.TokenExpiry = &utcExpiry
|
||||||
session.LastRefreshedAt = &now
|
session.LastRefreshedAt = &now
|
||||||
session.RefreshCount = session.RefreshCount + 1
|
|
||||||
|
|
||||||
// Save the updated session
|
// Save the updated session
|
||||||
err = a.db.DB.Save(session).Error
|
err = a.state.SaveOIDCSession(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).
|
log.Error().Err(err).
|
||||||
Str("session_id", session.SessionID).
|
Str("session_id", session.SessionID).
|
||||||
@ -542,11 +538,7 @@ func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error {
|
|||||||
threshold := currentTime.Add(a.cfg.TokenRefresh.ExpiryThreshold)
|
threshold := currentTime.Add(a.cfg.TokenRefresh.ExpiryThreshold)
|
||||||
|
|
||||||
// Only refresh tokens for sessions linked to OIDC-registered nodes
|
// Only refresh tokens for sessions linked to OIDC-registered nodes
|
||||||
err := a.db.DB.Preload("Node").Preload("Node.User").
|
sessions, err := a.state.GetOIDCSessionsNeedingRefresh(threshold, "oidc")
|
||||||
Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
|
|
||||||
Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?",
|
|
||||||
true, threshold, "oidc").
|
|
||||||
Find(&sessions).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to query sessions with expiring tokens: %w", err)
|
return fmt.Errorf("failed to query sessions with expiring tokens: %w", err)
|
||||||
}
|
}
|
||||||
@ -566,7 +558,7 @@ func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error {
|
|||||||
Msg("OIDC: Failed to refresh session, deactivating")
|
Msg("OIDC: Failed to refresh session, deactivating")
|
||||||
// Deactivate the session if refresh fails
|
// Deactivate the session if refresh fails
|
||||||
session.Deactivate()
|
session.Deactivate()
|
||||||
a.db.DB.Save(&session)
|
a.state.SaveOIDCSession(&session)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,26 +3,21 @@ package hscontrol
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
zcache "zgo.at/zcache/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
|
||||||
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
|
||||||
}
|
|
||||||
|
|
||||||
// createTestNode creates a test node for testing
|
// createTestNode creates a test node for testing
|
||||||
func createTestNode(t *testing.T, hsdb *db.HSDatabase, user *types.User, hostname string) *types.Node {
|
func createTestNode(t *testing.T, st *state.State, user *types.User, hostname string) *types.Node {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
@ -41,47 +36,57 @@ func createTestNode(t *testing.T, hsdb *db.HSDatabase, user *types.User, hostnam
|
|||||||
Expiry: &nodeExpiry,
|
Expiry: &nodeExpiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := hsdb.DB.Create(node).Error
|
createdNode, _, err := st.CreateNode(node)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return node
|
return createdNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupTestDB creates a test database
|
// setupTestState creates a test state with database
|
||||||
func setupTestDB(t *testing.T) *db.HSDatabase {
|
func setupTestState(t *testing.T) *state.State {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
hsdb, err := db.NewHeadscaleDatabase(
|
prefixV4, _ := netip.ParsePrefix("100.64.0.0/10")
|
||||||
types.DatabaseConfig{
|
prefixV6, _ := netip.ParsePrefix("fd7a:115c:a1e0::/48")
|
||||||
|
|
||||||
|
cfg := &types.Config{
|
||||||
|
Database: types.DatabaseConfig{
|
||||||
Type: types.DatabaseSqlite,
|
Type: types.DatabaseSqlite,
|
||||||
Sqlite: types.SqliteConfig{
|
Sqlite: types.SqliteConfig{
|
||||||
Path: tmpDir + "/test.db",
|
Path: tmpDir + "/test.db",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"",
|
Policy: types.PolicyConfig{
|
||||||
emptyCache(),
|
Mode: types.PolicyModeDB,
|
||||||
)
|
},
|
||||||
|
BaseDomain: "test.local",
|
||||||
|
PrefixV4: &prefixV4,
|
||||||
|
PrefixV6: &prefixV6,
|
||||||
|
IPAllocation: types.IPAllocationStrategySequential,
|
||||||
|
}
|
||||||
|
|
||||||
|
st, err := state.NewState(cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return hsdb
|
return st
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateOrUpdateOIDCSession(t *testing.T) {
|
func TestCreateOrUpdateOIDCSession(t *testing.T) {
|
||||||
hsdb := setupTestDB(t)
|
st := setupTestState(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
// Create test OIDC provider
|
// Create test OIDC provider
|
||||||
oidcProvider := &AuthProviderOIDC{
|
oidcProvider := &AuthProviderOIDC{
|
||||||
db: hsdb,
|
state: st,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test user
|
// Create test user
|
||||||
user := &types.User{
|
user := &types.User{
|
||||||
Model: gorm.Model{ID: 1},
|
Name: "testuser",
|
||||||
Name: "testuser",
|
|
||||||
}
|
}
|
||||||
err := hsdb.DB.Create(user).Error
|
createdUser, _, err := st.CreateUser(*user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create test node
|
// Create test node
|
||||||
@ -95,11 +100,11 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) {
|
|||||||
DiscoKey: discoKey.Public(),
|
DiscoKey: discoKey.Public(),
|
||||||
Hostname: "test-node",
|
Hostname: "test-node",
|
||||||
GivenName: "test-node",
|
GivenName: "test-node",
|
||||||
UserID: user.ID,
|
UserID: createdUser.ID,
|
||||||
RegisterMethod: util.RegisterMethodOIDC,
|
RegisterMethod: util.RegisterMethodOIDC,
|
||||||
Expiry: &nodeExpiry,
|
Expiry: &nodeExpiry,
|
||||||
}
|
}
|
||||||
err = hsdb.DB.Create(node).Error
|
createdNode, _, err := st.CreateNode(node)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@ -153,7 +158,7 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
err := oidcProvider.createOrUpdateOIDCSession(tt.registrationID, tt.token, node.ID)
|
err := oidcProvider.createOrUpdateOIDCSession(tt.registrationID, tt.token, createdNode.ID)
|
||||||
|
|
||||||
if tt.expectError {
|
if tt.expectError {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@ -163,8 +168,7 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) {
|
|||||||
|
|
||||||
if tt.expectSession && tt.token.RefreshToken != "" {
|
if tt.expectSession && tt.token.RefreshToken != "" {
|
||||||
// Verify session was created/updated
|
// Verify session was created/updated
|
||||||
var session types.OIDCSession
|
session, err := st.GetOIDCSessionByNodeID(createdNode.ID)
|
||||||
err = hsdb.DB.Where("node_id = ?", node.ID).First(&session).Error
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, tt.token.RefreshToken, session.RefreshToken)
|
assert.Equal(t, tt.token.RefreshToken, session.RefreshToken)
|
||||||
assert.True(t, session.IsActive)
|
assert.True(t, session.IsActive)
|
||||||
@ -334,23 +338,26 @@ func TestDetermineNodeExpiry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshExpiredTokens(t *testing.T) {
|
func TestRefreshExpiredTokens(t *testing.T) {
|
||||||
hsdb := setupTestDB(t)
|
st := setupTestState(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
// Create test OIDC provider
|
// Create test OIDC provider
|
||||||
oidcProvider := &AuthProviderOIDC{
|
oidcProvider := &AuthProviderOIDC{
|
||||||
db: hsdb,
|
state: st,
|
||||||
cfg: &types.OIDCConfig{
|
cfg: &types.OIDCConfig{
|
||||||
Issuer: "https://test.example.com",
|
Issuer: "https://test.example.com",
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
|
TokenRefresh: types.TokenRefreshConfig{
|
||||||
|
ExpiryThreshold: 5 * time.Minute,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test user
|
// Create test user
|
||||||
user := &types.User{
|
user := &types.User{
|
||||||
Model: gorm.Model{ID: 1},
|
Name: "testuser",
|
||||||
Name: "testuser",
|
|
||||||
}
|
}
|
||||||
err := hsdb.DB.Create(user).Error
|
createdUser, _, err := st.CreateUser(*user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
@ -364,9 +371,9 @@ func TestRefreshExpiredTokens(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "no sessions need refresh",
|
name: "no sessions need refresh",
|
||||||
setupSession: func() *types.OIDCSession {
|
setupSession: func() *types.OIDCSession {
|
||||||
node := createTestNode(t, hsdb, user, "test-node-1")
|
node := createTestNode(t, st, createdUser, "test-node-1")
|
||||||
return &types.OIDCSession{
|
return &types.OIDCSession{
|
||||||
NodeID: types.NodeID(node.ID),
|
NodeID: node.ID,
|
||||||
SessionID: "valid-session",
|
SessionID: "valid-session",
|
||||||
RegistrationID: types.RegistrationID("reg-123"),
|
RegistrationID: types.RegistrationID("reg-123"),
|
||||||
RefreshToken: "refresh-token",
|
RefreshToken: "refresh-token",
|
||||||
@ -380,9 +387,9 @@ func TestRefreshExpiredTokens(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "session needs refresh but no refresh token",
|
name: "session needs refresh but no refresh token",
|
||||||
setupSession: func() *types.OIDCSession {
|
setupSession: func() *types.OIDCSession {
|
||||||
node := createTestNode(t, hsdb, user, "test-node-2")
|
node := createTestNode(t, st, createdUser, "test-node-2")
|
||||||
return &types.OIDCSession{
|
return &types.OIDCSession{
|
||||||
NodeID: types.NodeID(node.ID),
|
NodeID: node.ID,
|
||||||
SessionID: "no-token-session",
|
SessionID: "no-token-session",
|
||||||
RegistrationID: types.RegistrationID("reg-456"),
|
RegistrationID: types.RegistrationID("reg-456"),
|
||||||
RefreshToken: "", // No refresh token
|
RefreshToken: "", // No refresh token
|
||||||
@ -396,9 +403,9 @@ func TestRefreshExpiredTokens(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "valid token should be ignored",
|
name: "valid token should be ignored",
|
||||||
setupSession: func() *types.OIDCSession {
|
setupSession: func() *types.OIDCSession {
|
||||||
node := createTestNode(t, hsdb, user, "test-node-3")
|
node := createTestNode(t, st, createdUser, "test-node-3")
|
||||||
return &types.OIDCSession{
|
return &types.OIDCSession{
|
||||||
NodeID: types.NodeID(node.ID),
|
NodeID: node.ID,
|
||||||
SessionID: "valid-token-session",
|
SessionID: "valid-token-session",
|
||||||
RegistrationID: types.RegistrationID("reg-789"),
|
RegistrationID: types.RegistrationID("reg-789"),
|
||||||
RefreshToken: "refresh-token",
|
RefreshToken: "refresh-token",
|
||||||
@ -421,15 +428,11 @@ func TestRefreshExpiredTokens(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Clean up sessions from previous tests
|
|
||||||
err := hsdb.DB.Where("1 = 1").Delete(&types.OIDCSession{}).Error
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Setup test session if needed
|
// Setup test session if needed
|
||||||
if tt.setupSession != nil {
|
if tt.setupSession != nil {
|
||||||
session := tt.setupSession()
|
session := tt.setupSession()
|
||||||
if session != nil {
|
if session != nil {
|
||||||
err := hsdb.DB.Create(session).Error
|
err := st.CreateOIDCSession(session)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -456,11 +459,12 @@ func TestRefreshExpiredTokens(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshOIDCSessionValidation(t *testing.T) {
|
func TestRefreshOIDCSessionValidation(t *testing.T) {
|
||||||
hsdb := setupTestDB(t)
|
st := setupTestState(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
// Create test OIDC provider
|
// Create test OIDC provider
|
||||||
oidcProvider := &AuthProviderOIDC{
|
oidcProvider := &AuthProviderOIDC{
|
||||||
db: hsdb,
|
state: st,
|
||||||
cfg: &types.OIDCConfig{
|
cfg: &types.OIDCConfig{
|
||||||
Issuer: "https://test.example.com",
|
Issuer: "https://test.example.com",
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
|
@ -835,3 +835,86 @@ func (s *State) autoApproveNodes() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline
|
||||||
|
// for longer than the configured grace period.
|
||||||
|
func (s *State) InvalidateExpiredOIDCSessions(offlineGracePeriod time.Duration) error {
|
||||||
|
return s.db.InvalidateExpiredOIDCSessions(offlineGracePeriod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCSessionByNodeID retrieves an OIDC session by node ID.
|
||||||
|
func (s *State) GetOIDCSessionByNodeID(nodeID types.NodeID) (*types.OIDCSession, error) {
|
||||||
|
var session types.OIDCSession
|
||||||
|
err := s.db.Read(func(tx *gorm.DB) error {
|
||||||
|
return tx.Where("node_id = ?", nodeID).First(&session).Error
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveOIDCSession saves or updates an OIDC session.
|
||||||
|
func (s *State) SaveOIDCSession(session *types.OIDCSession) error {
|
||||||
|
return s.db.Write(func(tx *gorm.DB) error {
|
||||||
|
return tx.Save(session).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOIDCSession creates a new OIDC session.
|
||||||
|
func (s *State) CreateOIDCSession(session *types.OIDCSession) error {
|
||||||
|
return s.db.Write(func(tx *gorm.DB) error {
|
||||||
|
return tx.Create(session).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateNodeExpiry updates a node's expiry time.
|
||||||
|
func (s *State) UpdateNodeExpiry(nodeID types.NodeID, expiry time.Time) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.db.Write(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update in-memory state
|
||||||
|
for i := range s.nodes {
|
||||||
|
if s.nodes[i].ID == nodeID {
|
||||||
|
s.nodes[i].Expiry = &expiry
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodeWithUser retrieves a node with its associated user.
|
||||||
|
func (s *State) GetNodeWithUser(nodeID types.NodeID) (*types.Node, error) {
|
||||||
|
var node types.Node
|
||||||
|
err := s.db.Read(func(tx *gorm.DB) error {
|
||||||
|
return tx.Preload("User").First(&node, nodeID).Error
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCSessionsNeedingRefresh retrieves OIDC sessions that need token refresh.
|
||||||
|
func (s *State) GetOIDCSessionsNeedingRefresh(threshold time.Time, registerMethod string) ([]types.OIDCSession, error) {
|
||||||
|
var sessions []types.OIDCSession
|
||||||
|
err := s.db.Read(func(tx *gorm.DB) error {
|
||||||
|
return tx.Preload("Node").Preload("Node.User").
|
||||||
|
Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id").
|
||||||
|
Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?",
|
||||||
|
true, threshold, registerMethod).
|
||||||
|
Find(&sessions).Error
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return sessions, nil
|
||||||
|
}
|
||||||
|
@ -24,7 +24,6 @@ type OIDCSession struct {
|
|||||||
// Token lifecycle
|
// Token lifecycle
|
||||||
TokenExpiry *time.Time `gorm:"index"`
|
TokenExpiry *time.Time `gorm:"index"`
|
||||||
LastRefreshedAt *time.Time
|
LastRefreshedAt *time.Time
|
||||||
RefreshCount int `gorm:"default:0"`
|
|
||||||
|
|
||||||
// Session state
|
// Session state
|
||||||
IsActive bool `gorm:"default:true;index"`
|
IsActive bool `gorm:"default:true;index"`
|
||||||
|
Loading…
Reference in New Issue
Block a user