Merge branch 'V2' into settingsPageEnhanced

This commit is contained in:
Anthony Stirling
2025-11-06 16:44:06 +00:00
committed by GitHub
36 changed files with 1103 additions and 237 deletions

View File

@@ -49,4 +49,33 @@ public class RequestUriUtils {
|| requestURI.startsWith("/fonts")
|| requestURI.startsWith("/pdfjs"));
}
/**
* Checks if the request URI is a public authentication endpoint that doesn't require
* authentication. This includes login, signup, OAuth callbacks, and public config endpoints.
*
* @param requestURI The full request URI
* @param contextPath The servlet context path
* @return true if the endpoint is public and doesn't require authentication
*/
public static boolean isPublicAuthEndpoint(String requestURI, String contextPath) {
// Remove context path from URI to normalize path matching
String trimmedUri =
requestURI.startsWith(contextPath)
? requestURI.substring(contextPath.length())
: requestURI;
// Public auth endpoints that don't require authentication
return trimmedUri.startsWith("/login")
|| trimmedUri.startsWith("/auth/")
|| trimmedUri.startsWith("/oauth2")
|| trimmedUri.startsWith("/saml2")
|| trimmedUri.contains("/login/oauth2/code/") // Spring Security OAuth2 callback
|| trimmedUri.contains("/oauth2/authorization/") // OAuth2 authorization endpoint
|| trimmedUri.startsWith("/api/v1/auth/login")
|| trimmedUri.startsWith("/api/v1/auth/refresh")
|| trimmedUri.startsWith("/api/v1/auth/logout")
|| trimmedUri.startsWith("/v1/api-docs")
|| trimmedUri.contains("/v1/api-docs");
}
}

View File

@@ -17,6 +17,20 @@ public class JwtAuthenticationEntryPoint implements AuthenticationEntryPoint {
HttpServletResponse response,
AuthenticationException authException)
throws IOException {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage());
String contextPath = request.getContextPath();
String requestURI = request.getRequestURI();
// For API requests, return JSON error
if (requestURI.startsWith(contextPath + "/api/")) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
String message =
authException != null ? authException.getMessage() : "Authentication required";
response.getWriter().write("{\"error\":\"" + message + "\"}");
} else {
// For non-API requests, use default behavior
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage());
}
}
}

View File

