diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java index 3f03dcbaf..06efcf3a1 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java @@ -8,6 +8,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.DependsOn; import org.springframework.context.annotation.Lazy; +import org.springframework.core.annotation.Order; import org.springframework.security.authentication.ProviderManager; import org.springframework.security.authentication.dao.DaoAuthenticationProvider; import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; @@ -185,11 +186,39 @@ public class SecurityConfiguration { } @Bean + @Order(1) + public SecurityFilterChain samlFilterChain( + HttpSecurity http, + @Lazy IPRateLimitingFilter rateLimitingFilter, + @Lazy JwtAuthenticationFilter jwtAuthenticationFilter) + throws Exception { + http.securityMatcher("/saml2/**", "/login/saml2/**"); + + SessionCreationPolicy sessionPolicy = + (securityProperties.isSaml2Active() && runningProOrHigher) + ? SessionCreationPolicy.IF_REQUIRED + : SessionCreationPolicy.STATELESS; + + return configureSecurity(http, rateLimitingFilter, jwtAuthenticationFilter, sessionPolicy); + } + + @Bean + @Order(2) public SecurityFilterChain filterChain( HttpSecurity http, @Lazy IPRateLimitingFilter rateLimitingFilter, @Lazy JwtAuthenticationFilter jwtAuthenticationFilter) throws Exception { + SessionCreationPolicy sessionPolicy = SessionCreationPolicy.STATELESS; + return configureSecurity(http, rateLimitingFilter, jwtAuthenticationFilter, sessionPolicy); + } + + private SecurityFilterChain configureSecurity( + HttpSecurity http, + @Lazy IPRateLimitingFilter rateLimitingFilter, + @Lazy JwtAuthenticationFilter jwtAuthenticationFilter, + SessionCreationPolicy sessionPolicy) + throws Exception { // Enable CORS only if we have configured origins CorsConfigurationSource corsSource = corsConfigurationSource(); if (corsSource != null) { @@ -233,9 +262,7 @@ public class SecurityConfiguration { .addFilterBefore(jwtAuthenticationFilter, UserAuthenticationFilter.class); http.sessionManagement( - sessionManagement -> - sessionManagement.sessionCreationPolicy( - SessionCreationPolicy.STATELESS)); + sessionManagement -> sessionManagement.sessionCreationPolicy(sessionPolicy)); http.authenticationProvider(daoAuthenticationProvider()); http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache())); diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java deleted file mode 100644 index d0508151c..000000000 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java +++ /dev/null @@ -1,135 +0,0 @@ -package stirling.software.proprietary.security.saml2; - -import java.util.HashMap; -import java.util.Map; - -import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; -import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; - -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; - -import lombok.extern.slf4j.Slf4j; - -import stirling.software.proprietary.security.service.JwtServiceInterface; - -@Slf4j -public class JwtSaml2AuthenticationRequestRepository - implements Saml2AuthenticationRequestRepository { - private final Map tokenStore; - private final JwtServiceInterface jwtService; - private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; - - private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; - - public JwtSaml2AuthenticationRequestRepository( - Map tokenStore, - JwtServiceInterface jwtService, - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { - this.tokenStore = tokenStore; - this.jwtService = jwtService; - this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; - } - - @Override - public void saveAuthenticationRequest( - Saml2PostAuthenticationRequest authRequest, - HttpServletRequest request, - HttpServletResponse response) { - if (!jwtService.isJwtEnabled()) { - log.debug("V2 is not enabled, skipping SAMLRequest token storage"); - return; - } - - if (authRequest == null) { - removeAuthenticationRequest(request, response); - return; - } - - Map claims = serializeSamlRequest(authRequest); - String token = jwtService.generateToken("", claims); - String relayState = authRequest.getRelayState(); - - tokenStore.put(relayState, token); - request.setAttribute(SAML_REQUEST_TOKEN, relayState); - response.addHeader(SAML_REQUEST_TOKEN, relayState); - - log.debug("Saved SAMLRequest token with RelayState: {}", relayState); - } - - @Override - public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { - String token = extractTokenFromStore(request); - - if (token == null) { - log.debug("No SAMLResponse token found in RelayState"); - return null; - } - - Map claims = jwtService.extractClaims(token); - return deserializeSamlRequest(claims); - } - - @Override - public Saml2PostAuthenticationRequest removeAuthenticationRequest( - HttpServletRequest request, HttpServletResponse response) { - Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request); - - String relayStateId = request.getParameter("RelayState"); - if (relayStateId != null) { - tokenStore.remove(relayStateId); - log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId); - } - - return authRequest; - } - - private String extractTokenFromStore(HttpServletRequest request) { - String authnRequestId = request.getParameter("RelayState"); - - if (authnRequestId != null && !authnRequestId.isEmpty()) { - String token = tokenStore.get(authnRequestId); - - if (token != null) { - tokenStore.remove(authnRequestId); - log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId); - return token; - } else { - log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId); - } - } - - return null; - } - - private Map serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) { - Map claims = new HashMap<>(); - - claims.put("id", authRequest.getId()); - claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId()); - claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri()); - claims.put("samlRequest", authRequest.getSamlRequest()); - claims.put("relayState", authRequest.getRelayState()); - - return claims; - } - - private Saml2PostAuthenticationRequest deserializeSamlRequest(Map claims) { - String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId"); - RelyingPartyRegistration relyingPartyRegistration = - relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId); - - if (relyingPartyRegistration == null) { - return null; - } - - return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration) - .id((String) claims.get("id")) - .authenticationRequestUri((String) claims.get("authenticationRequestUri")) - .samlRequest((String) claims.get("samlRequest")) - .relayState((String) claims.get("relayState")) - .build(); - } -} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java index 6ccffa1da..9cbde954f 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java @@ -3,7 +3,6 @@ package stirling.software.proprietary.security.saml2; import java.security.cert.X509Certificate; import java.util.Collections; import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; import org.opensaml.saml.saml2.core.AuthnRequest; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -12,12 +11,10 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.Resource; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; -import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; import jakarta.servlet.http.HttpServletRequest; @@ -27,7 +24,6 @@ import lombok.extern.slf4j.Slf4j; import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties.Security.SAML2; -import stirling.software.proprietary.security.service.JwtServiceInterface; @Configuration @Slf4j @@ -153,22 +149,10 @@ public class Saml2Configuration { return new InMemoryRelyingPartyRegistrationRepository(rp); } - @Bean - @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") - public Saml2AuthenticationRequestRepository - saml2AuthenticationRequestRepository( - JwtServiceInterface jwtService, - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { - return new JwtSaml2AuthenticationRequestRepository( - new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository); - } - @Bean @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver( - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, - Saml2AuthenticationRequestRepository - saml2AuthenticationRequestRepository) { + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { OpenSaml4AuthenticationRequestResolver resolver = new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository); @@ -176,30 +160,10 @@ public class Saml2Configuration { customizer -> { HttpServletRequest request = customizer.getRequest(); AuthnRequest authnRequest = customizer.getAuthnRequest(); - Saml2PostAuthenticationRequest saml2AuthenticationRequest = - saml2AuthenticationRequestRepository.loadAuthenticationRequest(request); - if (saml2AuthenticationRequest != null) { - String sessionId = request.getSession(false).getId(); + // Generate a unique AuthnRequest ID for each SAML request + authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); - log.debug( - "Retrieving SAML 2 authentication request ID from the current HTTP session {}", - sessionId); - - String authenticationRequestId = saml2AuthenticationRequest.getId(); - - if (!authenticationRequestId.isBlank()) { - authnRequest.setID(authenticationRequestId); - } else { - log.warn( - "No authentication request found for HTTP session {}. Generating new ID", - sessionId); - authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); - } - } else { - log.debug("Generating new authentication request ID"); - authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); - } logAuthnRequestDetails(authnRequest); logHttpRequestDetails(request); }); diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java index f2bc3cc2c..dbff59997 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java @@ -67,6 +67,7 @@ public class UserService implements UserServiceInterface { private final ApplicationProperties.Security.OAUTH2 oAuth2; + @Transactional public void processSSOPostLogin( String username, String ssoProviderId, diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java deleted file mode 100644 index 915c97444..000000000 --- a/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java +++ /dev/null @@ -1,243 +0,0 @@ -package stirling.software.proprietary.security.saml2; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.NullAndEmptySource; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; -import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; - -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; - -import stirling.software.proprietary.security.service.JwtServiceInterface; - -@ExtendWith(MockitoExtension.class) -class JwtSaml2AuthenticationRequestRepositoryTest { - - private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; - - private Map tokenStore; - - @Mock private JwtServiceInterface jwtService; - - @Mock private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; - - private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository; - - @BeforeEach - void setUp() { - tokenStore = new ConcurrentHashMap<>(); - jwtSaml2AuthenticationRequestRepository = - new JwtSaml2AuthenticationRequestRepository( - tokenStore, jwtService, relyingPartyRegistrationRepository); - } - - @Test - void saveAuthenticationRequest() { - var authRequest = mock(Saml2PostAuthenticationRequest.class); - var request = mock(MockHttpServletRequest.class); - var response = mock(MockHttpServletResponse.class); - String token = "testToken"; - String id = "testId"; - String relayState = "testRelayState"; - String authnRequestUri = "example.com/authnRequest"; - String samlRequest = "testSamlRequest"; - String relyingPartyRegistrationId = "stirling-pdf"; - - when(jwtService.isJwtEnabled()).thenReturn(true); - when(authRequest.getRelayState()).thenReturn(relayState); - when(authRequest.getId()).thenReturn(id); - when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri); - when(authRequest.getSamlRequest()).thenReturn(samlRequest); - when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId); - when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token); - - jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest( - authRequest, request, response); - - verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState); - verify(response).addHeader(SAML_REQUEST_TOKEN, relayState); - } - - @Test - void saveAuthenticationRequestWithNullRequest() { - var request = mock(MockHttpServletRequest.class); - var response = mock(MockHttpServletResponse.class); - - jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response); - - assertTrue(tokenStore.isEmpty()); - } - - @Test - void loadAuthenticationRequest() { - var request = mock(MockHttpServletRequest.class); - var relyingPartyRegistration = mock(RelyingPartyRegistration.class); - var assertingPartyMetadata = mock(AssertingPartyMetadata.class); - String relayState = "testRelayState"; - String token = "testToken"; - Map claims = - Map.of( - "id", "testId", - "relyingPartyRegistrationId", "stirling-pdf", - "authenticationRequestUri", "example.com/authnRequest", - "samlRequest", "testSamlRequest", - "relayState", relayState); - - when(request.getParameter("RelayState")).thenReturn(relayState); - when(jwtService.extractClaims(token)).thenReturn(claims); - when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")) - .thenReturn(relyingPartyRegistration); - when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); - when(relyingPartyRegistration.getAssertingPartyMetadata()) - .thenReturn(assertingPartyMetadata); - when(assertingPartyMetadata.getSingleSignOnServiceLocation()) - .thenReturn("https://example.com/sso"); - tokenStore.put(relayState, token); - - var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); - - assertNotNull(result); - assertFalse(tokenStore.containsKey(relayState)); - } - - @ParameterizedTest - @NullAndEmptySource - void loadAuthenticationRequestWithInvalidRelayState(String relayState) { - var request = mock(MockHttpServletRequest.class); - when(request.getParameter("RelayState")).thenReturn(relayState); - - var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); - - assertNull(result); - } - - @Test - void loadAuthenticationRequestWithNonExistentToken() { - var request = mock(MockHttpServletRequest.class); - when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); - - var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); - - assertNull(result); - } - - @Test - void loadAuthenticationRequestWithNullRelyingPartyRegistration() { - var request = mock(MockHttpServletRequest.class); - String relayState = "testRelayState"; - String token = "testToken"; - Map claims = - Map.of( - "id", "testId", - "relyingPartyRegistrationId", "stirling-pdf", - "authenticationRequestUri", "example.com/authnRequest", - "samlRequest", "testSamlRequest", - "relayState", relayState); - - when(request.getParameter("RelayState")).thenReturn(relayState); - when(jwtService.extractClaims(token)).thenReturn(claims); - when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")) - .thenReturn(null); - tokenStore.put(relayState, token); - - var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); - - assertNull(result); - } - - @Test - void removeAuthenticationRequest() { - var request = mock(HttpServletRequest.class); - var response = mock(HttpServletResponse.class); - var relyingPartyRegistration = mock(RelyingPartyRegistration.class); - var assertingPartyMetadata = mock(AssertingPartyMetadata.class); - String relayState = "testRelayState"; - String token = "testToken"; - Map claims = - Map.of( - "id", "testId", - "relyingPartyRegistrationId", "stirling-pdf", - "authenticationRequestUri", "example.com/authnRequest", - "samlRequest", "testSamlRequest", - "relayState", relayState); - - when(request.getParameter("RelayState")).thenReturn(relayState); - when(jwtService.extractClaims(token)).thenReturn(claims); - when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")) - .thenReturn(relyingPartyRegistration); - when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); - when(relyingPartyRegistration.getAssertingPartyMetadata()) - .thenReturn(assertingPartyMetadata); - when(assertingPartyMetadata.getSingleSignOnServiceLocation()) - .thenReturn("https://example.com/sso"); - tokenStore.put(relayState, token); - - var result = - jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( - request, response); - - assertNotNull(result); - assertFalse(tokenStore.containsKey(relayState)); - } - - @Test - void removeAuthenticationRequestWithNullRelayState() { - var request = mock(HttpServletRequest.class); - var response = mock(HttpServletResponse.class); - when(request.getParameter("RelayState")).thenReturn(null); - - var result = - jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( - request, response); - - assertNull(result); - } - - @Test - void removeAuthenticationRequestWithNonExistentToken() { - var request = mock(HttpServletRequest.class); - var response = mock(HttpServletResponse.class); - when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); - - var result = - jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( - request, response); - - assertNull(result); - } - - @Test - void removeAuthenticationRequestWithOnlyRelayState() { - var request = mock(HttpServletRequest.class); - var response = mock(HttpServletResponse.class); - String relayState = "testRelayState"; - - when(request.getParameter("RelayState")).thenReturn(relayState); - - var result = - jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( - request, response); - - assertNull(result); - assertFalse(tokenStore.containsKey(relayState)); - } -}