diff --git a/hscontrol/app.go b/hscontrol/app.go index 6fab0fae..ae9aea99 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -330,7 +330,7 @@ func (h *Headscale) oidcTokenRefreshJob(ctx context.Context, oidcProvider *AuthP case <-gracePeriodTicker.C: log.Debug().Msg("OIDC: Checking for nodes offline beyond grace period") gracePeriod := oidcProvider.cfg.TokenRefresh.SessionInvalidationGracePeriod - if err := h.db.InvalidateExpiredOIDCSessions(gracePeriod); err != nil { + if err := h.state.InvalidateExpiredOIDCSessions(gracePeriod); err != nil { log.Error().Err(err).Msg("OIDC: Failed to invalidate sessions for offline nodes") } } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index ba045d4d..3a357b7b 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -941,10 +941,66 @@ AND auth_key_id NOT IN ( // Create OIDC sessions table for managing OIDC refresh tokens // This replaces the old OIDC token columns in the users table if !tx.Migrator().HasTable(&types.OIDCSession{}) { - err := tx.AutoMigrate(&types.OIDCSession{}) - if err != nil { + // Create the table with database-specific SQL + var sql string + if cfg.Type == types.DatabasePostgres { + sql = ` + CREATE TABLE oidc_sessions ( + id SERIAL PRIMARY KEY, + node_id INTEGER NOT NULL, + session_id TEXT NOT NULL, + registration_id INTEGER NOT NULL, + refresh_token TEXT, + token_expiry TIMESTAMP, + last_refreshed_at TIMESTAMP, + is_active BOOLEAN DEFAULT true, + last_seen_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP, + deleted_at TIMESTAMP, + CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE + ) + ` + } else { + sql = ` + CREATE TABLE oidc_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + session_id TEXT NOT NULL, + registration_id INTEGER NOT NULL, + refresh_token TEXT, + token_expiry DATETIME, + last_refreshed_at DATETIME, + is_active NUMERIC DEFAULT true, + last_seen_at DATETIME, + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME, + CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE + ) + ` + } + if err := tx.Exec(sql).Error; err != nil { return fmt.Errorf("creating OIDC sessions table: %w", err) } + + // Create indexes matching the struct tags + if err := tx.Exec("CREATE UNIQUE INDEX idx_oidc_sessions_node_id ON oidc_sessions(node_id)").Error; err != nil { + return fmt.Errorf("creating node_id unique index: %w", err) + } + if err := tx.Exec("CREATE UNIQUE INDEX idx_oidc_sessions_session_id ON oidc_sessions(session_id)").Error; err != nil { + return fmt.Errorf("creating session_id unique index: %w", err) + } + if err := tx.Exec("CREATE INDEX idx_oidc_sessions_token_expiry ON oidc_sessions(token_expiry)").Error; err != nil { + return fmt.Errorf("creating token_expiry index: %w", err) + } + if err := tx.Exec("CREATE INDEX idx_oidc_sessions_is_active ON oidc_sessions(is_active)").Error; err != nil { + return fmt.Errorf("creating is_active index: %w", err) + } + if err := tx.Exec("CREATE INDEX idx_oidc_sessions_deleted_at ON oidc_sessions(deleted_at)").Error; err != nil { + return fmt.Errorf("creating deleted_at index: %w", err) + } + log.Debug().Msg("Created OIDC sessions table") } diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index 175e2aff..c72952fd 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -108,3 +108,24 @@ CREATE TABLE policies( deleted_at datetime ); CREATE INDEX idx_policies_deleted_at ON policies(deleted_at); + +CREATE TABLE oidc_sessions( + id INTEGER PRIMARY KEY AUTOINCREMENT, + node_id INTEGER NOT NULL, + session_id TEXT NOT NULL, + registration_id INTEGER NOT NULL, + refresh_token TEXT, + token_expiry DATETIME, + last_refreshed_at DATETIME, + is_active NUMERIC DEFAULT true, + last_seen_at DATETIME, + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME, + CONSTRAINT fk_oidc_sessions_node FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE +); +CREATE UNIQUE INDEX idx_oidc_sessions_node_id ON oidc_sessions(node_id); +CREATE UNIQUE INDEX idx_oidc_sessions_session_id ON oidc_sessions(session_id); +CREATE INDEX idx_oidc_sessions_token_expiry ON oidc_sessions(token_expiry); +CREATE INDEX idx_oidc_sessions_is_active ON oidc_sessions(is_active); +CREATE INDEX idx_oidc_sessions_deleted_at ON oidc_sessions(deleted_at); diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index a3d179eb..d6436172 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -439,8 +439,7 @@ func (a *AuthProviderOIDC) createOrUpdateOIDCSession(registrationID types.Regist session.LastSeenAt = &now // Try to update existing session first - var existingSession types.OIDCSession - err = a.db.DB.Where("node_id = ?", nodeID).First(&existingSession).Error + existingSession, err := a.state.GetOIDCSessionByNodeID(nodeID) if err == nil { // Update existing session existingSession.RefreshToken = token.RefreshToken @@ -449,16 +448,15 @@ func (a *AuthProviderOIDC) createOrUpdateOIDCSession(registrationID types.Regist existingSession.LastRefreshedAt = &now existingSession.LastSeenAt = &now existingSession.IsActive = true - existingSession.RefreshCount = existingSession.RefreshCount + 1 - err = a.db.DB.Save(&existingSession).Error + err = a.state.SaveOIDCSession(existingSession) if err != nil { return fmt.Errorf("failed to update OIDC session: %w", err) } } else { // Create new session - err = a.db.DB.Create(session).Error + err = a.state.CreateOIDCSession(session) if err != nil { return fmt.Errorf("failed to create OIDC session: %w", err) } @@ -488,15 +486,14 @@ func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *type nodeExpiry := a.determineNodeExpiry(newToken.Expiry) // Load the node for logging - var node types.Node - err = a.db.DB.Preload("User").First(&node, session.NodeID).Error + node, err := a.state.GetNodeWithUser(session.NodeID) if err != nil { return fmt.Errorf("failed to load node: %w", err) } // Update the node expiry directly for token refresh // We don't use HandleNodeFromAuthPath for refresh as the node is already registered - err = a.db.DB.Model(&node).Update("expiry", nodeExpiry).Error + err = a.state.UpdateNodeExpiry(session.NodeID, nodeExpiry) if err != nil { return fmt.Errorf("failed to update node expiry: %w", err) } @@ -519,10 +516,9 @@ func (a *AuthProviderOIDC) RefreshOIDCSession(ctx context.Context, session *type utcExpiry := newToken.Expiry.UTC() session.TokenExpiry = &utcExpiry session.LastRefreshedAt = &now - session.RefreshCount = session.RefreshCount + 1 // Save the updated session - err = a.db.DB.Save(session).Error + err = a.state.SaveOIDCSession(session) if err != nil { log.Error().Err(err). Str("session_id", session.SessionID). @@ -542,11 +538,7 @@ func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error { threshold := currentTime.Add(a.cfg.TokenRefresh.ExpiryThreshold) // Only refresh tokens for sessions linked to OIDC-registered nodes - err := a.db.DB.Preload("Node").Preload("Node.User"). - Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id"). - Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?", - true, threshold, "oidc"). - Find(&sessions).Error + sessions, err := a.state.GetOIDCSessionsNeedingRefresh(threshold, "oidc") if err != nil { return fmt.Errorf("failed to query sessions with expiring tokens: %w", err) } @@ -566,7 +558,7 @@ func (a *AuthProviderOIDC) RefreshExpiredTokens(ctx context.Context) error { Msg("OIDC: Failed to refresh session, deactivating") // Deactivate the session if refresh fails session.Deactivate() - a.db.DB.Save(&session) + a.state.SaveOIDCSession(&session) continue } } diff --git a/hscontrol/oidc_test.go b/hscontrol/oidc_test.go index 106da434..0fbbccd2 100644 --- a/hscontrol/oidc_test.go +++ b/hscontrol/oidc_test.go @@ -3,26 +3,21 @@ package hscontrol import ( "context" "fmt" + "net/netip" "testing" "time" - "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "gorm.io/gorm" "tailscale.com/types/key" - zcache "zgo.at/zcache/v2" ) -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) -} - // createTestNode creates a test node for testing -func createTestNode(t *testing.T, hsdb *db.HSDatabase, user *types.User, hostname string) *types.Node { +func createTestNode(t *testing.T, st *state.State, user *types.User, hostname string) *types.Node { t.Helper() nodeKey := key.NewNode() @@ -41,47 +36,57 @@ func createTestNode(t *testing.T, hsdb *db.HSDatabase, user *types.User, hostnam Expiry: &nodeExpiry, } - err := hsdb.DB.Create(node).Error + createdNode, _, err := st.CreateNode(node) require.NoError(t, err) - return node + return createdNode } -// setupTestDB creates a test database -func setupTestDB(t *testing.T) *db.HSDatabase { +// setupTestState creates a test state with database +func setupTestState(t *testing.T) *state.State { t.Helper() tmpDir := t.TempDir() - hsdb, err := db.NewHeadscaleDatabase( - types.DatabaseConfig{ + prefixV4, _ := netip.ParsePrefix("100.64.0.0/10") + prefixV6, _ := netip.ParsePrefix("fd7a:115c:a1e0::/48") + + cfg := &types.Config{ + Database: types.DatabaseConfig{ Type: types.DatabaseSqlite, Sqlite: types.SqliteConfig{ Path: tmpDir + "/test.db", }, }, - "", - emptyCache(), - ) + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + BaseDomain: "test.local", + PrefixV4: &prefixV4, + PrefixV6: &prefixV6, + IPAllocation: types.IPAllocationStrategySequential, + } + + st, err := state.NewState(cfg) require.NoError(t, err) - return hsdb + return st } func TestCreateOrUpdateOIDCSession(t *testing.T) { - hsdb := setupTestDB(t) + st := setupTestState(t) + defer st.Close() // Create test OIDC provider oidcProvider := &AuthProviderOIDC{ - db: hsdb, + state: st, } // Create test user user := &types.User{ - Model: gorm.Model{ID: 1}, - Name: "testuser", + Name: "testuser", } - err := hsdb.DB.Create(user).Error + createdUser, _, err := st.CreateUser(*user) require.NoError(t, err) // Create test node @@ -95,11 +100,11 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) { DiscoKey: discoKey.Public(), Hostname: "test-node", GivenName: "test-node", - UserID: user.ID, + UserID: createdUser.ID, RegisterMethod: util.RegisterMethodOIDC, Expiry: &nodeExpiry, } - err = hsdb.DB.Create(node).Error + createdNode, _, err := st.CreateNode(node) require.NoError(t, err) tests := []struct { @@ -153,7 +158,7 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := oidcProvider.createOrUpdateOIDCSession(tt.registrationID, tt.token, node.ID) + err := oidcProvider.createOrUpdateOIDCSession(tt.registrationID, tt.token, createdNode.ID) if tt.expectError { assert.Error(t, err) @@ -163,8 +168,7 @@ func TestCreateOrUpdateOIDCSession(t *testing.T) { if tt.expectSession && tt.token.RefreshToken != "" { // Verify session was created/updated - var session types.OIDCSession - err = hsdb.DB.Where("node_id = ?", node.ID).First(&session).Error + session, err := st.GetOIDCSessionByNodeID(createdNode.ID) assert.NoError(t, err) assert.Equal(t, tt.token.RefreshToken, session.RefreshToken) assert.True(t, session.IsActive) @@ -334,23 +338,26 @@ func TestDetermineNodeExpiry(t *testing.T) { } func TestRefreshExpiredTokens(t *testing.T) { - hsdb := setupTestDB(t) + st := setupTestState(t) + defer st.Close() // Create test OIDC provider oidcProvider := &AuthProviderOIDC{ - db: hsdb, + state: st, cfg: &types.OIDCConfig{ Issuer: "https://test.example.com", ClientID: "test-client-id", + TokenRefresh: types.TokenRefreshConfig{ + ExpiryThreshold: 5 * time.Minute, + }, }, } // Create test user user := &types.User{ - Model: gorm.Model{ID: 1}, - Name: "testuser", + Name: "testuser", } - err := hsdb.DB.Create(user).Error + createdUser, _, err := st.CreateUser(*user) require.NoError(t, err) now := time.Now().UTC() @@ -364,9 +371,9 @@ func TestRefreshExpiredTokens(t *testing.T) { { name: "no sessions need refresh", setupSession: func() *types.OIDCSession { - node := createTestNode(t, hsdb, user, "test-node-1") + node := createTestNode(t, st, createdUser, "test-node-1") return &types.OIDCSession{ - NodeID: types.NodeID(node.ID), + NodeID: node.ID, SessionID: "valid-session", RegistrationID: types.RegistrationID("reg-123"), RefreshToken: "refresh-token", @@ -380,9 +387,9 @@ func TestRefreshExpiredTokens(t *testing.T) { { name: "session needs refresh but no refresh token", setupSession: func() *types.OIDCSession { - node := createTestNode(t, hsdb, user, "test-node-2") + node := createTestNode(t, st, createdUser, "test-node-2") return &types.OIDCSession{ - NodeID: types.NodeID(node.ID), + NodeID: node.ID, SessionID: "no-token-session", RegistrationID: types.RegistrationID("reg-456"), RefreshToken: "", // No refresh token @@ -396,9 +403,9 @@ func TestRefreshExpiredTokens(t *testing.T) { { name: "valid token should be ignored", setupSession: func() *types.OIDCSession { - node := createTestNode(t, hsdb, user, "test-node-3") + node := createTestNode(t, st, createdUser, "test-node-3") return &types.OIDCSession{ - NodeID: types.NodeID(node.ID), + NodeID: node.ID, SessionID: "valid-token-session", RegistrationID: types.RegistrationID("reg-789"), RefreshToken: "refresh-token", @@ -421,15 +428,11 @@ func TestRefreshExpiredTokens(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Clean up sessions from previous tests - err := hsdb.DB.Where("1 = 1").Delete(&types.OIDCSession{}).Error - require.NoError(t, err) - // Setup test session if needed if tt.setupSession != nil { session := tt.setupSession() if session != nil { - err := hsdb.DB.Create(session).Error + err := st.CreateOIDCSession(session) require.NoError(t, err) } } @@ -456,11 +459,12 @@ func TestRefreshExpiredTokens(t *testing.T) { } func TestRefreshOIDCSessionValidation(t *testing.T) { - hsdb := setupTestDB(t) + st := setupTestState(t) + defer st.Close() // Create test OIDC provider oidcProvider := &AuthProviderOIDC{ - db: hsdb, + state: st, cfg: &types.OIDCConfig{ Issuer: "https://test.example.com", ClientID: "test-client-id", diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index b754e594..3fdb2197 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -835,3 +835,86 @@ func (s *State) autoApproveNodes() error { return nil } + +// InvalidateExpiredOIDCSessions invalidates sessions for nodes that have been offline +// for longer than the configured grace period. +func (s *State) InvalidateExpiredOIDCSessions(offlineGracePeriod time.Duration) error { + return s.db.InvalidateExpiredOIDCSessions(offlineGracePeriod) +} + +// GetOIDCSessionByNodeID retrieves an OIDC session by node ID. +func (s *State) GetOIDCSessionByNodeID(nodeID types.NodeID) (*types.OIDCSession, error) { + var session types.OIDCSession + err := s.db.Read(func(tx *gorm.DB) error { + return tx.Where("node_id = ?", nodeID).First(&session).Error + }) + if err != nil { + return nil, err + } + return &session, nil +} + +// SaveOIDCSession saves or updates an OIDC session. +func (s *State) SaveOIDCSession(session *types.OIDCSession) error { + return s.db.Write(func(tx *gorm.DB) error { + return tx.Save(session).Error + }) +} + +// CreateOIDCSession creates a new OIDC session. +func (s *State) CreateOIDCSession(session *types.OIDCSession) error { + return s.db.Write(func(tx *gorm.DB) error { + return tx.Create(session).Error + }) +} + +// UpdateNodeExpiry updates a node's expiry time. +func (s *State) UpdateNodeExpiry(nodeID types.NodeID, expiry time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.db.Write(func(tx *gorm.DB) error { + result := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry) + if result.Error != nil { + return result.Error + } + + // Update in-memory state + for i := range s.nodes { + if s.nodes[i].ID == nodeID { + s.nodes[i].Expiry = &expiry + break + } + } + + return nil + }) +} + +// GetNodeWithUser retrieves a node with its associated user. +func (s *State) GetNodeWithUser(nodeID types.NodeID) (*types.Node, error) { + var node types.Node + err := s.db.Read(func(tx *gorm.DB) error { + return tx.Preload("User").First(&node, nodeID).Error + }) + if err != nil { + return nil, err + } + return &node, nil +} + +// GetOIDCSessionsNeedingRefresh retrieves OIDC sessions that need token refresh. +func (s *State) GetOIDCSessionsNeedingRefresh(threshold time.Time, registerMethod string) ([]types.OIDCSession, error) { + var sessions []types.OIDCSession + err := s.db.Read(func(tx *gorm.DB) error { + return tx.Preload("Node").Preload("Node.User"). + Joins("JOIN nodes ON nodes.id = oidc_sessions.node_id"). + Where("oidc_sessions.is_active = ? AND oidc_sessions.token_expiry IS NOT NULL AND oidc_sessions.token_expiry < ? AND oidc_sessions.refresh_token != '' AND nodes.register_method = ?", + true, threshold, registerMethod). + Find(&sessions).Error + }) + if err != nil { + return nil, err + } + return sessions, nil +} diff --git a/hscontrol/types/oidc_session.go b/hscontrol/types/oidc_session.go index 57cf3050..54a01c6f 100644 --- a/hscontrol/types/oidc_session.go +++ b/hscontrol/types/oidc_session.go @@ -24,7 +24,6 @@ type OIDCSession struct { // Token lifecycle TokenExpiry *time.Time `gorm:"index"` LastRefreshedAt *time.Time - RefreshCount int `gorm:"default:0"` // Session state IsActive bool `gorm:"default:true;index"`