@@ -1,5 +1,6 @@
package stirling.software.proprietary.security.configuration;
import java.util.List;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
@@ -28,11 +29,15 @@ import org.springframework.security.web.csrf.CookieCsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.savedrequest.NullRequestCache;
import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import lombok.extern.slf4j.Slf4j;
import stirling.software.common.configuration.AppConfig;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.util.RequestUriUtils;
import stirling.software.proprietary.security.CustomAuthenticationFailureHandler;
import stirling.software.proprietary.security.CustomAuthenticationSuccessHandler;
import stirling.software.proprietary.security.CustomLogoutSuccessHandler;
@@ -67,6 +72,7 @@ public class SecurityConfiguration {
private final boolean loginEnabledValue;
private final boolean runningProOrHigher;
private final ApplicationProperties applicationProperties;
private final ApplicationProperties.Security securityProperties;
private final AppConfig appConfig;
private final UserAuthenticationFilter userAuthenticationFilter;
@@ -86,6 +92,7 @@ public class SecurityConfiguration {
@Qualifier("loginEnabled") boolean loginEnabledValue,
@Qualifier("runningProOrHigher") boolean runningProOrHigher,
AppConfig appConfig,
ApplicationProperties applicationProperties,
ApplicationProperties.Security securityProperties,
UserAuthenticationFilter userAuthenticationFilter,
JwtServiceInterface jwtService,
@@ -102,6 +109,7 @@ public class SecurityConfiguration {
this.loginEnabledValue = loginEnabledValue;
this.runningProOrHigher = runningProOrHigher;
this.appConfig = appConfig;
this.applicationProperties = applicationProperties;
this.securityProperties = securityProperties;
this.userAuthenticationFilter = userAuthenticationFilter;
this.jwtService = jwtService;
@@ -120,7 +128,79 @@ public class SecurityConfiguration {
}
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
public CorsConfigurationSource corsConfigurationSource() {
// Read CORS allowed origins from settings
if (applicationProperties.getSystem() != null
&& applicationProperties.getSystem().getCorsAllowedOrigins() != null
&& !applicationProperties.getSystem().getCorsAllowedOrigins().isEmpty()) {
List<String> allowedOrigins = applicationProperties.getSystem().getCorsAllowedOrigins();
CorsConfiguration cfg = new CorsConfiguration();
// Use setAllowedOriginPatterns for better wildcard and port support
cfg.setAllowedOriginPatterns(allowedOrigins);
log.debug(
"CORS configured with allowed origin patterns from settings.yml: {}",
allowedOrigins);
// Set allowed methods explicitly (including OPTIONS for preflight)
cfg.setAllowedMethods(List.of("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"));
// Set allowed headers explicitly
cfg.setAllowedHeaders(
List.of(
"Authorization",
"Content-Type",
"X-Requested-With",
"Accept",
"Origin",
"X-API-KEY",
"X-CSRF-TOKEN"));
// Set exposed headers (headers that the browser can access)
cfg.setExposedHeaders(
List.of(
"WWW-Authenticate",
"X-Total-Count",
"X-Page-Number",
"X-Page-Size",
"Content-Disposition",
"Content-Type"));
// Allow credentials (cookies, authorization headers)
cfg.setAllowCredentials(true);
// Set max age for preflight cache
cfg.setMaxAge(3600L);
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
source.registerCorsConfiguration("/**", cfg);
return source;
} else {
// No CORS origins configured - return null to disable CORS processing entirely
// This avoids empty CORS policy that unexpectedly rejects preflights
log.info(
"CORS is disabled - no allowed origins configured in settings.yml (system.corsAllowedOrigins)");
return null;
}
}
@Bean
public SecurityFilterChain filterChain(
HttpSecurity http,
@Lazy IPRateLimitingFilter rateLimitingFilter,
@Lazy JwtAuthenticationFilter jwtAuthenticationFilter)
throws Exception {
// Enable CORS only if we have configured origins
CorsConfigurationSource corsSource = corsConfigurationSource();
if (corsSource != null) {
http.cors(cors -> cors.configurationSource(corsSource));
} else {
// Explicitly disable CORS when no origins are configured
http.cors(cors -> cors.disable());
}
if (securityProperties.getCsrfDisabled() || !loginEnabledValue) {
http.csrf(CsrfConfigurer::disable);
}
@@ -130,12 +210,8 @@ public class SecurityConfiguration {
http.addFilterBefore(
userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class)
.addFilterBefore(
rateLimitingFilter(), UsernamePasswordAuthenticationFilter.class);
if (v2Enabled) {
http.addFilterBefore(jwtAuthenticationFilter(), UserAuthenticationFilter.class);
}
.addFilterBefore(rateLimitingFilter, UsernamePasswordAuthenticationFilter.class)
.addFilterBefore(jwtAuthenticationFilter, UserAuthenticationFilter.class);
if (!securityProperties.getCsrfDisabled()) {
CookieCsrfTokenRepository cookieRepo =
@@ -195,6 +271,18 @@ public class SecurityConfiguration {
});
http.authenticationProvider(daoAuthenticationProvider());
http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()));
// Configure exception handling for API endpoints
http.exceptionHandling(
exceptions ->
exceptions.defaultAuthenticationEntryPointFor(
jwtAuthenticationEntryPoint,
request -> {
String contextPath = request.getContextPath();
String requestURI = request.getRequestURI();
return requestURI.startsWith(contextPath + "/api/");
}));
http.logout(
logout ->
logout.logoutRequestMatcher(
@@ -227,49 +315,12 @@ public class SecurityConfiguration {
req -> {
String uri = req.getRequestURI();
String contextPath = req.getContextPath();
// Remove the context path from the URI
String trimmedUri =
uri.startsWith(contextPath)
? uri.substring(
contextPath.length())
: uri;
return trimmedUri.startsWith("/login")
|| trimmedUri.startsWith("/oauth")
|| trimmedUri.startsWith("/oauth2")
|| trimmedUri.startsWith("/saml2")
|| trimmedUri.endsWith(".svg")
|| trimmedUri.startsWith("/register")
|| trimmedUri.startsWith("/signup")
|| trimmedUri.startsWith("/invite")
|| trimmedUri.startsWith("/auth/callback")
|| trimmedUri.startsWith("/error")
|| trimmedUri.startsWith("/images/")
|| trimmedUri.startsWith("/public/")
|| trimmedUri.startsWith("/css/")
|| trimmedUri.startsWith("/fonts/")
|| trimmedUri.startsWith("/js/")
|| trimmedUri.startsWith("/pdfjs/")
|| trimmedUri.startsWith("/pdfjs-legacy/")
|| trimmedUri.startsWith("/favicon")
|| trimmedUri.startsWith(
"/api/v1/info/status")
|| trimmedUri.startsWith("/api/v1/config")
|| trimmedUri.startsWith(
"/api/v1/auth/register")
|| trimmedUri.startsWith(
"/api/v1/user/register")
|| trimmedUri.startsWith(
"/api/v1/auth/login")
|| trimmedUri.startsWith(
"/api/v1/auth/refresh")
|| trimmedUri.startsWith("/api/v1/auth/me")
|| trimmedUri.startsWith(
"/api/v1/invite/validate")
|| trimmedUri.startsWith(
"/api/v1/invite/accept")
|| trimmedUri.startsWith("/v1/api-docs")
|| uri.contains("/v1/api-docs");
// Check if it's a public auth endpoint or static
// resource
return RequestUriUtils.isStaticResource(
contextPath, uri)
|| RequestUriUtils.isPublicAuthEndpoint(
uri, contextPath);
})
.permitAll()
.anyRequest()
@@ -338,8 +389,12 @@ public class SecurityConfiguration {
.saml2Login(
saml2 -> {
try {
saml2.loginPage("/saml2")
.relyingPartyRegistrationRepository(
// Only set login page for v1/Thymeleaf mode
if (!v2Enabled) {
saml2.loginPage("/saml2");
}
saml2.relyingPartyRegistrationRepository(
saml2RelyingPartyRegistrations)
.authenticationManager(
new ProviderManager(authenticationProvider))

View File

@@ -21,11 +21,15 @@ import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.audit.AuditEventType;
import stirling.software.proprietary.audit.AuditLevel;
import stirling.software.proprietary.audit.Audited;
import stirling.software.proprietary.security.model.AuthenticationType;
import stirling.software.proprietary.security.model.User;
import stirling.software.proprietary.security.model.api.user.UsernameAndPass;
import stirling.software.proprietary.security.service.CustomUserDetailsService;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.LoginAttemptService;
import stirling.software.proprietary.security.service.UserService;
/** REST API Controller for authentication operations. */
@@ -39,6 +43,7 @@ public class AuthController {
private final UserService userService;
private final JwtServiceInterface jwtService;
private final CustomUserDetailsService userDetailsService;
private final LoginAttemptService loginAttemptService;
/**
* Login endpoint - replaces Supabase signInWithPassword
@@ -49,8 +54,11 @@ public class AuthController {
*/
@PreAuthorize("!hasAuthority('ROLE_DEMO_USER')")
@PostMapping("/login")
@Audited(type = AuditEventType.USER_LOGIN, level = AuditLevel.BASIC)
public ResponseEntity<?> login(
@RequestBody UsernameAndPass request, HttpServletResponse response) {
@RequestBody UsernameAndPass request,
HttpServletRequest httpRequest,
HttpServletResponse response) {
try {
// Validate input parameters
if (request.getUsername() == null || request.getUsername().trim().isEmpty()) {
@@ -67,20 +75,30 @@ public class AuthController {
.body(Map.of("error", "Password is required"));
}
log.debug("Login attempt for user: {}", request.getUsername());
String username = request.getUsername().trim();
String ip = httpRequest.getRemoteAddr();
UserDetails userDetails =
userDetailsService.loadUserByUsername(request.getUsername().trim());
// Check if account is blocked due to too many failed attempts
if (loginAttemptService.isBlocked(username)) {
log.warn("Blocked account login attempt for user: {} from IP: {}", username, ip);
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(Map.of("error", "Account is locked due to too many failed attempts"));
}
log.debug("Login attempt for user: {} from IP: {}", username, ip);
UserDetails userDetails = userDetailsService.loadUserByUsername(username);
User user = (User) userDetails;
if (!userService.isPasswordCorrect(user, request.getPassword())) {
log.warn("Invalid password for user: {}", request.getUsername());
log.warn("Invalid password for user: {} from IP: {}", username, ip);
loginAttemptService.loginFailed(username);
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(Map.of("error", "Invalid credentials"));
}
if (!user.isEnabled()) {
log.warn("Disabled user attempted login: {}", request.getUsername());
log.warn("Disabled user attempted login: {} from IP: {}", username, ip);
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(Map.of("error", "User account is disabled"));
}
@@ -91,7 +109,9 @@ public class AuthController {
String token = jwtService.generateToken(user.getUsername(), claims);
log.info("Login successful for user: {}", request.getUsername());
// Record successful login
loginAttemptService.loginSucceeded(username);
log.info("Login successful for user: {} from IP: {}", username, ip);
return ResponseEntity.ok(
Map.of(
@@ -99,11 +119,15 @@ public class AuthController {
"session", Map.of("access_token", token, "expires_in", 3600)));
} catch (UsernameNotFoundException e) {
log.warn("User not found: {}", request.getUsername());
String username = request.getUsername();
log.warn("User not found: {}", username);
loginAttemptService.loginFailed(username);
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(Map.of("error", "Invalid username or password"));
} catch (AuthenticationException e) {
log.error("Authentication failed for user: {}", request.getUsername(), e);
String username = request.getUsername();
log.error("Authentication failed for user: {}", username, e);
loginAttemptService.loginFailed(username);
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(Map.of("error", "Invalid credentials"));
} catch (Exception e) {
@@ -228,11 +252,4 @@ public class AuthController {
return userMap;
}
// ===========================
// Request/Response DTOs
// ===========================
/** Login request DTO */
public record LoginRequest(String email, String password) {}
}

View File

@@ -1,5 +1,6 @@
package stirling.software.proprietary.security.filter;
import static stirling.software.common.util.RequestUriUtils.isPublicAuthEndpoint;
import static stirling.software.common.util.RequestUriUtils.isStaticResource;
import static stirling.software.proprietary.security.model.AuthenticationType.OAUTH2;
import static stirling.software.proprietary.security.model.AuthenticationType.SAML2;
@@ -80,20 +81,7 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
String requestURI = request.getRequestURI();
String contextPath = request.getContextPath();
// Public auth endpoints that don't require JWT
boolean isPublicAuthEndpoint =
requestURI.startsWith(contextPath + "/login")
|| requestURI.startsWith(contextPath + "/signup")
|| requestURI.startsWith(contextPath + "/invite")
|| requestURI.startsWith(contextPath + "/auth/")
|| requestURI.startsWith(contextPath + "/oauth2")
|| requestURI.startsWith(contextPath + "/api/v1/auth/login")
|| requestURI.startsWith(contextPath + "/api/v1/auth/register")
|| requestURI.startsWith(contextPath + "/api/v1/auth/refresh")
|| requestURI.startsWith(contextPath + "/api/v1/invite/validate")
|| requestURI.startsWith(contextPath + "/api/v1/invite/accept");
if (!isPublicAuthEndpoint) {
if (!isPublicAuthEndpoint(requestURI, contextPath)) {
// For API requests, return 401 JSON
String acceptHeader = request.getHeader("Accept");
if (requestURI.startsWith(contextPath + "/api/")

View File

@@ -1,5 +1,7 @@
package stirling.software.proprietary.security.filter;
import static stirling.software.common.util.RequestUriUtils.isPublicAuthEndpoint;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
@@ -105,11 +107,17 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
}
}
// If we still don't have any authentication, deny the request
// If we still don't have any authentication, check if it's a public endpoint. If not, deny the request
if (authentication == null || !authentication.isAuthenticated()) {
String method = request.getMethod();
String contextPath = request.getContextPath();
// Allow public auth endpoints to pass through without authentication
if (isPublicAuthEndpoint(requestURI, contextPath)) {
filterChain.doFilter(request, response);
return;
}
if ("GET".equalsIgnoreCase(method) && !requestURI.startsWith(contextPath + "/login")) {
response.sendRedirect(contextPath + "/login"); // redirect to the login page
} else {
@@ -200,6 +208,23 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
}
private static boolean isPublicAuthEndpoint(String requestURI, String contextPath) {
// Remove context path from URI to normalize path matching
String trimmedUri =
requestURI.startsWith(contextPath)
? requestURI.substring(contextPath.length())
: requestURI;
// Public auth endpoints that don't require authentication
return trimmedUri.startsWith("/login")
|| trimmedUri.startsWith("/auth/")
|| trimmedUri.startsWith("/oauth2")
|| trimmedUri.startsWith("/saml2")
|| trimmedUri.startsWith("/api/v1/auth/login")
|| trimmedUri.startsWith("/api/v1/auth/refresh")
|| trimmedUri.startsWith("/api/v1/auth/logout");
}
private enum UserLoginType {
USERDETAILS("UserDetails"),
OAUTH2USER("OAuth2User"),
@@ -225,7 +250,6 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
String contextPath = request.getContextPath();
String[] permitAllPatterns = {
contextPath + "/login",
contextPath + "/signup",
contextPath + "/register",
contextPath + "/invite",
contextPath + "/error",
@@ -238,7 +262,6 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
contextPath + "/pdfjs-legacy/",
contextPath + "/api/v1/info/status",
contextPath + "/api/v1/auth/login",
contextPath + "/api/v1/auth/register",
contextPath + "/api/v1/auth/refresh",
contextPath + "/api/v1/auth/me",
contextPath + "/api/v1/invite/validate",

View File

@@ -4,9 +4,14 @@ import static stirling.software.proprietary.security.model.AuthenticationType.OA
import static stirling.software.proprietary.security.model.AuthenticationType.SSO;
import java.io.IOException;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.Map;
import java.util.Optional;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
@@ -16,6 +21,7 @@ import org.springframework.security.web.authentication.SavedRequestAwareAuthenti
import org.springframework.security.web.savedrequest.SavedRequest;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
@@ -37,6 +43,9 @@ import stirling.software.proprietary.security.service.UserService;
public class CustomOAuth2AuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private static final String SPA_REDIRECT_COOKIE = "stirling_redirect_path";
private static final String DEFAULT_CALLBACK_PATH = "/auth/callback";
private final LoginAttemptService loginAttemptService;
private final ApplicationProperties.Security.OAUTH2 oauth2Properties;
private final UserService userService;
@@ -119,7 +128,8 @@ public class CustomOAuth2AuthenticationSuccessHandler
authentication, Map.of("authType", AuthenticationType.OAUTH2));
// Build context-aware redirect URL based on the original request
String redirectUrl = buildContextAwareRedirectUrl(request, contextPath, jwt);
String redirectUrl =
buildContextAwareRedirectUrl(request, response, contextPath, jwt);
response.sendRedirect(redirectUrl);
} else {
@@ -149,30 +159,110 @@ public class CustomOAuth2AuthenticationSuccessHandler
* Builds a context-aware redirect URL based on the request's origin
*
* @param request The HTTP request
* @param response HTTP response (used to clear redirect cookies)
* @param contextPath The application context path
* @param jwt The JWT token to include
* @return The appropriate redirect URL
*/
private String buildContextAwareRedirectUrl(
HttpServletRequest request, String contextPath, String jwt) {
// Try to get the origin from the Referer header first
HttpServletRequest request,
HttpServletResponse response,
String contextPath,
String jwt) {
String redirectPath = resolveRedirectPath(request, contextPath);
String origin =
resolveForwardedOrigin(request)
.orElseGet(
() ->
resolveOriginFromReferer(request)
.orElseGet(() -> buildOriginFromRequest(request)));
clearRedirectCookie(response);
return origin + redirectPath + "#access_token=" + jwt;
}
private String resolveRedirectPath(HttpServletRequest request, String contextPath) {
return extractRedirectPathFromCookie(request)
.filter(path -> path.startsWith("/"))
.orElseGet(() -> defaultCallbackPath(contextPath));
}
private Optional<String> extractRedirectPathFromCookie(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if (cookies == null) {
return Optional.empty();
}
for (Cookie cookie : cookies) {
if (SPA_REDIRECT_COOKIE.equals(cookie.getName())) {
String value = URLDecoder.decode(cookie.getValue(), StandardCharsets.UTF_8).trim();
if (!value.isEmpty()) {
return Optional.of(value);
}
}
}
return Optional.empty();
}
private String defaultCallbackPath(String contextPath) {
if (contextPath == null
|| contextPath.isBlank()
|| "/".equals(contextPath)
|| "\\".equals(contextPath)) {
return DEFAULT_CALLBACK_PATH;
}
return contextPath + DEFAULT_CALLBACK_PATH;
}
private Optional<String> resolveForwardedOrigin(HttpServletRequest request) {
String forwardedHostHeader = request.getHeader("X-Forwarded-Host");
if (forwardedHostHeader == null || forwardedHostHeader.isBlank()) {
return Optional.empty();
}
String host = forwardedHostHeader.split(",")[0].trim();
if (host.isEmpty()) {
return Optional.empty();
}
String forwardedProtoHeader = request.getHeader("X-Forwarded-Proto");
String proto =
(forwardedProtoHeader == null || forwardedProtoHeader.isBlank())
? request.getScheme()
: forwardedProtoHeader.split(",")[0].trim();
if (!host.contains(":")) {
String forwardedPort = request.getHeader("X-Forwarded-Port");
if (forwardedPort != null
&& !forwardedPort.isBlank()
&& !isDefaultPort(proto, forwardedPort.trim())) {
host = host + ":" + forwardedPort.trim();
}
}
return Optional.of(proto + "://" + host);
}
private Optional<String> resolveOriginFromReferer(HttpServletRequest request) {
String referer = request.getHeader("Referer");
if (referer != null && !referer.isEmpty()) {
try {
java.net.URL refererUrl = new java.net.URL(referer);
String origin = refererUrl.getProtocol() + "://" + refererUrl.getHost();
if (refererUrl.getPort() != -1
&& refererUrl.getPort() != 80
&& refererUrl.getPort() != 443) {
origin += ":" + refererUrl.getPort();
String refererHost = refererUrl.getHost().toLowerCase();
if (!isOAuthProviderDomain(refererHost)) {
String origin = refererUrl.getProtocol() + "://" + refererUrl.getHost();
if (refererUrl.getPort() != -1
&& refererUrl.getPort() != 80
&& refererUrl.getPort() != 443) {
origin += ":" + refererUrl.getPort();
}
return Optional.of(origin);
}
return origin + "/auth/callback#access_token=" + jwt;
} catch (java.net.MalformedURLException e) {
// Fall back to other methods if referer is malformed
// ignore and fall back
}
}
return Optional.empty();
}
// Fall back to building from request host/port
private String buildOriginFromRequest(HttpServletRequest request) {
String scheme = request.getScheme();
String serverName = request.getServerName();
int serverPort = request.getServerPort();
@@ -180,12 +270,50 @@ public class CustomOAuth2AuthenticationSuccessHandler
StringBuilder origin = new StringBuilder();
origin.append(scheme).append("://").append(serverName);
// Only add port if it's not the default port for the scheme
if ((!"http".equals(scheme) || serverPort != 80)
&& (!"https".equals(scheme) || serverPort != 443)) {
if ((!"http".equalsIgnoreCase(scheme) || serverPort != 80)
&& (!"https".equalsIgnoreCase(scheme) || serverPort != 443)) {
origin.append(":").append(serverPort);
}
return origin.toString() + "/auth/callback#access_token=" + jwt;
return origin.toString();
}
private boolean isDefaultPort(String scheme, String port) {
if (port == null) {
return true;
}
try {
int parsedPort = Integer.parseInt(port);
return ("http".equalsIgnoreCase(scheme) && parsedPort == 80)
|| ("https".equalsIgnoreCase(scheme) && parsedPort == 443);
} catch (NumberFormatException e) {
return false;
}
}
private void clearRedirectCookie(HttpServletResponse response) {
ResponseCookie cookie =
ResponseCookie.from(SPA_REDIRECT_COOKIE, "")
.path("/")
.sameSite("Lax")
.maxAge(0)
.build();
response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString());
}
/**
* Checks if the given hostname belongs to a known OAuth provider.
*
* @param hostname The hostname to check
* @return true if it's an OAuth provider domain, false otherwise
*/
private boolean isOAuthProviderDomain(String hostname) {
return hostname.contains("google.com")
|| hostname.contains("googleapis.com")
|| hostname.contains("github.com")
|| hostname.contains("microsoft.com")
|| hostname.contains("microsoftonline.com")
|| hostname.contains("linkedin.com")
|| hostname.contains("apple.com");
}
}

View File

@@ -165,12 +165,7 @@ public class OAuth2Configuration {
githubClient.getUseAsUsername());
boolean isValid = validateProvider(github);
log.info(
"GitHub OAuth2 provider validation: {} (clientId: {}, clientSecret: {}, scopes: {})",
isValid,
githubClient.getClientId(),
githubClient.getClientSecret() != null ? "***" : "null",
githubClient.getScopes());
log.info("Initialised GitHub OAuth2 provider");
return isValid
? Optional.of(

View File

@@ -4,15 +4,21 @@ import static stirling.software.proprietary.security.model.AuthenticationType.SA
import static stirling.software.proprietary.security.model.AuthenticationType.SSO;
import java.io.IOException;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.Map;
import java.util.Optional;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.SavedRequest;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
@@ -36,6 +42,9 @@ import stirling.software.proprietary.security.service.UserService;
public class CustomSaml2AuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private static final String SPA_REDIRECT_COOKIE = "stirling_redirect_path";
private static final String DEFAULT_CALLBACK_PATH = "/auth/callback";
private LoginAttemptService loginAttemptService;
private ApplicationProperties.Security.SAML2 saml2Properties;
private UserService userService;
@@ -148,7 +157,7 @@ public class CustomSaml2AuthenticationSuccessHandler
// Build context-aware redirect URL based on the original request
String redirectUrl =
buildContextAwareRedirectUrl(request, contextPath, jwt);
buildContextAwareRedirectUrl(request, response, contextPath, jwt);
response.sendRedirect(redirectUrl);
} else {
@@ -177,8 +186,81 @@ public class CustomSaml2AuthenticationSuccessHandler
* @return The appropriate redirect URL
*/
private String buildContextAwareRedirectUrl(
HttpServletRequest request, String contextPath, String jwt) {
// Try to get the origin from the Referer header first
HttpServletRequest request,
HttpServletResponse response,
String contextPath,
String jwt) {
String redirectPath = resolveRedirectPath(request, contextPath);
String origin =
resolveForwardedOrigin(request)
.orElseGet(
() ->
resolveOriginFromReferer(request)
.orElseGet(() -> buildOriginFromRequest(request)));
clearRedirectCookie(response);
return origin + redirectPath + "#access_token=" + jwt;
}
private String resolveRedirectPath(HttpServletRequest request, String contextPath) {
return extractRedirectPathFromCookie(request)
.filter(path -> path.startsWith("/"))
.orElseGet(() -> defaultCallbackPath(contextPath));
}
private Optional<String> extractRedirectPathFromCookie(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if (cookies == null) {
return Optional.empty();
}
for (Cookie cookie : cookies) {
if (SPA_REDIRECT_COOKIE.equals(cookie.getName())) {
String value = URLDecoder.decode(cookie.getValue(), StandardCharsets.UTF_8).trim();
if (!value.isEmpty()) {
return Optional.of(value);
}
}
}
return Optional.empty();
}
private String defaultCallbackPath(String contextPath) {
if (contextPath == null
|| contextPath.isBlank()
|| "/".equals(contextPath)
|| "\\".equals(contextPath)) {
return DEFAULT_CALLBACK_PATH;
}
return contextPath + DEFAULT_CALLBACK_PATH;
}
private Optional<String> resolveForwardedOrigin(HttpServletRequest request) {
String forwardedHostHeader = request.getHeader("X-Forwarded-Host");
if (forwardedHostHeader == null || forwardedHostHeader.isBlank()) {
return Optional.empty();
}
String host = forwardedHostHeader.split(",")[0].trim();
if (host.isEmpty()) {
return Optional.empty();
}
String forwardedProtoHeader = request.getHeader("X-Forwarded-Proto");
String proto =
(forwardedProtoHeader == null || forwardedProtoHeader.isBlank())
? request.getScheme()
: forwardedProtoHeader.split(",")[0].trim();
if (!host.contains(":")) {
String forwardedPort = request.getHeader("X-Forwarded-Port");
if (forwardedPort != null
&& !forwardedPort.isBlank()
&& !isDefaultPort(proto, forwardedPort.trim())) {
host = host + ":" + forwardedPort.trim();
}
}
return Optional.of(proto + "://" + host);
}
private Optional<String> resolveOriginFromReferer(HttpServletRequest request) {
String referer = request.getHeader("Referer");
if (referer != null && !referer.isEmpty()) {
try {
@@ -189,14 +271,16 @@ public class CustomSaml2AuthenticationSuccessHandler
&& refererUrl.getPort() != 443) {
origin += ":" + refererUrl.getPort();
}
return origin + "/auth/callback#access_token=" + jwt;
return Optional.of(origin);
} catch (java.net.MalformedURLException e) {
log.debug(
"Malformed referer URL: {}, falling back to request-based origin", referer);
}
}
return Optional.empty();
}
// Fall back to building from request host/port
private String buildOriginFromRequest(HttpServletRequest request) {
String scheme = request.getScheme();
String serverName = request.getServerName();
int serverPort = request.getServerPort();
@@ -204,12 +288,34 @@ public class CustomSaml2AuthenticationSuccessHandler
StringBuilder origin = new StringBuilder();
origin.append(scheme).append("://").append(serverName);
// Only add port if it's not the default port for the scheme
if ((!"http".equals(scheme) || serverPort != 80)
&& (!"https".equals(scheme) || serverPort != 443)) {
if ((!"http".equalsIgnoreCase(scheme) || serverPort != 80)
&& (!"https".equalsIgnoreCase(scheme) || serverPort != 443)) {
origin.append(":").append(serverPort);
}
return origin + "/auth/callback#access_token=" + jwt;
return origin.toString();
}
private boolean isDefaultPort(String scheme, String port) {
if (port == null) {
return true;
}
try {
int parsedPort = Integer.parseInt(port);
return ("http".equalsIgnoreCase(scheme) && parsedPort == 80)
|| ("https".equalsIgnoreCase(scheme) && parsedPort == 443);
} catch (NumberFormatException e) {
return false;
}
}
private void clearRedirectCookie(HttpServletResponse response) {
ResponseCookie cookie =
ResponseCookie.from(SPA_REDIRECT_COOKIE, "")
.path("/")
.sameSite("Lax")
.maxAge(0)
.build();
response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString());
}
}

View File

@@ -29,6 +29,8 @@ class JwtAuthenticationEntryPointTest {
@Test
void testCommence() throws IOException {
String errorMessage = "Authentication failed";
when(request.getRequestURI()).thenReturn("/redact");
when(authException.getMessage()).thenReturn(errorMessage);
jwtAuthenticationEntryPoint.commence(request, response, authException);