Fixed missing AuthnRequest bug (#5606)

This commit is contained in:
Dario Ghunney Ware
2026-01-30 16:27:31 +00:00
committed by GitHub
parent 1cc562a6b1
commit 6fee27739c
5 changed files with 34 additions and 420 deletions

View File

@@ -8,6 +8,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.annotation.Order;
import org.springframework.security.authentication.ProviderManager;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
@@ -185,11 +186,39 @@ public class SecurityConfiguration {
}
@Bean
@Order(1)
public SecurityFilterChain samlFilterChain(
HttpSecurity http,
@Lazy IPRateLimitingFilter rateLimitingFilter,
@Lazy JwtAuthenticationFilter jwtAuthenticationFilter)
throws Exception {
http.securityMatcher("/saml2/**", "/login/saml2/**");
SessionCreationPolicy sessionPolicy =
(securityProperties.isSaml2Active() && runningProOrHigher)
? SessionCreationPolicy.IF_REQUIRED
: SessionCreationPolicy.STATELESS;
return configureSecurity(http, rateLimitingFilter, jwtAuthenticationFilter, sessionPolicy);
}
@Bean
@Order(2)
public SecurityFilterChain filterChain(
HttpSecurity http,
@Lazy IPRateLimitingFilter rateLimitingFilter,
@Lazy JwtAuthenticationFilter jwtAuthenticationFilter)
throws Exception {
SessionCreationPolicy sessionPolicy = SessionCreationPolicy.STATELESS;
return configureSecurity(http, rateLimitingFilter, jwtAuthenticationFilter, sessionPolicy);
}
private SecurityFilterChain configureSecurity(
HttpSecurity http,
@Lazy IPRateLimitingFilter rateLimitingFilter,
@Lazy JwtAuthenticationFilter jwtAuthenticationFilter,
SessionCreationPolicy sessionPolicy)
throws Exception {
// Enable CORS only if we have configured origins
CorsConfigurationSource corsSource = corsConfigurationSource();
if (corsSource != null) {
@@ -233,9 +262,7 @@ public class SecurityConfiguration {
.addFilterBefore(jwtAuthenticationFilter, UserAuthenticationFilter.class);
http.sessionManagement(
sessionManagement ->
sessionManagement.sessionCreationPolicy(
SessionCreationPolicy.STATELESS));
sessionManagement -> sessionManagement.sessionCreationPolicy(sessionPolicy));
http.authenticationProvider(daoAuthenticationProvider());
http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()));

View File

@@ -1,135 +0,0 @@
package stirling.software.proprietary.security.saml2;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Slf4j
public class JwtSaml2AuthenticationRequestRepository
implements Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest> {
private final Map<String, String> tokenStore;
private final JwtServiceInterface jwtService;
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token";
public JwtSaml2AuthenticationRequestRepository(
Map<String, String> tokenStore,
JwtServiceInterface jwtService,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this.tokenStore = tokenStore;
this.jwtService = jwtService;
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
}
@Override
public void saveAuthenticationRequest(
Saml2PostAuthenticationRequest authRequest,
HttpServletRequest request,
HttpServletResponse response) {
if (!jwtService.isJwtEnabled()) {
log.debug("V2 is not enabled, skipping SAMLRequest token storage");
return;
}
if (authRequest == null) {
removeAuthenticationRequest(request, response);
return;
}
Map<String, Object> claims = serializeSamlRequest(authRequest);
String token = jwtService.generateToken("", claims);
String relayState = authRequest.getRelayState();
tokenStore.put(relayState, token);
request.setAttribute(SAML_REQUEST_TOKEN, relayState);
response.addHeader(SAML_REQUEST_TOKEN, relayState);
log.debug("Saved SAMLRequest token with RelayState: {}", relayState);
}
@Override
public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
String token = extractTokenFromStore(request);
if (token == null) {
log.debug("No SAMLResponse token found in RelayState");
return null;
}
Map<String, Object> claims = jwtService.extractClaims(token);
return deserializeSamlRequest(claims);
}
@Override
public Saml2PostAuthenticationRequest removeAuthenticationRequest(
HttpServletRequest request, HttpServletResponse response) {
Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request);
String relayStateId = request.getParameter("RelayState");
if (relayStateId != null) {
tokenStore.remove(relayStateId);
log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId);
}
return authRequest;
}
private String extractTokenFromStore(HttpServletRequest request) {
String authnRequestId = request.getParameter("RelayState");
if (authnRequestId != null && !authnRequestId.isEmpty()) {
String token = tokenStore.get(authnRequestId);
if (token != null) {
tokenStore.remove(authnRequestId);
log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId);
return token;
} else {
log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId);
}
}
return null;
}
private Map<String, Object> serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) {
Map<String, Object> claims = new HashMap<>();
claims.put("id", authRequest.getId());
claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId());
claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri());
claims.put("samlRequest", authRequest.getSamlRequest());
claims.put("relayState", authRequest.getRelayState());
return claims;
}
private Saml2PostAuthenticationRequest deserializeSamlRequest(Map<String, Object> claims) {
String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId");
RelyingPartyRegistration relyingPartyRegistration =
relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId);
if (relyingPartyRegistration == null) {
return null;
}
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration)
.id((String) claims.get("id"))
.authenticationRequestUri((String) claims.get("authenticationRequestUri"))
.samlRequest((String) claims.get("samlRequest"))
.relayState((String) claims.get("relayState"))
.build();
}
}

