diff --git a/src/main/java/stirling/software/SPDF/config/security/UserAuthenticationFilter.java b/src/main/java/stirling/software/SPDF/config/security/UserAuthenticationFilter.java index 714096c61..494d1d848 100644 --- a/src/main/java/stirling/software/SPDF/config/security/UserAuthenticationFilter.java +++ b/src/main/java/stirling/software/SPDF/config/security/UserAuthenticationFilter.java @@ -22,6 +22,7 @@ import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.HttpSession; import lombok.extern.slf4j.Slf4j; @@ -32,6 +33,7 @@ import stirling.software.SPDF.model.ApplicationProperties; import stirling.software.SPDF.model.ApplicationProperties.Security; import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2; import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2; +import stirling.software.SPDF.model.SessionEntity; import stirling.software.SPDF.model.User; @Slf4j @@ -59,26 +61,66 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication != null && authentication.isAuthenticated() && loginEnabledValue) { + Object principalTest = authentication.getPrincipal(); + String username = UserUtils.getUsernameFromPrincipal(principalTest); + + log.info("Principal: {}", username); + List allSessions = + sessionPersistentRegistry.getAllSessions(username, false); + + HttpSession session = request.getSession(false); + if (session == null) { + session = request.getSession(true); + } + + String sessionId = request.getSession(false).getId(); + + log.info("allSessions: {} username: {}", allSessions.size(), username); + + for (SessionInformation sessionInformation : allSessions) { + if (sessionId.equals(sessionInformation.getSessionId())) { + log.info("Session found: {}", sessionId); + log.info("lastRequest: {}", sessionInformation.getLastRequest()); + sessionPersistentRegistry.refreshLastRequest(sessionId); + SessionInformation sessionInfo = + sessionPersistentRegistry.getSessionInformation(sessionId); + log.info("new lastRequest: {}", sessionInfo.getLastRequest()); + } else if (allSessions.size() > 2) { + sessionPersistentRegistry.expireSession(sessionId); + sessionInformation.expireNow(); + authentication.setAuthenticated(false); + SecurityContextHolder.clearContext(); + request.getSession().invalidate(); + log.info( + "Expired session: {} Date: {}", + sessionInformation.getSessionId(), + sessionInformation.getLastRequest()); + response.sendRedirect(request.getContextPath() + "/login?error=expiredSession"); + return; + } + } + allSessions = sessionPersistentRegistry.getAllSessions(username, false); + + SessionEntity sessionEntity = sessionPersistentRegistry.getSessionEntity(sessionId); + + if (allSessions.isEmpty() || sessionEntity.isExpired()) { + log.info("No sessions found for user: {}", username); + sessionPersistentRegistry.expireSession(sessionId); + authentication.setAuthenticated(false); + SecurityContextHolder.clearContext(); + response.sendRedirect(request.getContextPath() + "/login?error=expiredSession"); + return; + } + } if (!loginEnabledValue) { // If login is not enabled, just pass all requests without authentication filterChain.doFilter(request, response); return; } String requestURI = request.getRequestURI(); - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - - // Check for session expiration (unsure if needed) - // if (authentication != null && authentication.isAuthenticated()) { - // String sessionId = request.getSession().getId(); - // SessionInformation sessionInfo = - // sessionPersistentRegistry.getSessionInformation(sessionId); - // - // if (sessionInfo != null && sessionInfo.isExpired()) { - // SecurityContextHolder.clearContext(); - // response.sendRedirect(request.getContextPath() + "/login?expired=true"); - // return; - // } - // } + // authentication = SecurityContextHolder.getContext().getAuthentication(); } // Check for API key in the request headers if no authentication exists if (authentication == null || !authentication.isAuthenticated()) { diff --git a/src/main/java/stirling/software/SPDF/config/security/UserService.java b/src/main/java/stirling/software/SPDF/config/security/UserService.java index 61b7c40af..49cd7dfdb 100644 --- a/src/main/java/stirling/software/SPDF/config/security/UserService.java +++ b/src/main/java/stirling/software/SPDF/config/security/UserService.java @@ -12,7 +12,6 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.crypto.password.PasswordEncoder; @@ -380,23 +379,10 @@ public class UserService implements UserServiceInterface { } public void invalidateUserSessions(String username) { - String usernameP = ""; - for (Object principal : sessionRegistry.getAllPrincipals()) { - for (SessionInformation sessionsInformation : - sessionRegistry.getAllSessions(principal, false)) { - if (principal instanceof UserDetails detailsUser) { - usernameP = detailsUser.getUsername(); - } else if (principal instanceof OAuth2User oAuth2User) { - usernameP = oAuth2User.getName(); - } else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) { - usernameP = saml2User.name(); - } else if (principal instanceof String stringUser) { - usernameP = stringUser; - } - if (usernameP.equalsIgnoreCase(username)) { - sessionRegistry.expireSession(sessionsInformation.getSessionId()); - } + String usernameP = UserUtils.getUsernameFromPrincipal(principal); + if (usernameP.equalsIgnoreCase(username)) { + sessionRegistry.expireAllSessionsByPrincipalName(usernameP); } } } diff --git a/src/main/java/stirling/software/SPDF/config/security/UserUtils.java b/src/main/java/stirling/software/SPDF/config/security/UserUtils.java new file mode 100644 index 000000000..a2c03ac14 --- /dev/null +++ b/src/main/java/stirling/software/SPDF/config/security/UserUtils.java @@ -0,0 +1,22 @@ +package stirling.software.SPDF.config.security; + +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.oauth2.core.user.OAuth2User; + +import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal; + +public class UserUtils { + public static String getUsernameFromPrincipal(Object principal) { + if (principal instanceof UserDetails detailsUser) { + return detailsUser.getUsername(); + } else if (principal instanceof OAuth2User oAuth2User) { + return oAuth2User.getName(); + } else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) { + return saml2User.name(); + } else if (principal instanceof String stringUser) { + return stringUser; + } else { + return null; + } + } +} diff --git a/src/main/java/stirling/software/SPDF/config/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java b/src/main/java/stirling/software/SPDF/config/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java index 4ee49aed4..1b4d84e53 100644 --- a/src/main/java/stirling/software/SPDF/config/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java +++ b/src/main/java/stirling/software/SPDF/config/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java @@ -5,7 +5,6 @@ import java.sql.SQLException; import org.springframework.security.authentication.LockedException; import org.springframework.security.core.Authentication; -import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.SavedRequest; @@ -17,6 +16,7 @@ import jakarta.servlet.http.HttpSession; import stirling.software.SPDF.config.security.LoginAttemptService; import stirling.software.SPDF.config.security.UserService; +import stirling.software.SPDF.config.security.UserUtils; import stirling.software.SPDF.model.ApplicationProperties; import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2; import stirling.software.SPDF.model.AuthenticationType; @@ -45,13 +45,7 @@ public class CustomOAuth2AuthenticationSuccessHandler throws ServletException, IOException { Object principal = authentication.getPrincipal(); - String username = ""; - - if (principal instanceof OAuth2User oAuth2User) { - username = oAuth2User.getName(); - } else if (principal instanceof UserDetails detailsUser) { - username = detailsUser.getUsername(); - } + String username = UserUtils.getUsernameFromPrincipal(principal); // Get the saved request HttpSession session = request.getSession(false); diff --git a/src/main/java/stirling/software/SPDF/config/security/session/CustomHttpSessionListener.java b/src/main/java/stirling/software/SPDF/config/security/session/CustomHttpSessionListener.java index 3d97181ab..dd0b01025 100644 --- a/src/main/java/stirling/software/SPDF/config/security/session/CustomHttpSessionListener.java +++ b/src/main/java/stirling/software/SPDF/config/security/session/CustomHttpSessionListener.java @@ -1,6 +1,8 @@ package stirling.software.SPDF.config.security.session; -import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Component; import jakarta.servlet.http.HttpSessionEvent; @@ -8,20 +10,43 @@ import jakarta.servlet.http.HttpSessionListener; import lombok.extern.slf4j.Slf4j; +import stirling.software.SPDF.config.security.UserUtils; + @Component @Slf4j public class CustomHttpSessionListener implements HttpSessionListener { - private SessionPersistentRegistry sessionPersistentRegistry; + private final SessionPersistentRegistry sessionPersistentRegistry; - @Autowired public CustomHttpSessionListener(SessionPersistentRegistry sessionPersistentRegistry) { super(); this.sessionPersistentRegistry = sessionPersistentRegistry; } @Override - public void sessionCreated(HttpSessionEvent se) {} + public void sessionCreated(HttpSessionEvent se) { + SecurityContext securityContext = SecurityContextHolder.getContext(); + if (securityContext == null) { + log.debug("Security context is null"); + return; + } + Authentication authentication = securityContext.getAuthentication(); + if (authentication == null) { + log.info("Authentication is null"); + return; + } + Object principal = authentication.getPrincipal(); + if (principal == null) { + log.info("Principal is null"); + return; + } + String principalName = UserUtils.getUsernameFromPrincipal(principal); + if (principalName == null || "anonymousUser".equals(principalName)) { + return; + } + log.info("Session created: {}", principalName); + sessionPersistentRegistry.registerNewSession(se.getSession().getId(), principalName); + } @Override public void sessionDestroyed(HttpSessionEvent se) { diff --git a/src/main/java/stirling/software/SPDF/config/security/session/SessionPersistentRegistry.java b/src/main/java/stirling/software/SPDF/config/security/session/SessionPersistentRegistry.java index 18b037164..65a9e94e3 100644 --- a/src/main/java/stirling/software/SPDF/config/security/session/SessionPersistentRegistry.java +++ b/src/main/java/stirling/software/SPDF/config/security/session/SessionPersistentRegistry.java @@ -1,26 +1,31 @@ package stirling.software.SPDF.config.security.session; import java.time.Duration; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Optional; import org.springframework.beans.factory.annotation.Value; import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.session.SessionRegistry; -import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.stereotype.Component; import jakarta.transaction.Transactional; -import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.SPDF.config.security.UserUtils; import stirling.software.SPDF.model.SessionEntity; @Component +@Slf4j public class SessionPersistentRegistry implements SessionRegistry { private final SessionRepository sessionRepository; - @Value("${server.servlet.session.timeout:30m}") + @Value("${server.servlet.session.timeout:120s}") // TODO: Change to 30m private Duration defaultMaxInactiveInterval; public SessionPersistentRegistry(SessionRepository sessionRepository) { @@ -41,17 +46,7 @@ public class SessionPersistentRegistry implements SessionRegistry { public List getAllSessions( Object principal, boolean includeExpiredSessions) { List sessionInformations = new ArrayList<>(); - String principalName = null; - - if (principal instanceof UserDetails detailsUser) { - principalName = detailsUser.getUsername(); - } else if (principal instanceof OAuth2User oAuth2User) { - principalName = oAuth2User.getName(); - } else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) { - principalName = saml2User.name(); - } else if (principal instanceof String stringUser) { - principalName = stringUser; - } + String principalName = UserUtils.getUsernameFromPrincipal(principal); if (principalName != null) { List sessionEntities = @@ -72,29 +67,15 @@ public class SessionPersistentRegistry implements SessionRegistry { @Override @Transactional public void registerNewSession(String sessionId, Object principal) { - String principalName = null; - - if (principal instanceof UserDetails detailsUser) { - principalName = detailsUser.getUsername(); - } else if (principal instanceof OAuth2User oAuth2User) { - principalName = oAuth2User.getName(); - } else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) { - principalName = saml2User.name(); - } else if (principal instanceof String stringUser) { - principalName = stringUser; - } + String principalName = UserUtils.getUsernameFromPrincipal(principal); if (principalName != null) { - // Clear old sessions for the principal (unsure if needed) - // List existingSessions = - // sessionRepository.findByPrincipalName(principalName); - // for (SessionEntity session : existingSessions) { - // session.setExpired(true); - // sessionRepository.save(session); - // } - - SessionEntity sessionEntity = new SessionEntity(); - sessionEntity.setSessionId(sessionId); + SessionEntity sessionEntity = sessionRepository.findBySessionId(sessionId); + if (sessionEntity == null) { + sessionEntity = new SessionEntity(); + sessionEntity.setSessionId(sessionId); + log.info("Registering new session for principal: {}", principalName); + } sessionEntity.setPrincipalName(principalName); sessionEntity.setLastRequest(new Date()); // Set lastRequest to the current date sessionEntity.setExpired(false); @@ -111,11 +92,12 @@ public class SessionPersistentRegistry implements SessionRegistry { @Override @Transactional public void refreshLastRequest(String sessionId) { - Optional sessionEntityOpt = sessionRepository.findById(sessionId); - if (sessionEntityOpt.isPresent()) { - SessionEntity sessionEntity = sessionEntityOpt.get(); + SessionEntity sessionEntity = sessionRepository.findBySessionId(sessionId); + if (sessionEntity != null) { sessionEntity.setLastRequest(new Date()); // Update lastRequest to the current date sessionRepository.save(sessionEntity); + } else { + log.error("Session not found for session ID: {}", sessionId); } } @@ -152,6 +134,15 @@ public class SessionPersistentRegistry implements SessionRegistry { } } + // Mark all sessions as expired for a given principal name + public void expireAllSessionsByPrincipalName(String principalName) { + List sessionEntities = sessionRepository.findByPrincipalName(principalName); + for (SessionEntity sessionEntity : sessionEntities) { + sessionEntity.setExpired(true); // Set expired to true + sessionRepository.save(sessionEntity); + } + } + // Get the maximum inactive interval for sessions public int getMaxInactiveInterval() { return (int) defaultMaxInactiveInterval.getSeconds(); @@ -168,6 +159,15 @@ public class SessionPersistentRegistry implements SessionRegistry { sessionRepository.saveByPrincipalName(expired, lastRequest, principalName); } + // Update session details by session ID + public void updateSessionBySessionId(String sessionId) { + SessionEntity sessionEntity = getSessionEntity(sessionId); + if (sessionEntity != null) { + sessionEntity.setLastRequest(new Date()); + sessionRepository.save(sessionEntity); + } + } + // Find the latest session for a given principal name public Optional findLatestSession(String principalName) { List allSessions = sessionRepository.findByPrincipalName(principalName); @@ -178,13 +178,8 @@ public class SessionPersistentRegistry implements SessionRegistry { // Sort sessions by lastRequest in descending order Collections.sort( allSessions, - new Comparator() { - @Override - public int compare(SessionEntity s1, SessionEntity s2) { - // Sort by lastRequest in descending order - return s2.getLastRequest().compareTo(s1.getLastRequest()); - } - }); + (SessionEntity s1, SessionEntity s2) -> + s2.getLastRequest().compareTo(s1.getLastRequest())); // The first session in the list is the latest session for the given principal name return Optional.of(allSessions.get(0)); diff --git a/src/main/java/stirling/software/SPDF/config/security/session/SessionRepository.java b/src/main/java/stirling/software/SPDF/config/security/session/SessionRepository.java index b7f0133f3..933c20129 100644 --- a/src/main/java/stirling/software/SPDF/config/security/session/SessionRepository.java +++ b/src/main/java/stirling/software/SPDF/config/security/session/SessionRepository.java @@ -24,7 +24,8 @@ public interface SessionRepository extends JpaRepository @Modifying @Transactional @Query( - "UPDATE SessionEntity s SET s.expired = :expired, s.lastRequest = :lastRequest WHERE s.principalName = :principalName") + "UPDATE SessionEntity s SET s.expired = :expired, s.lastRequest = :lastRequest WHERE" + + " s.principalName = :principalName") void saveByPrincipalName( @Param("expired") boolean expired, @Param("lastRequest") Date lastRequest, diff --git a/src/main/java/stirling/software/SPDF/config/security/session/SessionScheduled.java b/src/main/java/stirling/software/SPDF/config/security/session/SessionScheduled.java index 9710316e3..e6d52421b 100644 --- a/src/main/java/stirling/software/SPDF/config/security/session/SessionScheduled.java +++ b/src/main/java/stirling/software/SPDF/config/security/session/SessionScheduled.java @@ -6,10 +6,15 @@ import java.util.Date; import java.util.List; import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionInformation; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Component +@Slf4j public class SessionScheduled { private final SessionPersistentRegistry sessionPersistentRegistry; @@ -18,10 +23,18 @@ public class SessionScheduled { this.sessionPersistentRegistry = sessionPersistentRegistry; } - @Scheduled(cron = "0 0/5 * * * ?") + @Scheduled(cron = "0 0/1 * * * ?") // TODO: Change to 5m public void expireSessions() { Instant now = Instant.now(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); for (Object principal : sessionPersistentRegistry.getAllPrincipals()) { + if (principal == null) { + continue; + } else if (principal instanceof String stringPrincipal) { + if ("anonymousUser".equals(stringPrincipal)) { + continue; + } + } List sessionInformations = sessionPersistentRegistry.getAllSessions(principal, false); for (SessionInformation sessionInformation : sessionInformations) { @@ -30,7 +43,21 @@ public class SessionScheduled { Instant expirationTime = lastRequest.toInstant().plus(maxInactiveInterval, ChronoUnit.SECONDS); if (now.isAfter(expirationTime)) { + log.info( + "SessionID: {} expiration time: {} Current time: {}", + sessionInformation.getSessionId(), + expirationTime, + now); sessionPersistentRegistry.expireSession(sessionInformation.getSessionId()); + sessionInformation.expireNow(); + if (authentication != null && principal.equals(authentication.getPrincipal())) { + authentication.setAuthenticated(false); + } + SecurityContextHolder.clearContext(); + log.info( + "Session expired for principal: {} SessionID: {}", + principal, + sessionInformation.getSessionId()); } } } diff --git a/src/main/java/stirling/software/SPDF/controller/web/AccountWebController.java b/src/main/java/stirling/software/SPDF/controller/web/AccountWebController.java index 65e1d055a..1396772fb 100644 --- a/src/main/java/stirling/software/SPDF/controller/web/AccountWebController.java +++ b/src/main/java/stirling/software/SPDF/controller/web/AccountWebController.java @@ -147,6 +147,7 @@ public class AccountWebController { case "badCredentials" -> error = "login.invalid"; case "locked" -> error = "login.locked"; case "oauth2AuthenticationError" -> error = "userAlreadyExistsOAuthMessage"; + case "expiredSession" -> error = "expiredSessionMessage"; } model.addAttribute("error", error);