View File

@@ -3,7 +3,6 @@ package stirling.software.proprietary.security.saml2;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -12,12 +11,10 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
import jakarta.servlet.http.HttpServletRequest;
@@ -27,7 +24,6 @@ import lombok.extern.slf4j.Slf4j;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.model.ApplicationProperties.Security.SAML2;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Configuration
@Slf4j
@@ -153,22 +149,10 @@ public class Saml2Configuration {
return new InMemoryRelyingPartyRegistrationRepository(rp);
}
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest>
saml2AuthenticationRequestRepository(
JwtServiceInterface jwtService,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
return new JwtSaml2AuthenticationRequestRepository(
new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository);
}
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest>
saml2AuthenticationRequestRepository) {
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
OpenSaml4AuthenticationRequestResolver resolver =
new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository);
@@ -176,30 +160,10 @@ public class Saml2Configuration {
customizer -> {
HttpServletRequest request = customizer.getRequest();
AuthnRequest authnRequest = customizer.getAuthnRequest();
Saml2PostAuthenticationRequest saml2AuthenticationRequest =
saml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
if (saml2AuthenticationRequest != null) {
String sessionId = request.getSession(false).getId();
// Generate a unique AuthnRequest ID for each SAML request
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
log.debug(
"Retrieving SAML 2 authentication request ID from the current HTTP session {}",
sessionId);
String authenticationRequestId = saml2AuthenticationRequest.getId();
if (!authenticationRequestId.isBlank()) {
authnRequest.setID(authenticationRequestId);
} else {
log.warn(
"No authentication request found for HTTP session {}. Generating new ID",
sessionId);
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
}
} else {
log.debug("Generating new authentication request ID");
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
}
logAuthnRequestDetails(authnRequest);
logHttpRequestDetails(request);
});

View File

@@ -67,6 +67,7 @@ public class UserService implements UserServiceInterface {
private final ApplicationProperties.Security.OAUTH2 oAuth2;
@Transactional
public void processSSOPostLogin(
String username,
String ssoProviderId,

View File

@@ -1,243 +0,0 @@
package stirling.software.proprietary.security.saml2;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@ExtendWith(MockitoExtension.class)
class JwtSaml2AuthenticationRequestRepositoryTest {
private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token";
private Map<String, String> tokenStore;
@Mock private JwtServiceInterface jwtService;
@Mock private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository;
@BeforeEach
void setUp() {
tokenStore = new ConcurrentHashMap<>();
jwtSaml2AuthenticationRequestRepository =
new JwtSaml2AuthenticationRequestRepository(
tokenStore, jwtService, relyingPartyRegistrationRepository);
}
@Test
void saveAuthenticationRequest() {
var authRequest = mock(Saml2PostAuthenticationRequest.class);
var request = mock(MockHttpServletRequest.class);
var response = mock(MockHttpServletResponse.class);
String token = "testToken";
String id = "testId";
String relayState = "testRelayState";
String authnRequestUri = "example.com/authnRequest";
String samlRequest = "testSamlRequest";
String relyingPartyRegistrationId = "stirling-pdf";
when(jwtService.isJwtEnabled()).thenReturn(true);
when(authRequest.getRelayState()).thenReturn(relayState);
when(authRequest.getId()).thenReturn(id);
when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri);
when(authRequest.getSamlRequest()).thenReturn(samlRequest);
when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId);
when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(
authRequest, request, response);
verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState);
verify(response).addHeader(SAML_REQUEST_TOKEN, relayState);
}
@Test
void saveAuthenticationRequestWithNullRequest() {
var request = mock(MockHttpServletRequest.class);
var response = mock(MockHttpServletResponse.class);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response);
assertTrue(tokenStore.isEmpty());
}
@Test
void loadAuthenticationRequest() {
var request = mock(MockHttpServletRequest.class);
var relyingPartyRegistration = mock(RelyingPartyRegistration.class);
var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims =
Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf"))
.thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata())
.thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation())
.thenReturn("https://example.com/sso");
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
@ParameterizedTest
@NullAndEmptySource
void loadAuthenticationRequestWithInvalidRelayState(String relayState) {
var request = mock(MockHttpServletRequest.class);
when(request.getParameter("RelayState")).thenReturn(relayState);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void loadAuthenticationRequestWithNonExistentToken() {
var request = mock(MockHttpServletRequest.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void loadAuthenticationRequestWithNullRelyingPartyRegistration() {
var request = mock(MockHttpServletRequest.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims =
Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf"))
.thenReturn(null);
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void removeAuthenticationRequest() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
var relyingPartyRegistration = mock(RelyingPartyRegistration.class);
var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims =
Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf"))
.thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata())
.thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation())
.thenReturn("https://example.com/sso");
tokenStore.put(relayState, token);
var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
@Test
void removeAuthenticationRequestWithNullRelayState() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn(null);
var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNull(result);
}
@Test
void removeAuthenticationRequestWithNonExistentToken() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNull(result);
}
@Test
void removeAuthenticationRequestWithOnlyRelayState() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
String relayState = "testRelayState";
when(request.getParameter("RelayState")).thenReturn(relayState);
var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
}