This commit is contained in:
Dario Ghunney Ware 2025-07-28 12:03:28 +01:00 committed by GitHub
commit 34cc980687
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 2690 additions and 191 deletions

View File

@ -5,7 +5,13 @@
"Bash(mkdir:*)",
"Bash(./gradlew:*)",
"Bash(grep:*)",
"Bash(cat:*)"
"Bash(cat:*)",
"Bash(find:*)",
"Bash(grep:*)",
"Bash(rg:*)",
"Bash(strings:*)",
"Bash(pkill:*)",
"Bash(true)"
],
"deny": []
}

View File

@ -8,6 +8,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Properties;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@ -51,6 +52,14 @@ public class AppConfig {
@Value("${server.port:8080}")
private String serverPort;
@Value("${v2}")
public boolean v2Enabled;
@Bean
public boolean v2Enabled() {
return v2Enabled;
}
@Bean
@ConditionalOnProperty(name = "system.customHTMLFiles", havingValue = "true")
public SpringTemplateEngine templateEngine(ResourceLoader resourceLoader) {
@ -120,7 +129,7 @@ public class AppConfig {
public boolean rateLimit() {
String rateLimit = System.getProperty("rateLimit");
if (rateLimit == null) rateLimit = System.getenv("rateLimit");
return (rateLimit != null) ? Boolean.valueOf(rateLimit) : false;
return Boolean.parseBoolean(rateLimit);
}
@Bean(name = "RunningInDocker")
@ -140,8 +149,8 @@ public class AppConfig {
if (!Files.exists(mountInfo)) {
return true;
}
try {
return Files.lines(mountInfo).anyMatch(line -> line.contains(" /configs "));
try (Stream<String> lines = Files.lines(mountInfo)) {
return lines.anyMatch(line -> line.contains(" /configs "));
} catch (IOException e) {
return false;
}

View File

@ -113,6 +113,7 @@ public class ApplicationProperties {
private long loginResetTimeMinutes;
private String loginMethod = "all";
private String customGlobalAPIKey;
private Jwt jwt = new Jwt();
public Boolean isAltLogin() {
return saml2.getEnabled() || oauth2.getEnabled();
@ -275,6 +276,12 @@ public class ApplicationProperties {
}
}
}
@Data
public static class Jwt {
private boolean enableKeystore = true;
private boolean enableKeyRotation = false;
}
}
@Data

View File

@ -14,8 +14,10 @@ public class RequestUriUtils {
|| requestURI.startsWith(contextPath + "/images/")
|| requestURI.startsWith(contextPath + "/public/")
|| requestURI.startsWith(contextPath + "/pdfjs/")
|| requestURI.startsWith(contextPath + "/pdfjs-legacy/")
|| requestURI.startsWith(contextPath + "/login")
|| requestURI.startsWith(contextPath + "/error")
|| requestURI.startsWith(contextPath + "/favicon")
|| requestURI.endsWith(".svg")
|| requestURI.endsWith(".png")
|| requestURI.endsWith(".ico")

View File

@ -5,7 +5,7 @@ logging.level.org.eclipse.jetty=WARN
#logging.level.org.springframework.security.saml2=TRACE
#logging.level.org.springframework.security=DEBUG
#logging.level.org.opensaml=DEBUG
#logging.level.stirling.software.SPDF.config.security: DEBUG
#logging.level.stirling.software.proprietary.security: DEBUG
logging.level.com.zaxxer.hikari=WARN
spring.jpa.open-in-view=false
server.forward-headers-strategy=NATIVE
@ -47,4 +47,7 @@ posthog.host=https://eu.i.posthog.com
spring.main.allow-bean-definition-overriding=true
# Set up a consistent temporary directory location
java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf}
java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf}
# V2 features
v2=true

View File

@ -861,7 +861,7 @@ login.rememberme=Remember me
login.invalid=Invalid username or password.
login.locked=Your account has been locked.
login.signinTitle=Please sign in
login.ssoSignIn=Login via Single Sign-on
login.ssoSignIn=Login via Single Sign-On
login.oAuth2AutoCreateDisabled=OAUTH2 Auto-Create User Disabled
login.oAuth2AdminBlockedUser=Registration or logging in of non-registered users is currently blocked. Please contact the administrator.
login.oauth2RequestNotFound=Authorization request not found
@ -876,6 +876,7 @@ login.alreadyLoggedIn=You are already logged in to
login.alreadyLoggedIn2=devices. Please log out of the devices and try again.
login.toManySessions=You have too many active sessions
login.logoutMessage=You have been logged out.
login.invalidInResponseTo=The requested SAML response is invalid or has expired. Please contact the administrator.
#auto-redact
autoRedact.title=Auto Redact

View File

@ -31,7 +31,7 @@ security:
google:
clientId: '' # client ID for Google OAuth2
clientSecret: '' # client secret for Google OAuth2
scopes: email, profile # scopes for Google OAuth2
scopes: https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile # scopes for Google OAuth2
useAsUsername: email # field to use as the username for Google OAuth2. Available options are: [email | name | given_name | family_name]
github:
clientId: '' # client ID for GitHub OAuth2
@ -51,20 +51,22 @@ security:
provider: '' # The name of your Provider
autoCreateUser: true # set to 'true' to allow auto-creation of non-existing users
blockRegistration: false # set to 'true' to deny login with SSO without prior registration by an admin
registrationId: stirling # The name of your Service Provider (SP) app name. Should match the name in the path for your SSO & SLO URLs
idpMetadataUri: https://dev-XXXXXXXX.okta.com/app/externalKey/sso/saml/metadata # The uri for your Provider's metadata
idpSingleLoginUrl: https://dev-XXXXXXXX.okta.com/app/dev-XXXXXXXX_stirlingpdf_1/externalKey/sso/saml # The URL for initiating SSO. Provided by your Provider
idpSingleLogoutUrl: https://dev-XXXXXXXX.okta.com/app/dev-XXXXXXXX_stirlingpdf_1/externalKey/slo/saml # The URL for initiating SLO. Provided by your Provider
idpIssuer: '' # The ID of your Provider
idpCert: classpath:okta.cert # The certificate your Provider will use to authenticate your app's SAML authentication requests. Provided by your Provider
privateKey: classpath:saml-private-key.key # Your private key. Generated from your keypair
spCert: classpath:saml-public-cert.crt # Your signing certificate. Generated from your keypair
registrationId: stirlingpdf-dario-saml # The name of your Service Provider (SP) app name. Should match the name in the path for your SSO & SLO URLs
idpMetadataUri: https://authentik.dev.stirlingpdf.com/api/v3/providers/saml/5/metadata/ # The uri for your Provider's metadata
idpSingleLoginUrl: https://authentik.dev.stirlingpdf.com/application/saml/stirlingpdf-dario-saml/sso/binding/post/ # The URL for initiating SSO. Provided by your Provider
idpSingleLogoutUrl: https://authentik.dev.stirlingpdf.com/application/saml/stirlingpdf-dario-saml/slo/binding/post/ # The URL for initiating SLO. Provided by your Provider
idpIssuer: authentik # The ID of your Provider
idpCert: classpath:authentik-Self-signed_Certificate_certificate.pem # The certificate your Provider will use to authenticate your app's SAML authentication requests. Provided by your Provider
privateKey: classpath:private_key.key # Your private key. Generated from your keypair
spCert: classpath:certificate.crt # Your signing certificate. Generated from your keypair
jwt:
enableKeyStore: true # Set to 'true' to enable JWT key store
enableKeyRotation: true # Set to 'true' to enable JWT key rotation
premium:
key: 00000000-0000-0000-0000-000000000000
key: 3R3T-WFPY-UNRW-LJFA-MMXM-YVJK-WCKY-PCRT # fixme: remove
enabled: false # Enable license key checks for pro/enterprise features
proFeatures:
database: true # Enable database features
SSOAutoLogin: false
CustomMetadata:
autoUpdateMetadata: false

View File

@ -46,10 +46,9 @@ export class DecryptFile {
formData.append('password', password);
}
// Send decryption request
const response = await fetch('/api/v1/security/remove-password', {
const response = await fetchWithCsrf('/api/v1/security/remove-password', {
method: 'POST',
body: formData,
headers: csrfToken ? {'X-XSRF-TOKEN': csrfToken} : undefined,
});
if (response.ok) {

View File

@ -218,7 +218,7 @@
formData.append('password', password);
// Use handleSingleDownload to send the request
const decryptionResult = await fetch(removePasswordUrl, {method: 'POST', body: formData});
const decryptionResult = await fetchWithCsrf(removePasswordUrl, {method: 'POST', body: formData});
if (decryptionResult && decryptionResult.blob) {
const decryptedBlob = await decryptionResult.blob();

View File

@ -1,3 +1,76 @@
// JWT Management Utility
window.JWTManager = {
JWT_STORAGE_KEY: 'stirling_jwt',
// Store JWT token in localStorage
storeToken: function(token) {
if (token) {
localStorage.setItem(this.JWT_STORAGE_KEY, token);
}
},
// Get JWT token from localStorage
getToken: function() {
return localStorage.getItem(this.JWT_STORAGE_KEY);
},
// Remove JWT token from localStorage
removeToken: function() {
localStorage.removeItem(this.JWT_STORAGE_KEY);
},
// Extract JWT from Authorization header in response
extractTokenFromResponse: function(response) {
const authHeader = response.headers.get('Authorization');
if (authHeader && authHeader.startsWith('Bearer ')) {
const token = authHeader.substring(7); // Remove 'Bearer ' prefix
this.storeToken(token);
return token;
}
return null;
},
// Check if user is authenticated (has valid JWT)
isAuthenticated: function() {
const token = this.getToken();
if (!token) return false;
try {
// Basic JWT expiration check (decode payload)
const payload = JSON.parse(atob(token.split('.')[1]));
const now = Date.now() / 1000;
return payload.exp > now;
} catch (error) {
console.warn('Invalid JWT token:', error);
this.removeToken();
return false;
}
},
// Logout - remove token and redirect to login
logout: function() {
this.removeToken();
// Clear all possible token storage locations
localStorage.removeItem(this.JWT_STORAGE_KEY);
sessionStorage.removeItem(this.JWT_STORAGE_KEY);
// Clear JWT cookie manually (fallback)
document.cookie = 'stirling_jwt=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure';
// Perform logout request to clear server-side session
fetch('/logout', {
method: 'POST',
credentials: 'include'
}).then(() => {
window.location.href = '/login';
}).catch(() => {
// Even if logout fails, redirect to login
window.location.href = '/login';
});
}
};
window.fetchWithCsrf = async function(url, options = {}) {
function getCsrfToken() {
const cookieValue = document.cookie
@ -24,5 +97,31 @@ window.fetchWithCsrf = async function(url, options = {}) {
fetchOptions.headers['X-XSRF-TOKEN'] = csrfToken;
}
return fetch(url, fetchOptions);
// Add JWT token to Authorization header if available
const jwtToken = window.JWTManager.getToken();
if (jwtToken) {
fetchOptions.headers['Authorization'] = `Bearer ${jwtToken}`;
// Include credentials when JWT is enabled
fetchOptions.credentials = 'include';
}
// Make the request
const response = await fetch(url, fetchOptions);
// Extract JWT from response if present
window.JWTManager.extractTokenFromResponse(response);
// Handle 401 responses (unauthorized)
if (response.status === 401) {
console.warn('Authentication failed, redirecting to login');
window.JWTManager.logout();
return response;
}
return response;
}
// Enhanced fetch function that always includes JWT
window.fetchWithJWT = async function(url, options = {}) {
return window.fetchWithCsrf(url, options);
}

View File

@ -0,0 +1,121 @@
// JWT Initialization Script
// This script handles JWT token extraction during OAuth/Login flows and initializes the JWT manager
(function() {
// Extract JWT token from URL parameters (for OAuth redirects)
function extractTokenFromUrl() {
const urlParams = new URLSearchParams(window.location.search);
const token = urlParams.get('jwt') || urlParams.get('token');
if (token) {
window.JWTManager.storeToken(token);
// Clean up URL by removing token parameter
urlParams.delete('jwt');
urlParams.delete('token');
const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : '');
window.history.replaceState({}, '', newUrl);
}
}
// Extract JWT token from cookie on page load (fallback)
function extractTokenFromCookie() {
const cookieValue = document.cookie
.split('; ')
.find(row => row.startsWith('stirling_jwt='))
?.split('=')[1];
if (cookieValue) {
window.JWTManager.storeToken(cookieValue);
// Clear the cookie since we're using localStorage with consistent SameSite policy
document.cookie = 'stirling_jwt=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure';
}
}
// Initialize JWT handling when page loads
function initializeJWT() {
// Try to extract token from URL first (OAuth flow)
extractTokenFromUrl();
// If no token in URL, try cookie (login flow)
if (!window.JWTManager.getToken()) {
extractTokenFromCookie();
}
// Check if user is authenticated
if (window.JWTManager.isAuthenticated()) {
console.log('User is authenticated with JWT');
} else {
console.log('User is not authenticated or token expired');
// Only redirect to login if we're not already on login/register pages
const currentPath = window.location.pathname;
if (!currentPath.includes('/login') &&
!currentPath.includes('/register') &&
!currentPath.includes('/oauth') &&
!currentPath.includes('/saml') &&
!currentPath.includes('/error')) {
// Redirect to login after a short delay to allow other scripts to load
setTimeout(() => {
window.location.href = '/login';
}, 100);
}
}
}
// Override form submissions to include JWT
function enhanceFormSubmissions() {
// Override form submit for login forms
document.addEventListener('submit', function(event) {
const form = event.target;
// Add JWT to form data if available
const jwtToken = window.JWTManager.getToken();
if (jwtToken && form.method && form.method.toLowerCase() !== 'get') {
// Create a hidden input for JWT
const jwtInput = document.createElement('input');
jwtInput.type = 'hidden';
jwtInput.name = 'jwt';
jwtInput.value = jwtToken;
form.appendChild(jwtInput);
}
});
}
// Add logout functionality to logout buttons
function enhanceLogoutButtons() {
document.addEventListener('click', function(event) {
const element = event.target;
// Check if clicked element is a logout button/link
if (element.matches('a[href="/logout"], button[data-action="logout"], .logout-btn')) {
event.preventDefault();
window.JWTManager.logout();
}
});
}
// Initialize when DOM is ready
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', function() {
initializeJWT();
enhanceFormSubmissions();
enhanceLogoutButtons();
});
} else {
initializeJWT();
enhanceFormSubmissions();
enhanceLogoutButtons();
}
// Handle page visibility changes to check token expiration
document.addEventListener('visibilitychange', function() {
if (!document.hidden && !window.JWTManager.isAuthenticated()) {
// Token expired while page was hidden, redirect to login
const currentPath = window.location.pathname;
if (!currentPath.includes('/login') &&
!currentPath.includes('/register') &&
!currentPath.includes('/oauth') &&
!currentPath.includes('/saml')) {
window.location.href = '/login';
}
}
});
})();

View File

@ -138,5 +138,19 @@ document.addEventListener('DOMContentLoaded', () => {
tooltipSetup();
setupDropdowns();
fixNavbarDropdownStyles();
// Setup logout button functionality
const logoutButton = document.querySelector('a[href="/logout"]');
if (logoutButton) {
logoutButton.addEventListener('click', function(event) {
event.preventDefault();
if (window.JWTManager) {
window.JWTManager.logout();
} else {
// Fallback if JWTManager is not available
window.location.href = '/logout';
}
});
}
});
window.addEventListener('resize', fixNavbarDropdownStyles);

View File

@ -102,7 +102,7 @@ async function fetchEndpointData() {
refreshBtn.classList.add('refreshing');
refreshBtn.disabled = true;
const response = await fetch('/api/v1/info/load/all');
const response = await fetchWithCsrf('/api/v1/info/load/all');
if (!response.ok) {
throw new Error('Network response was not ok');
}

View File

@ -1,9 +1,15 @@
repositories {
maven { url = "https://build.shibboleth.net/maven/releases" }
}
ext {
jwtVersion = '0.12.6'
}
bootRun {
enabled = false
}
spotless {
java {
target sourceSets.main.allJava
@ -38,6 +44,10 @@ dependencies {
implementation 'org.thymeleaf.extras:thymeleaf-extras-springsecurity5:3.1.3.RELEASE'
api 'io.micrometer:micrometer-registry-prometheus'
implementation 'com.unboundid.product.scim2:scim2-sdk-client:4.0.0'
api "io.jsonwebtoken:jjwt-api:$jwtVersion"
runtimeOnly "io.jsonwebtoken:jjwt-impl:$jwtVersion"
runtimeOnly "io.jsonwebtoken:jjwt-jackson:$jwtVersion"
runtimeOnly 'com.h2database:h2:2.3.232' // Don't upgrade h2database
runtimeOnly 'org.postgresql:postgresql:42.7.7'
constraints {

View File

@ -1,6 +1,7 @@
package stirling.software.proprietary.security;
import java.io.IOException;
import java.util.Map;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
@ -17,6 +18,8 @@ import stirling.software.common.util.RequestUriUtils;
import stirling.software.proprietary.audit.AuditEventType;
import stirling.software.proprietary.audit.AuditLevel;
import stirling.software.proprietary.audit.Audited;
import stirling.software.proprietary.security.model.AuthenticationType;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.LoginAttemptService;
import stirling.software.proprietary.security.service.UserService;
@ -24,13 +27,17 @@ import stirling.software.proprietary.security.service.UserService;
public class CustomAuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private UserService userService;
private final LoginAttemptService loginAttemptService;
private final UserService userService;
private final JwtServiceInterface jwtService;
public CustomAuthenticationSuccessHandler(
LoginAttemptService loginAttemptService, UserService userService) {
LoginAttemptService loginAttemptService,
UserService userService,
JwtServiceInterface jwtService) {
this.loginAttemptService = loginAttemptService;
this.userService = userService;
this.jwtService = jwtService;
}
@Override
@ -46,23 +53,35 @@ public class CustomAuthenticationSuccessHandler
}
loginAttemptService.loginSucceeded(userName);
// Get the saved request
HttpSession session = request.getSession(false);
SavedRequest savedRequest =
(session != null)
? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")
: null;
if (jwtService.isJwtEnabled()) {
try {
String jwt =
jwtService.generateToken(
authentication, Map.of("authType", AuthenticationType.WEB));
jwtService.addTokenToResponse(response, jwt);
log.debug("JWT generated for user: {}", userName);
} catch (Exception e) {
log.error("Failed to generate JWT token for user: {}", userName, e);
}
if (savedRequest != null
&& !RequestUriUtils.isStaticResource(
request.getContextPath(), savedRequest.getRedirectUrl())) {
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
// Redirect to the root URL (considering context path)
getRedirectStrategy().sendRedirect(request, response, "/");
}
} else {
// Get the saved request
HttpSession session = request.getSession(false);
SavedRequest savedRequest =
(session != null)
? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")
: null;
// super.onAuthenticationSuccess(request, response, authentication);
if (savedRequest != null
&& !RequestUriUtils.isStaticResource(
request.getContextPath(), savedRequest.getRedirectUrl())) {
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
// No saved request or it's a static resource, redirect to home page
getRedirectStrategy().sendRedirect(request, response, "/");
}
}
}
}

View File

@ -33,6 +33,7 @@ import stirling.software.proprietary.audit.AuditLevel;
import stirling.software.proprietary.audit.Audited;
import stirling.software.proprietary.security.saml2.CertificateUtils;
import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Slf4j
@RequiredArgsConstructor
@ -40,15 +41,18 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
public static final String LOGOUT_PATH = "/login?logout=true";
private final ApplicationProperties applicationProperties;
private final ApplicationProperties.Security securityProperties;
private final AppConfig appConfig;
private final JwtServiceInterface jwtService;
@Override
@Audited(type = AuditEventType.USER_LOGOUT, level = AuditLevel.BASIC)
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
if (!response.isCommitted()) {
if (authentication != null) {
if (authentication instanceof Saml2Authentication samlAuthentication) {
@ -67,6 +71,9 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
authentication.getClass().getSimpleName());
getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH);
}
} else if (!jwtService.extractTokenFromRequest(request).isBlank()) {
jwtService.clearTokenFromResponse(response);
getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH);
} else {
// Redirect to login page after logout
String path = checkForErrors(request);
@ -82,7 +89,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
Saml2Authentication samlAuthentication)
throws IOException {
SAML2 samlConf = applicationProperties.getSecurity().getSaml2();
SAML2 samlConf = securityProperties.getSaml2();
String registrationId = samlConf.getRegistrationId();
CustomSaml2AuthenticatedPrincipal principal =
@ -127,7 +134,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
OAuth2AuthenticationToken oAuthToken)
throws IOException {
String registrationId;
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
OAUTH2 oauth = securityProperties.getOauth2();
String path = checkForErrors(request);
String redirectUrl = UrlUtils.getOrigin(request) + "/login?" + path;

View File

@ -0,0 +1,22 @@
package stirling.software.proprietary.security;
import java.io.IOException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.stereotype.Component;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
@Component
public class JwtAuthenticationEntryPoint implements AuthenticationEntryPoint {
@Override
public void commence(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException authException)
throws IOException {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage());
}
}

View File

@ -184,7 +184,7 @@ public class AccountWebController {
errorOAuth = "login.relyingPartyRegistrationNotFound";
// Valid InResponseTo was not available from the validation context, unable to
// evaluate
case "invalid_in_response_to" -> errorOAuth = "login.invalid_in_response_to";
case "invalid_in_response_to" -> errorOAuth = "login.invalidInResponseTo";
case "not_authentication_provider_found" ->
errorOAuth = "login.not_authentication_provider_found";
}

View File

@ -8,11 +8,14 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.ProviderManager;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.config.annotation.authentication.configuration.AuthenticationConfiguration;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
@ -35,10 +38,12 @@ import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.CustomAuthenticationFailureHandler;
import stirling.software.proprietary.security.CustomAuthenticationSuccessHandler;
import stirling.software.proprietary.security.CustomLogoutSuccessHandler;
import stirling.software.proprietary.security.JwtAuthenticationEntryPoint;
import stirling.software.proprietary.security.database.repository.JPATokenRepositoryImpl;
import stirling.software.proprietary.security.database.repository.PersistentLoginRepository;
import stirling.software.proprietary.security.filter.FirstLoginFilter;
import stirling.software.proprietary.security.filter.IPRateLimitingFilter;
import stirling.software.proprietary.security.filter.JwtAuthenticationFilter;
import stirling.software.proprietary.security.filter.UserAuthenticationFilter;
import stirling.software.proprietary.security.model.User;
import stirling.software.proprietary.security.oauth2.CustomOAuth2AuthenticationFailureHandler;
@ -48,6 +53,7 @@ import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticationSuc
import stirling.software.proprietary.security.saml2.CustomSaml2ResponseAuthenticationConverter;
import stirling.software.proprietary.security.service.CustomOAuth2UserService;
import stirling.software.proprietary.security.service.CustomUserDetailsService;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.LoginAttemptService;
import stirling.software.proprietary.security.service.UserService;
import stirling.software.proprietary.security.session.SessionPersistentRegistry;
@ -64,9 +70,11 @@ public class SecurityConfiguration {
private final boolean loginEnabledValue;
private final boolean runningProOrHigher;
private final ApplicationProperties applicationProperties;
private final ApplicationProperties.Security securityProperties;
private final AppConfig appConfig;
private final UserAuthenticationFilter userAuthenticationFilter;
private final JwtServiceInterface jwtService;
private final JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint;
private final LoginAttemptService loginAttemptService;
private final FirstLoginFilter firstLoginFilter;
private final SessionPersistentRegistry sessionRegistry;
@ -82,8 +90,10 @@ public class SecurityConfiguration {
@Qualifier("loginEnabled") boolean loginEnabledValue,
@Qualifier("runningProOrHigher") boolean runningProOrHigher,
AppConfig appConfig,
ApplicationProperties applicationProperties,
ApplicationProperties.Security securityProperties,
UserAuthenticationFilter userAuthenticationFilter,
JwtServiceInterface jwtService,
JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint,
LoginAttemptService loginAttemptService,
FirstLoginFilter firstLoginFilter,
SessionPersistentRegistry sessionRegistry,
@ -97,8 +107,10 @@ public class SecurityConfiguration {
this.loginEnabledValue = loginEnabledValue;
this.runningProOrHigher = runningProOrHigher;
this.appConfig = appConfig;
this.applicationProperties = applicationProperties;
this.securityProperties = securityProperties;
this.userAuthenticationFilter = userAuthenticationFilter;
this.jwtService = jwtService;
this.jwtAuthenticationEntryPoint = jwtAuthenticationEntryPoint;
this.loginAttemptService = loginAttemptService;
this.firstLoginFilter = firstLoginFilter;
this.sessionRegistry = sessionRegistry;
@ -115,14 +127,28 @@ public class SecurityConfiguration {
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
if (applicationProperties.getSecurity().getCsrfDisabled() || !loginEnabledValue) {
http.csrf(csrf -> csrf.disable());
if (securityProperties.getCsrfDisabled() || !loginEnabledValue) {
http.csrf(CsrfConfigurer::disable);
}
if (loginEnabledValue) {
boolean v2Enabled = appConfig.v2Enabled();
if (v2Enabled) {
http.addFilterBefore(
jwtAuthenticationFilter(),
UsernamePasswordAuthenticationFilter.class)
.exceptionHandling(
exceptionHandling ->
exceptionHandling.authenticationEntryPoint(
jwtAuthenticationEntryPoint));
}
http.addFilterBefore(
userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);
if (!applicationProperties.getSecurity().getCsrfDisabled()) {
userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class)
.addFilterAfter(rateLimitingFilter(), UserAuthenticationFilter.class)
.addFilterAfter(firstLoginFilter, UsernamePasswordAuthenticationFilter.class);
if (!securityProperties.getCsrfDisabled()) {
CookieCsrfTokenRepository cookieRepo =
CookieCsrfTokenRepository.withHttpOnlyFalse();
CsrfTokenRequestAttributeHandler requestHandler =
@ -156,16 +182,22 @@ public class SecurityConfiguration {
.csrfTokenRepository(cookieRepo)
.csrfTokenRequestHandler(requestHandler));
}
http.addFilterBefore(rateLimitingFilter(), UsernamePasswordAuthenticationFilter.class);
http.addFilterAfter(firstLoginFilter, UsernamePasswordAuthenticationFilter.class);
// Configure session management based on JWT setting
http.sessionManagement(
sessionManagement ->
sessionManagement -> {
if (v2Enabled && !securityProperties.isSaml2Active()) {
sessionManagement.sessionCreationPolicy(
SessionCreationPolicy.STATELESS);
} else {
sessionManagement
.sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED)
.maximumSessions(10)
.maxSessionsPreventsLogin(false)
.sessionRegistry(sessionRegistry)
.expiredUrl("/login?logout=true"));
.expiredUrl("/login?logout=true");
}
});
http.authenticationProvider(daoAuthenticationProvider());
http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()));
http.logout(
@ -175,10 +207,10 @@ public class SecurityConfiguration {
.matcher("/logout"))
.logoutSuccessHandler(
new CustomLogoutSuccessHandler(
applicationProperties, appConfig))
securityProperties, appConfig, jwtService))
.clearAuthentication(true)
.invalidateHttpSession(true)
.deleteCookies("JSESSIONID", "remember-me"));
.deleteCookies("JSESSIONID", "remember-me", "stirling_jwt"));
http.rememberMe(
rememberMeConfigurer -> // Use the configurator directly
rememberMeConfigurer
@ -200,6 +232,7 @@ public class SecurityConfiguration {
req -> {
String uri = req.getRequestURI();
String contextPath = req.getContextPath();
// Remove the context path from the URI
String trimmedUri =
uri.startsWith(contextPath)
@ -217,29 +250,33 @@ public class SecurityConfiguration {
|| trimmedUri.startsWith("/css/")
|| trimmedUri.startsWith("/fonts/")
|| trimmedUri.startsWith("/js/")
|| trimmedUri.startsWith("/favicon")
|| trimmedUri.startsWith(
"/api/v1/info/status");
"/api/v1/info/status")
|| trimmedUri.startsWith("/v1/api-docs")
|| uri.contains("/v1/api-docs");
})
.permitAll()
.anyRequest()
.authenticated());
// Handle User/Password Logins
if (applicationProperties.getSecurity().isUserPass()) {
if (securityProperties.isUserPass()) {
http.formLogin(
formLogin ->
formLogin
.loginPage("/login")
.successHandler(
new CustomAuthenticationSuccessHandler(
loginAttemptService, userService))
loginAttemptService,
userService,
jwtService))
.failureHandler(
new CustomAuthenticationFailureHandler(
loginAttemptService, userService))
.defaultSuccessUrl("/")
.permitAll());
}
// Handle OAUTH2 Logins
if (applicationProperties.getSecurity().isOauth2Active()) {
if (securityProperties.isOauth2Active()) {
http.oauth2Login(
oauth2 ->
oauth2.loginPage("/oauth2")
@ -251,17 +288,18 @@ public class SecurityConfiguration {
.successHandler(
new CustomOAuth2AuthenticationSuccessHandler(
loginAttemptService,
applicationProperties,
userService))
securityProperties.getOauth2(),
userService,
jwtService))
.failureHandler(
new CustomOAuth2AuthenticationFailureHandler())
. // Add existing Authorities from the database
userInfoEndpoint(
// Add existing Authorities from the database
.userInfoEndpoint(
userInfoEndpoint ->
userInfoEndpoint
.oidcUserService(
new CustomOAuth2UserService(
applicationProperties,
securityProperties,
userService,
loginAttemptService))
.userAuthoritiesMapper(
@ -269,8 +307,7 @@ public class SecurityConfiguration {
.permitAll());
}
// Handle SAML
if (applicationProperties.getSecurity().isSaml2Active() && runningProOrHigher) {
// Configure the authentication provider
if (securityProperties.isSaml2Active() && runningProOrHigher) {
OpenSaml4AuthenticationProvider authenticationProvider =
new OpenSaml4AuthenticationProvider();
authenticationProvider.setResponseAuthenticationConverter(
@ -287,8 +324,9 @@ public class SecurityConfiguration {
.successHandler(
new CustomSaml2AuthenticationSuccessHandler(
loginAttemptService,
applicationProperties,
userService))
securityProperties.getSaml2(),
userService,
jwtService))
.failureHandler(
new CustomSaml2AuthenticationFailureHandler())
.authenticationRequestResolver(
@ -306,6 +344,12 @@ public class SecurityConfiguration {
return http.build();
}
@Bean
public AuthenticationManager authenticationManager(AuthenticationConfiguration configuration)
throws Exception {
return configuration.getAuthenticationManager();
}
public DaoAuthenticationProvider daoAuthenticationProvider() {
DaoAuthenticationProvider provider = new DaoAuthenticationProvider(userDetailsService);
provider.setPasswordEncoder(passwordEncoder());
@ -323,4 +367,14 @@ public class SecurityConfiguration {
public PersistentTokenRepository persistentTokenRepository() {
return new JPATokenRepositoryImpl(persistentLoginRepository);
}
@Bean
public JwtAuthenticationFilter jwtAuthenticationFilter() {
return new JwtAuthenticationFilter(
jwtService,
userService,
userDetailsService,
jwtAuthenticationEntryPoint,
securityProperties);
}
}

View File

@ -0,0 +1,18 @@
package stirling.software.proprietary.security.database.repository;
import java.util.Optional;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
import stirling.software.proprietary.security.model.JwtSigningKey;
@Repository
public interface JwtSigningKeyRepository extends JpaRepository<JwtSigningKey, Long> {
Optional<JwtSigningKey> findByIsActiveTrue();
Optional<JwtSigningKey> findByKeyId(String keyId);
Optional<JwtSigningKey> findByKeyIdAndIsActiveTrue(String keyId);
}

View File

@ -0,0 +1,175 @@
package stirling.software.proprietary.security.filter;
import static stirling.software.common.util.RequestUriUtils.isStaticResource;
import static stirling.software.proprietary.security.model.AuthenticationType.*;
import static stirling.software.proprietary.security.model.AuthenticationType.SAML2;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Map;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.web.filter.OncePerRequestFilter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.model.exception.UnsupportedProviderException;
import stirling.software.proprietary.security.model.AuthenticationType;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import stirling.software.proprietary.security.service.CustomUserDetailsService;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.UserService;
@Slf4j
public class JwtAuthenticationFilter extends OncePerRequestFilter {
private final JwtServiceInterface jwtService;
private final UserService userService;
private final CustomUserDetailsService userDetailsService;
private final AuthenticationEntryPoint authenticationEntryPoint;
private final ApplicationProperties.Security securityProperties;
public JwtAuthenticationFilter(
JwtServiceInterface jwtService,
UserService userService,
CustomUserDetailsService userDetailsService,
AuthenticationEntryPoint authenticationEntryPoint,
ApplicationProperties.Security securityProperties) {
this.jwtService = jwtService;
this.userService = userService;
this.userDetailsService = userDetailsService;
this.authenticationEntryPoint = authenticationEntryPoint;
this.securityProperties = securityProperties;
}
@Override
protected void doFilterInternal(
HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if (!jwtService.isJwtEnabled()) {
filterChain.doFilter(request, response);
return;
}
if (isStaticResource(request.getContextPath(), request.getRequestURI())) {
filterChain.doFilter(request, response);
return;
}
String jwtToken = jwtService.extractTokenFromRequest(request);
if (jwtToken == null) {
// If they are unauthenticated and navigating to '/', redirect to '/login' instead of
// sending a 401
if ("/".equals(request.getRequestURI())
&& "GET".equalsIgnoreCase(request.getMethod())) {
response.sendRedirect("/login");
return;
}
handleAuthenticationFailure(
request,
response,
new AuthenticationFailureException("JWT is missing from the request"));
return;
}
try {
jwtService.validateToken(jwtToken);
} catch (AuthenticationFailureException e) {
// Clear invalid tokens from response
jwtService.clearTokenFromResponse(response);
handleAuthenticationFailure(request, response, e);
return;
}
Map<String, Object> claims = jwtService.extractAllClaims(jwtToken);
String tokenUsername = claims.get("sub").toString();
try {
Authentication authentication = createAuthentication(request, claims);
String jwt = jwtService.generateToken(authentication, claims);
jwtService.addTokenToResponse(response, jwt);
} catch (SQLException | UnsupportedProviderException e) {
log.error("Error processing user authentication for user: {}", tokenUsername, e);
handleAuthenticationFailure(
request,
response,
new AuthenticationFailureException("Error processing user authentication", e));
return;
}
filterChain.doFilter(request, response);
}
private Authentication createAuthentication(
HttpServletRequest request, Map<String, Object> claims)
throws SQLException, UnsupportedProviderException {
String username = claims.get("sub").toString();
if (username != null && SecurityContextHolder.getContext().getAuthentication() == null) {
processUserAuthenticationType(claims, username);
UserDetails userDetails = userDetailsService.loadUserByUsername(username);
if (userDetails != null) {
UsernamePasswordAuthenticationToken authToken =
new UsernamePasswordAuthenticationToken(
userDetails, null, userDetails.getAuthorities());
authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request));
SecurityContextHolder.getContext().setAuthentication(authToken);
log.info(
"JWT authentication successful for user: {} - Authentication set in SecurityContext",
username);
} else {
throw new UsernameNotFoundException("User not found: " + username);
}
}
return SecurityContextHolder.getContext().getAuthentication();
}
private void processUserAuthenticationType(Map<String, Object> claims, String username)
throws SQLException, UnsupportedProviderException {
AuthenticationType authenticationType =
AuthenticationType.valueOf(claims.getOrDefault("authType", WEB).toString());
log.debug("Processing {} login for {} user", authenticationType, username);
switch (authenticationType) {
case OAUTH2 -> {
ApplicationProperties.Security.OAUTH2 oauth2Properties =
securityProperties.getOauth2();
userService.processSSOPostLogin(
username, oauth2Properties.getAutoCreateUser(), OAUTH2);
}
case SAML2 -> {
ApplicationProperties.Security.SAML2 saml2Properties =
securityProperties.getSaml2();
userService.processSSOPostLogin(
username, saml2Properties.getAutoCreateUser(), SAML2);
}
}
}
private void handleAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException authException)
throws IOException, ServletException {
authenticationEntryPoint.commence(request, response, authException);
}
}

View File

@ -9,7 +9,6 @@ import org.springframework.context.annotation.Lazy;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.userdetails.UserDetails;
@ -64,7 +63,15 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
return;
}
String requestURI = request.getRequestURI();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
log.info(
"UserAuthenticationFilter - Authentication from SecurityContext: {}",
authentication != null
? authentication.getClass().getSimpleName()
+ " for "
+ authentication.getName()
: "null");
// Check for session expiration (unsure if needed)
// if (authentication != null && authentication.isAuthenticated()) {
@ -92,14 +99,9 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
response.getWriter().write("Invalid API Key.");
return;
}
List<SimpleGrantedAuthority> authorities =
user.get().getAuthorities().stream()
.map(
authority ->
new SimpleGrantedAuthority(
authority.getAuthority()))
.toList();
authentication = new ApiKeyAuthenticationToken(user.get(), apiKey, authorities);
authentication =
new ApiKeyAuthenticationToken(
user.get(), apiKey, user.get().getAuthorities());
SecurityContextHolder.getContext().setAuthentication(authentication);
} catch (AuthenticationException e) {
// If API key authentication fails, deny the request
@ -117,18 +119,18 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
if ("GET".equalsIgnoreCase(method) && !(contextPath + "/login").equals(requestURI)) {
response.sendRedirect(contextPath + "/login"); // redirect to the login page
return;
} else {
response.setStatus(HttpStatus.UNAUTHORIZED.value());
response.getWriter()
.write(
"Authentication required. Please provide a X-API-KEY in request"
+ " header.\n"
+ "This is found in Settings -> Account Settings -> API Key\n"
+ "Alternatively you can disable authentication if this is"
+ " unexpected");
return;
"""
Authentication required. Please provide a X-API-KEY in request\
header.
This is found in Settings -> Account Settings -> API Key
Alternatively you can disable authentication if this is\
unexpected""");
}
return;
}
// Check if the authenticated user is disabled and invalidate their session if so
@ -226,11 +228,12 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
}
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
protected boolean shouldNotFilter(HttpServletRequest request) {
String uri = request.getRequestURI();
String contextPath = request.getContextPath();
String[] permitAllPatterns = {
contextPath + "/login",
contextPath + "/signup",
contextPath + "/register",
contextPath + "/error",
contextPath + "/images/",
@ -247,6 +250,7 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
for (String pattern : permitAllPatterns) {
if (uri.startsWith(pattern)
|| uri.endsWith(".svg")
|| uri.endsWith(".mjs")
|| uri.endsWith(".png")
|| uri.endsWith(".ico")) {
return true;

View File

@ -2,5 +2,7 @@ package stirling.software.proprietary.security.model;
public enum AuthenticationType {
WEB,
SSO
SSO,
OAUTH2,
SAML2
}

View File

@ -2,6 +2,8 @@ package stirling.software.proprietary.security.model;
import java.io.Serializable;
import org.springframework.security.core.GrantedAuthority;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
@ -18,7 +20,7 @@ import lombok.Setter;
@Table(name = "authorities")
@Getter
@Setter
public class Authority implements Serializable {
public class Authority implements GrantedAuthority, Serializable {
private static final long serialVersionUID = 1L;

View File

@ -0,0 +1,62 @@
package stirling.software.proprietary.security.model;
import java.io.Serializable;
import java.time.LocalDateTime;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.ToString;
@Entity
@Getter
@Setter
@NoArgsConstructor
@Table(name = "signing_keys")
@ToString(onlyExplicitlyIncluded = true)
@EqualsAndHashCode(onlyExplicitlyIncluded = true)
public class JwtSigningKey implements Serializable {
private static final long serialVersionUID = 1L;
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "signing_key_id")
@EqualsAndHashCode.Include
@ToString.Include
private Long id;
@Column(name = "key_id", nullable = false, unique = true)
@ToString.Include
private String keyId;
@Column(name = "signing_key", columnDefinition = "TEXT", nullable = false)
private String signingKey;
@Column(name = "algorithm", nullable = false)
private String algorithm = "RS256";
@Column(name = "created_at", nullable = false)
@ToString.Include
private LocalDateTime createdAt;
@Column(name = "is_active", nullable = false)
@ToString.Include
private Boolean isActive = true;
public JwtSigningKey(String keyId, String signingKey, String algorithm) {
this.keyId = keyId;
this.signingKey = signingKey;
this.algorithm = algorithm;
this.createdAt = LocalDateTime.now();
this.isActive = true;
}
}

View File

@ -7,6 +7,8 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.security.core.userdetails.UserDetails;
import jakarta.persistence.*;
import lombok.EqualsAndHashCode;
@ -25,7 +27,7 @@ import stirling.software.proprietary.model.Team;
@Setter
@EqualsAndHashCode(onlyExplicitlyIncluded = true)
@ToString(onlyExplicitlyIncluded = true)
public class User implements Serializable {
public class User implements UserDetails, Serializable {
private static final long serialVersionUID = 1L;

View File

@ -0,0 +1,13 @@
package stirling.software.proprietary.security.model.exception;
import org.springframework.security.core.AuthenticationException;
public class AuthenticationFailureException extends AuthenticationException {
public AuthenticationFailureException(String message) {
super(message);
}
public AuthenticationFailureException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -1,7 +1,11 @@
package stirling.software.proprietary.security.oauth2;
import static stirling.software.proprietary.security.model.AuthenticationType.OAUTH2;
import static stirling.software.proprietary.security.model.AuthenticationType.SSO;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Map;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
@ -18,10 +22,10 @@ import jakarta.servlet.http.HttpSession;
import lombok.RequiredArgsConstructor;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.common.model.exception.UnsupportedProviderException;
import stirling.software.common.util.RequestUriUtils;
import stirling.software.proprietary.security.model.AuthenticationType;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.LoginAttemptService;
import stirling.software.proprietary.security.service.UserService;
@ -30,8 +34,9 @@ public class CustomOAuth2AuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private final LoginAttemptService loginAttemptService;
private final ApplicationProperties applicationProperties;
private final ApplicationProperties.Security.OAUTH2 oauth2Properties;
private final UserService userService;
private final JwtServiceInterface jwtService;
@Override
public void onAuthenticationSuccess(
@ -60,8 +65,6 @@ public class CustomOAuth2AuthenticationSuccessHandler
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
OAUTH2 oAuth = applicationProperties.getSecurity().getOauth2();
if (loginAttemptService.isBlocked(username)) {
if (session != null) {
session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST");
@ -69,7 +72,12 @@ public class CustomOAuth2AuthenticationSuccessHandler
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (jwtService.isJwtEnabled()) {
String jwt =
jwtService.generateToken(
authentication, Map.of("authType", AuthenticationType.OAUTH2));
jwtService.addTokenToResponse(response, jwt);
}
if (userService.isUserDisabled(username)) {
getRedirectStrategy()
.sendRedirect(request, response, "/logout?userIsDisabled=true");
@ -77,20 +85,22 @@ public class CustomOAuth2AuthenticationSuccessHandler
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(username, AuthenticationType.SSO)
&& oAuth.getAutoCreateUser()) {
&& (!userService.isAuthenticationTypeByUsername(username, SSO)
|| !userService.isAuthenticationTypeByUsername(username, OAUTH2))
&& oauth2Properties.getAutoCreateUser()) {
response.sendRedirect(contextPath + "/logout?oAuth2AuthenticationErrorWeb=true");
return;
}
try {
if (oAuth.getBlockRegistration()
if (oauth2Properties.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(contextPath + "/logout?oAuth2AdminBlockedUser=true");
return;
}
if (principal instanceof OAuth2User) {
userService.processSSOPostLogin(username, oAuth.getAutoCreateUser());
userService.processSSOPostLogin(
username, oauth2Properties.getAutoCreateUser(), OAUTH2);
}
response.sendRedirect(contextPath + "/");
} catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) {

View File

@ -34,6 +34,7 @@ import stirling.software.common.model.oauth2.GitHubProvider;
import stirling.software.common.model.oauth2.GoogleProvider;
import stirling.software.common.model.oauth2.KeycloakProvider;
import stirling.software.common.model.oauth2.Provider;
import stirling.software.proprietary.security.model.Authority;
import stirling.software.proprietary.security.model.User;
import stirling.software.proprietary.security.model.exception.NoProviderFoundException;
import stirling.software.proprietary.security.service.UserService;
@ -239,12 +240,14 @@ public class OAuth2Configuration {
Optional<User> userOpt =
userService.findByUsernameIgnoreCase(
(String) oAuth2Auth.getAttributes().get(useAsUsername));
if (userOpt.isPresent()) {
User user = userOpt.get();
mappedAuthorities.add(
new SimpleGrantedAuthority(
userService.findRole(user).getAuthority()));
}
userOpt.ifPresent(
user ->
mappedAuthorities.add(
new Authority(
userService
.findRole(user)
.getAuthority(),
user)));
}
});
return mappedAuthorities;

View File

@ -1,7 +1,11 @@
package stirling.software.proprietary.security.saml2;
import static stirling.software.proprietary.security.model.AuthenticationType.SAML2;
import static stirling.software.proprietary.security.model.AuthenticationType.SSO;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Map;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
@ -17,10 +21,10 @@ import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.model.ApplicationProperties.Security.SAML2;
import stirling.software.common.model.exception.UnsupportedProviderException;
import stirling.software.common.util.RequestUriUtils;
import stirling.software.proprietary.security.model.AuthenticationType;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.LoginAttemptService;
import stirling.software.proprietary.security.service.UserService;
@ -30,8 +34,9 @@ public class CustomSaml2AuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private ApplicationProperties applicationProperties;
private ApplicationProperties.Security.SAML2 saml2Properties;
private UserService userService;
private final JwtServiceInterface jwtService;
@Override
public void onAuthenticationSuccess(
@ -65,10 +70,9 @@ public class CustomSaml2AuthenticationSuccessHandler
savedRequest.getRedirectUrl());
super.onAuthenticationSuccess(request, response, authentication);
} else {
SAML2 saml2 = applicationProperties.getSecurity().getSaml2();
log.debug(
"Processing SAML2 authentication with autoCreateUser: {}",
saml2.getAutoCreateUser());
saml2Properties.getAutoCreateUser());
if (loginAttemptService.isBlocked(username)) {
log.debug("User {} is blocked due to too many login attempts", username);
@ -82,17 +86,21 @@ public class CustomSaml2AuthenticationSuccessHandler
boolean userExists = userService.usernameExistsIgnoreCase(username);
boolean hasPassword = userExists && userService.hasPassword(username);
boolean isSSOUser =
userExists
&& userService.isAuthenticationTypeByUsername(
username, AuthenticationType.SSO);
userExists && userService.isAuthenticationTypeByUsername(username, SSO);
boolean isSAML2User =
userExists && userService.isAuthenticationTypeByUsername(username, SAML2);
log.debug(
"User status - Exists: {}, Has password: {}, Is SSO user: {}",
"User status - Exists: {}, Has password: {}, Is SSO user: {}, Is SAML2 user: {}",
userExists,
hasPassword,
isSSOUser);
isSSOUser,
isSAML2User);
if (userExists && hasPassword && !isSSOUser && saml2.getAutoCreateUser()) {
if (userExists
&& hasPassword
&& (!isSSOUser || !isSAML2User)
&& saml2Properties.getAutoCreateUser()) {
log.debug(
"User {} exists with password but is not SSO user, redirecting to logout",
username);
@ -102,15 +110,18 @@ public class CustomSaml2AuthenticationSuccessHandler
}
try {
if (saml2.getBlockRegistration() && !userExists) {
if (!userExists || saml2Properties.getBlockRegistration()) {
log.debug("Registration blocked for new user: {}", username);
response.sendRedirect(
contextPath + "/login?errorOAuth=oAuth2AdminBlockedUser");
return;
}
log.debug("Processing SSO post-login for user: {}", username);
userService.processSSOPostLogin(username, saml2.getAutoCreateUser());
userService.processSSOPostLogin(
username, saml2Properties.getAutoCreateUser(), SAML2);
log.debug("Successfully processed authentication for user: {}", username);
generateJWT(response, authentication);
response.sendRedirect(contextPath + "/");
} catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) {
log.debug(
@ -124,4 +135,13 @@ public class CustomSaml2AuthenticationSuccessHandler
super.onAuthenticationSuccess(request, response, authentication);
}
}
private void generateJWT(HttpServletResponse response, Authentication authentication) {
if (jwtService.isJwtEnabled()) {
String jwt =
jwtService.generateToken(
authentication, Map.of("authType", AuthenticationType.SAML2));
jwtService.addTokenToResponse(response, jwt);
}
}
}

View File

@ -0,0 +1,135 @@
package stirling.software.proprietary.security.saml2;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Slf4j
public class JwtSaml2AuthenticationRequestRepository
implements Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest> {
private final Map<String, String> tokenStore;
private final JwtServiceInterface jwtService;
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token";
public JwtSaml2AuthenticationRequestRepository(
Map<String, String> tokenStore,
JwtServiceInterface jwtService,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this.tokenStore = tokenStore;
this.jwtService = jwtService;
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
}
@Override
public void saveAuthenticationRequest(
Saml2PostAuthenticationRequest authRequest,
HttpServletRequest request,
HttpServletResponse response) {
if (!jwtService.isJwtEnabled()) {
log.debug("V2 is not enabled, skipping SAMLRequest token storage");
return;
}
if (authRequest == null) {
removeAuthenticationRequest(request, response);
return;
}
Map<String, Object> claims = serializeSamlRequest(authRequest);
String token = jwtService.generateToken("", claims);
String relayState = authRequest.getRelayState();
tokenStore.put(relayState, token);
request.setAttribute(SAML_REQUEST_TOKEN, relayState);
response.addHeader(SAML_REQUEST_TOKEN, relayState);
log.debug("Saved SAMLRequest token with RelayState: {}", relayState);
}
@Override
public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
String token = extractTokenFromStore(request);
if (token == null) {
log.debug("No SAMLResponse token found in RelayState");
return null;
}
Map<String, Object> claims = jwtService.extractAllClaims(token);
return deserializeSamlRequest(claims);
}
@Override
public Saml2PostAuthenticationRequest removeAuthenticationRequest(
HttpServletRequest request, HttpServletResponse response) {
Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request);
String relayStateId = request.getParameter("RelayState");
if (relayStateId != null) {
tokenStore.remove(relayStateId);
log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId);
}
return authRequest;
}
private String extractTokenFromStore(HttpServletRequest request) {
String authnRequestId = request.getParameter("RelayState");
if (authnRequestId != null && !authnRequestId.isEmpty()) {
String token = tokenStore.get(authnRequestId);
if (token != null) {
tokenStore.remove(authnRequestId);
log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId);
return token;
} else {
log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId);
}
}
return null;
}
private Map<String, Object> serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) {
Map<String, Object> claims = new HashMap<>();
claims.put("id", authRequest.getId());
claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId());
claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri());
claims.put("samlRequest", authRequest.getSamlRequest());
claims.put("relayState", authRequest.getRelayState());
return claims;
}
private Saml2PostAuthenticationRequest deserializeSamlRequest(Map<String, Object> claims) {
String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId");
RelyingPartyRegistration relyingPartyRegistration =
relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId);
if (relyingPartyRegistration == null) {
return null;
}
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration)
.id((String) claims.get("id"))
.authenticationRequestUri((String) claims.get("authenticationRequestUri"))
.samlRequest((String) claims.get("samlRequest"))
.relayState((String) claims.get("relayState"))
.build();
}
}

View File

@ -3,6 +3,7 @@ package stirling.software.proprietary.security.saml2;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@ -11,12 +12,12 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
import jakarta.servlet.http.HttpServletRequest;
@ -26,12 +27,13 @@ import lombok.extern.slf4j.Slf4j;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.model.ApplicationProperties.Security.SAML2;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Configuration
@Slf4j
@ConditionalOnProperty(value = "security.saml2.enabled", havingValue = "true")
@RequiredArgsConstructor
public class SAML2Configuration {
public class Saml2Configuration {
private final ApplicationProperties applicationProperties;
@ -58,6 +60,7 @@ public class SAML2Configuration {
.assertionConsumerServiceBinding(Saml2MessageBinding.POST)
.assertionConsumerServiceLocation(
"{baseUrl}/login/saml2/sso/{registrationId}")
.authnRequestsSigned(true)
.assertingPartyMetadata(
metadata ->
metadata.entityId(samlConf.getIdpIssuer())
@ -71,15 +74,29 @@ public class SAML2Configuration {
Saml2MessageBinding.POST)
.singleLogoutServiceLocation(
samlConf.getIdpSingleLogoutUrl())
.singleLogoutServiceResponseLocation(
"http://localhost:8080/login")
.wantAuthnRequestsSigned(true))
.build();
return new InMemoryRelyingPartyRegistrationRepository(rp);
}
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest>
saml2AuthenticationRequestRepository(
JwtServiceInterface jwtService,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
return new JwtSaml2AuthenticationRequestRepository(
new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository);
}
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest>
saml2AuthenticationRequestRepository) {
OpenSaml4AuthenticationRequestResolver resolver =
new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository);
@ -87,10 +104,8 @@ public class SAML2Configuration {
customizer -> {
HttpServletRequest request = customizer.getRequest();
AuthnRequest authnRequest = customizer.getAuthnRequest();
HttpSessionSaml2AuthenticationRequestRepository requestRepository =
new HttpSessionSaml2AuthenticationRequestRepository();
AbstractSaml2AuthenticationRequest saml2AuthenticationRequest =
requestRepository.loadAuthenticationRequest(request);
Saml2PostAuthenticationRequest saml2AuthenticationRequest =
saml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
if (saml2AuthenticationRequest != null) {
String sessionId = request.getSession(false).getId();
@ -113,7 +128,6 @@ public class SAML2Configuration {
log.debug("Generating new authentication request ID");
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
}
logAuthnRequestDetails(authnRequest);
logHttpRequestDetails(request);
});

View File

@ -27,13 +27,13 @@ public class CustomOAuth2UserService implements OAuth2UserService<OidcUserReques
private final LoginAttemptService loginAttemptService;
private final ApplicationProperties applicationProperties;
private final ApplicationProperties.Security securityProperties;
public CustomOAuth2UserService(
ApplicationProperties applicationProperties,
ApplicationProperties.Security securityProperties,
UserService userService,
LoginAttemptService loginAttemptService) {
this.applicationProperties = applicationProperties;
this.securityProperties = securityProperties;
this.userService = userService;
this.loginAttemptService = loginAttemptService;
}
@ -42,7 +42,7 @@ public class CustomOAuth2UserService implements OAuth2UserService<OidcUserReques
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
try {
OidcUser user = delegate.loadUser(userRequest);
OAUTH2 oauth2 = applicationProperties.getSecurity().getOauth2();
OAUTH2 oauth2 = securityProperties.getOauth2();
UsernameAttribute usernameAttribute =
UsernameAttribute.valueOf(oauth2.getUseAsUsername().toUpperCase());
String usernameAttributeKey = usernameAttribute.getName();

View File

@ -1,11 +1,6 @@
package stirling.software.proprietary.security.service;
import java.util.Collection;
import java.util.Set;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
@ -14,7 +9,7 @@ import org.springframework.stereotype.Service;
import lombok.RequiredArgsConstructor;
import stirling.software.proprietary.security.database.repository.UserRepository;
import stirling.software.proprietary.security.model.Authority;
import stirling.software.proprietary.security.model.AuthenticationType;
import stirling.software.proprietary.security.model.User;
@Service
@ -34,26 +29,18 @@ public class CustomUserDetailsService implements UserDetailsService {
() ->
new UsernameNotFoundException(
"No user found with username: " + username));
if (loginAttemptService.isBlocked(username)) {
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (!user.hasPassword()) {
AuthenticationType userAuthenticationType =
AuthenticationType.valueOf(user.getAuthenticationType().toUpperCase());
if (!user.hasPassword() && userAuthenticationType == AuthenticationType.WEB) {
throw new IllegalArgumentException("Password must not be null");
}
return new org.springframework.security.core.userdetails.User(
user.getUsername(),
user.getPassword(),
user.isEnabled(),
true,
true,
true,
getAuthorities(user.getAuthorities()));
}
private Collection<? extends GrantedAuthority> getAuthorities(Set<Authority> authorities) {
return authorities.stream()
.map(authority -> new SimpleGrantedAuthority(authority.getAuthority()))
.toList();
return user;
}
}

View File

@ -0,0 +1,238 @@
package stirling.software.proprietary.security.service;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Base64;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
import stirling.software.common.configuration.InstallationPathConfig;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.database.repository.JwtSigningKeyRepository;
import stirling.software.proprietary.security.model.JwtSigningKey;
@Service
@Slf4j
public class JwtKeystoreService implements JwtKeystoreServiceInterface {
public static final String KEY_SUFFIX = ".key";
private final JwtSigningKeyRepository repository;
private final ApplicationProperties.Security.Jwt jwtProperties;
private final Path privateKeyDirectory;
private volatile KeyPair currentKeyPair;
private volatile String currentKeyId;
@Autowired
public JwtKeystoreService(
JwtSigningKeyRepository repository, ApplicationProperties applicationProperties) {
this.repository = repository;
this.jwtProperties = applicationProperties.getSecurity().getJwt();
this.privateKeyDirectory = Paths.get(InstallationPathConfig.getConfigPath(), "jwt-keys");
}
@PostConstruct
public void initializeKeystore() {
if (!isKeystoreEnabled()) {
log.info("JWT keystore is disabled, using in-memory key generation");
return;
}
try {
ensurePrivateKeyDirectoryExists();
loadOrGenerateKeypair();
} catch (Exception e) {
log.error("Failed to initialize JWT keystore, falling back to in-memory generation", e);
}
}
@Override
public KeyPair getActiveKeypair() {
if (!isKeystoreEnabled() || currentKeyPair == null) {
return generateRSAKeypair();
}
return currentKeyPair;
}
@Override
public Optional<KeyPair> getKeypairByKeyId(String keyId) {
if (!isKeystoreEnabled()) {
return Optional.empty();
}
try {
Optional<JwtSigningKey> signingKey = repository.findByKeyId(keyId);
if (signingKey.isEmpty()) {
return Optional.empty();
}
PrivateKey privateKey = loadPrivateKey(keyId);
PublicKey publicKey = decodePublicKey(signingKey.get().getSigningKey());
return Optional.of(new KeyPair(publicKey, privateKey));
} catch (Exception e) {
log.error("Failed to load keypair for keyId: {}", keyId, e);
return Optional.empty();
}
}
@Override
public String getActiveKeyId() {
return currentKeyId;
}
@Override
@Transactional
public void rotateKeypair() {
if (!isKeystoreEnabled()) {
log.warn("Cannot rotate keypair when keystore is disabled");
return;
}
try {
repository
.findByIsActiveTrue()
.ifPresent(
key -> {
key.setIsActive(false);
repository.save(key);
});
generateAndStoreKeypair();
log.info("Successfully rotated JWT keypair");
} catch (Exception e) {
log.error("Failed to rotate JWT keypair", e);
throw new RuntimeException("Keypair rotation failed", e);
}
}
@Override
public boolean isKeystoreEnabled() {
return jwtProperties.isEnableKeystore();
}
private void loadOrGenerateKeypair() {
Optional<JwtSigningKey> activeKey = repository.findByIsActiveTrue();
if (activeKey.isPresent()) {
try {
currentKeyId = activeKey.get().getKeyId();
PrivateKey privateKey = loadPrivateKey(currentKeyId);
PublicKey publicKey = decodePublicKey(activeKey.get().getSigningKey());
currentKeyPair = new KeyPair(publicKey, privateKey);
log.info("Loaded existing JWT keypair with keyId: {}", currentKeyId);
} catch (Exception e) {
log.error("Failed to load existing keypair, generating new one", e);
generateAndStoreKeypair();
}
} else {
generateAndStoreKeypair();
}
}
private void generateAndStoreKeypair() {
try {
KeyPair keyPair = generateRSAKeypair();
String keyId = generateKeyId();
storePrivateKey(keyId, keyPair.getPrivate());
JwtSigningKey signingKey =
new JwtSigningKey(keyId, encodePublicKey(keyPair.getPublic()), "RS256");
repository.save(signingKey);
currentKeyPair = keyPair;
currentKeyId = keyId;
log.info("Generated and stored new JWT keypair with keyId: {}", keyId);
} catch (Exception e) {
log.error("Failed to generate and store keypair", e);
throw new RuntimeException("Keypair generation failed", e);
}
}
private KeyPair generateRSAKeypair() {
KeyPairGenerator keyPairGenerator = null;
try {
keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Failed to initialize RSA key pair generator", e);
}
return keyPairGenerator.generateKeyPair();
}
private String generateKeyId() {
return "jwt-key-"
+ LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd-HHmmss"));
}
private void ensurePrivateKeyDirectoryExists() throws IOException {
if (!Files.exists(privateKeyDirectory)) {
Files.createDirectories(privateKeyDirectory);
log.info("Created JWT private key directory: {}", privateKeyDirectory);
}
}
private void storePrivateKey(String keyId, PrivateKey privateKey) throws IOException {
Path keyFile = privateKeyDirectory.resolve(keyId + KEY_SUFFIX);
String encodedKey = Base64.getEncoder().encodeToString(privateKey.getEncoded());
Files.writeString(keyFile, encodedKey);
// Set read/write to only the owner
try {
keyFile.toFile().setReadable(true, true);
keyFile.toFile().setWritable(true, true);
keyFile.toFile().setExecutable(false, false);
} catch (Exception e) {
log.warn("Failed to set permissions on private key file: {}", keyFile, e);
}
}
private PrivateKey loadPrivateKey(String keyId)
throws IOException, NoSuchAlgorithmException, InvalidKeySpecException {
Path keyFile = privateKeyDirectory.resolve(keyId + KEY_SUFFIX);
if (!Files.exists(keyFile)) {
throw new IOException("Private key file not found: " + keyFile);
}
String encodedKey = Files.readString(keyFile);
byte[] keyBytes = Base64.getDecoder().decode(encodedKey);
PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes);
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
return keyFactory.generatePrivate(keySpec);
}
private String encodePublicKey(PublicKey publicKey) {
return Base64.getEncoder().encodeToString(publicKey.getEncoded());
}
private PublicKey decodePublicKey(String encodedKey)
throws NoSuchAlgorithmException, InvalidKeySpecException {
byte[] keyBytes = Base64.getDecoder().decode(encodedKey);
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(keyBytes);
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
return keyFactory.generatePublic(keySpec);
}
}

View File

@ -0,0 +1,17 @@
package stirling.software.proprietary.security.service;
import java.security.KeyPair;
import java.util.Optional;
public interface JwtKeystoreServiceInterface {
KeyPair getActiveKeypair();
Optional<KeyPair> getKeypairByKeyId(String keyId);
String getActiveKeyId();
void rotateKeypair();
boolean isKeystoreEnabled();
}

View File

@ -0,0 +1,242 @@
package stirling.software.proprietary.security.service;
import java.security.KeyPair;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.ResponseCookie;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
import io.github.pixee.security.Newlines;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.SignatureException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal;
@Slf4j
@Service
public class JwtService implements JwtServiceInterface {
private static final String JWT_COOKIE_NAME = "stirling_jwt";
private static final String AUTHORIZATION_HEADER = "Authorization";
private static final String BEARER_PREFIX = "Bearer ";
private static final String ISSUER = "Stirling PDF";
private static final long EXPIRATION = 3600000;
private final JwtKeystoreServiceInterface keystoreService;
private final boolean v2Enabled;
@Autowired
public JwtService(
@Qualifier("v2Enabled") boolean v2Enabled,
JwtKeystoreServiceInterface keystoreService) {
this.v2Enabled = v2Enabled;
this.keystoreService = keystoreService;
}
@Override
public String generateToken(Authentication authentication, Map<String, Object> claims) {
Object principal = authentication.getPrincipal();
String username = "";
if (principal instanceof UserDetails) {
username = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
username = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
}
return generateToken(username, claims);
}
@Override
public String generateToken(String username, Map<String, Object> claims) {
KeyPair keyPair = keystoreService.getActiveKeypair();
var builder =
Jwts.builder()
.claims(claims)
.subject(username)
.issuer(ISSUER)
.issuedAt(new Date())
.expiration(new Date(System.currentTimeMillis() + EXPIRATION))
.signWith(keyPair.getPrivate(), Jwts.SIG.RS256);
String keyId = keystoreService.getActiveKeyId();
if (keyId != null) {
builder.header().keyId(keyId);
}
return builder.compact();
}
@Override
public void validateToken(String token) throws AuthenticationFailureException {
extractAllClaimsFromToken(token);
if (isTokenExpired(token)) {
throw new AuthenticationFailureException("The token has expired");
}
}
@Override
public String extractUsername(String token) {
return extractClaim(token, Claims::getSubject);
}
@Override
public Map<String, Object> extractAllClaims(String token) {
Claims claims = extractAllClaimsFromToken(token);
return new HashMap<>(claims);
}
@Override
public boolean isTokenExpired(String token) {
return extractExpiration(token).before(new Date());
}
private Date extractExpiration(String token) {
return extractClaim(token, Claims::getExpiration);
}
private <T> T extractClaim(String token, Function<Claims, T> claimsResolver) {
final Claims claims = extractAllClaimsFromToken(token);
return claimsResolver.apply(claims);
}
private Claims extractAllClaimsFromToken(String token) {
try {
// Extract key ID from token header if present
String keyId = extractKeyIdFromToken(token);
KeyPair keyPair;
if (keyId != null) {
Optional<KeyPair> specificKeyPair = keystoreService.getKeypairByKeyId(keyId);
if (specificKeyPair.isPresent()) {
keyPair = specificKeyPair.get();
} else {
log.warn(
"Key ID {} not found in keystore, token may have been signed with a rotated key",
keyId);
throw new AuthenticationFailureException(
"JWT token signed with unknown key ID: " + keyId);
}
} else {
keyPair = keystoreService.getActiveKeypair();
}
return Jwts.parser()
.verifyWith(keyPair.getPublic())
.build()
.parseSignedClaims(token)
.getPayload();
} catch (SignatureException e) {
log.warn("Invalid signature: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid signature", e);
} catch (MalformedJwtException e) {
log.warn("Invalid token: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid token", e);
} catch (ExpiredJwtException e) {
log.warn("The token has expired: {}", e.getMessage());
throw new AuthenticationFailureException("The token has expired", e);
} catch (UnsupportedJwtException e) {
log.warn("The token is unsupported: {}", e.getMessage());
throw new AuthenticationFailureException("The token is unsupported", e);
} catch (IllegalArgumentException e) {
log.warn("Claims are empty: {}", e.getMessage());
throw new AuthenticationFailureException("Claims are empty", e);
}
}
@Override
public String extractTokenFromRequest(HttpServletRequest request) {
String authHeader = request.getHeader(AUTHORIZATION_HEADER);
if (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) {
return authHeader.substring(BEARER_PREFIX.length());
}
Cookie[] cookies = request.getCookies();
if (cookies != null) {
for (Cookie cookie : cookies) {
if (JWT_COOKIE_NAME.equals(cookie.getName())) {
return cookie.getValue();
}
}
}
return null;
}
@Override
public void addTokenToResponse(HttpServletResponse response, String token) {
response.setHeader(AUTHORIZATION_HEADER, Newlines.stripAll(BEARER_PREFIX + token));
ResponseCookie cookie =
ResponseCookie.from(JWT_COOKIE_NAME, Newlines.stripAll(token))
.httpOnly(true)
.secure(true)
.sameSite("None")
.maxAge(EXPIRATION / 1000)
.path("/")
.build();
response.addHeader("Set-Cookie", cookie.toString());
}
@Override
public void clearTokenFromResponse(HttpServletResponse response) {
response.setHeader(AUTHORIZATION_HEADER, null);
ResponseCookie cookie =
ResponseCookie.from(JWT_COOKIE_NAME, "")
.httpOnly(true)
.secure(true)
.sameSite("None")
.maxAge(0)
.path("/")
.build();
response.addHeader("Set-Cookie", cookie.toString());
}
@Override
public boolean isJwtEnabled() {
return v2Enabled;
}
private String extractKeyIdFromToken(String token) {
try {
return (String)
Jwts.parser()
.unsecured()
.build()
.parseUnsecuredClaims(token)
.getHeader()
.get("kid");
} catch (Exception e) {
log.debug("Failed to extract key ID from token header: {}", e.getMessage());
return null;
}
}
}

View File

@ -0,0 +1,90 @@
package stirling.software.proprietary.security.service;
import java.util.Map;
import org.springframework.security.core.Authentication;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
public interface JwtServiceInterface {
/**
* Generate a JWT token for the authenticated user
*
* @param authentication Spring Security authentication object
* @return JWT token as a string
*/
String generateToken(Authentication authentication, Map<String, Object> claims);
/**
* Generate a JWT token for a specific username
*
* @param username the username for which to generate the token
* @param claims additional claims to include in the token
* @return JWT token as a string
*/
String generateToken(String username, Map<String, Object> claims);
/**
* Validate a JWT token
*
* @param token the JWT token to validate
* @return true if token is valid, false otherwise
*/
void validateToken(String token);
/**
* Extract username from JWT token
*
* @param token the JWT token
* @return username extracted from token
*/
String extractUsername(String token);
/**
* Extract all claims from JWT token
*
* @param token the JWT token
* @return map of claims
*/
Map<String, Object> extractAllClaims(String token);
/**
* Check if token is expired
*
* @param token the JWT token
* @return true if token is expired, false otherwise
*/
boolean isTokenExpired(String token);
/**
* Extract JWT token from HTTP request (header or cookie)
*
* @param request HTTP servlet request
* @return JWT token if found, null otherwise
*/
String extractTokenFromRequest(HttpServletRequest request);
/**
* Add JWT token to HTTP response (header and cookie)
*
* @param response HTTP servlet response
* @param token JWT token to add
*/
void addTokenToResponse(HttpServletResponse response, String token);
/**
* Clear JWT token from HTTP response (remove cookie)
*
* @param response HTTP servlet response
*/
void clearTokenFromResponse(HttpServletResponse response);
/**
* Check if JWT authentication is enabled
*
* @return true if JWT is enabled, false otherwise
*/
boolean isJwtEnabled();
}

View File

@ -1,5 +1,8 @@
package stirling.software.proprietary.security.service;
import static stirling.software.proprietary.security.model.AuthenticationType.OAUTH2;
import static stirling.software.proprietary.security.model.AuthenticationType.SSO;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
@ -15,7 +18,6 @@ import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.userdetails.UserDetails;
@ -64,16 +66,17 @@ public class UserService implements UserServiceInterface {
@Transactional
public void migrateOauth2ToSSO() {
userRepository
.findByAuthenticationTypeIgnoreCase("OAUTH2")
.findByAuthenticationTypeIgnoreCase(OAUTH2.toString())
.forEach(
user -> {
user.setAuthenticationType(AuthenticationType.SSO);
user.setAuthenticationType(SSO);
userRepository.save(user);
});
}
// Handle OAUTH2 login and user auto creation.
public void processSSOPostLogin(String username, boolean autoCreateUser)
public void processSSOPostLogin(
String username, boolean autoCreateUser, AuthenticationType type)
throws IllegalArgumentException, SQLException, UnsupportedProviderException {
if (!isUsernameValid(username)) {
return;
@ -83,7 +86,7 @@ public class UserService implements UserServiceInterface {
return;
}
if (autoCreateUser) {
saveUser(username, AuthenticationType.SSO);
saveUser(username, type);
}
}
@ -100,10 +103,7 @@ public class UserService implements UserServiceInterface {
}
private Collection<? extends GrantedAuthority> getAuthorities(User user) {
// Convert each Authority object into a SimpleGrantedAuthority object.
return user.getAuthorities().stream()
.map((Authority authority) -> new SimpleGrantedAuthority(authority.getAuthority()))
.toList();
return user.getAuthorities();
}
private String generateApiKey() {

View File

@ -230,7 +230,7 @@ function loadAuditData(targetPage, realPageSize) {
document.getElementById('page-indicator').textContent = `Page ${requestedPage + 1} of ?`;
}
fetch(url)
fetchWithCsrf(url)
.then(response => {
return response.json();
})
@ -302,7 +302,7 @@ function loadStats(days) {
showLoading('user-chart-loading');
showLoading('time-chart-loading');
fetch(`/audit/stats?days=${days}`)
fetchWithCsrf(`/audit/stats?days=${days}`)
.then(response => response.json())
.then(data => {
document.getElementById('total-events').textContent = data.totalEvents;
@ -835,7 +835,7 @@ function hideLoading(id) {
// Load event types from the server for filter dropdowns
function loadEventTypes() {
fetch('/audit/types')
fetchWithCsrf('/audit/types')
.then(response => response.json())
.then(types => {
if (!types || types.length === 0) {

View File

@ -7,16 +7,21 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import stirling.software.common.configuration.AppConfig;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class CustomLogoutSuccessHandlerTest {
@Mock private ApplicationProperties applicationProperties;
@Mock private ApplicationProperties.Security securityProperties;
@Mock private AppConfig appConfig;
@Mock private JwtServiceInterface jwtService;
@InjectMocks private CustomLogoutSuccessHandler customLogoutSuccessHandler;
@ -24,9 +29,12 @@ class CustomLogoutSuccessHandlerTest {
void testSuccessfulLogout() throws IOException {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
String logoutPath = "logout=true";
String token = "token";
String logoutPath = "/login?logout=true";
when(response.isCommitted()).thenReturn(false);
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doNothing().when(jwtService).clearTokenFromResponse(response);
when(request.getContextPath()).thenReturn("");
when(response.encodeRedirectURL(logoutPath)).thenReturn(logoutPath);
@ -35,12 +43,30 @@ class CustomLogoutSuccessHandlerTest {
verify(response).sendRedirect(logoutPath);
}
@Test
void testSuccessfulLogoutViaJWT() throws IOException {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
String logoutPath = "/login?logout=true";
String token = "token";
when(response.isCommitted()).thenReturn(false);
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doNothing().when(jwtService).clearTokenFromResponse(response);
when(request.getContextPath()).thenReturn("");
when(response.encodeRedirectURL(logoutPath)).thenReturn(logoutPath);
customLogoutSuccessHandler.onLogoutSuccess(request, response, null);
verify(response).sendRedirect(logoutPath);
verify(jwtService).clearTokenFromResponse(response);
}
@Test
void testSuccessfulLogoutViaOAuth2() throws IOException {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken oAuth2AuthenticationToken = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -51,8 +77,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(oAuth2AuthenticationToken.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, oAuth2AuthenticationToken);
@ -67,7 +92,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -81,8 +105,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -98,7 +121,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -108,8 +130,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -124,7 +145,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -135,8 +155,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -151,7 +170,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -164,8 +182,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -180,7 +197,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -195,8 +211,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -211,7 +226,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -227,8 +241,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);
@ -243,7 +256,6 @@ class CustomLogoutSuccessHandlerTest {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
ApplicationProperties.Security security = mock(ApplicationProperties.Security.class);
ApplicationProperties.Security.OAUTH2 oauth =
mock(ApplicationProperties.Security.OAUTH2.class);
@ -256,8 +268,7 @@ class CustomLogoutSuccessHandlerTest {
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8080);
when(request.getContextPath()).thenReturn("");
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getOauth2()).thenReturn(oauth);
when(securityProperties.getOauth2()).thenReturn(oauth);
when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test");
customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication);

View File

@ -0,0 +1,40 @@
package stirling.software.proprietary.security;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class JwtAuthenticationEntryPointTest {
@Mock
private HttpServletRequest request;
@Mock
private HttpServletResponse response;
@Mock
private AuthenticationFailureException authException;
@InjectMocks
private JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint;
@Test
void testCommence() throws IOException {
String errorMessage = "Authentication failed";
when(authException.getMessage()).thenReturn(errorMessage);
jwtAuthenticationEntryPoint.commence(request, response, authException);
verify(response).sendError(HttpServletResponse.SC_UNAUTHORIZED, errorMessage);
}
}

View File

@ -0,0 +1,221 @@
package stirling.software.proprietary.security.filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.web.AuthenticationEntryPoint;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import stirling.software.proprietary.security.service.CustomUserDetailsService;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.UserService;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class JwtAuthenticationFilterTest {
@Mock
private JwtServiceInterface jwtService;
@Mock
private CustomUserDetailsService userDetailsService;
@Mock
private UserService userService;
@Mock
private ApplicationProperties.Security securityProperties;
@Mock
private HttpServletRequest request;
@Mock
private HttpServletResponse response;
@Mock
private FilterChain filterChain;
@Mock
private UserDetails userDetails;
@Mock
private SecurityContext securityContext;
@Mock
private AuthenticationEntryPoint authenticationEntryPoint;
@InjectMocks
private JwtAuthenticationFilter jwtAuthenticationFilter;
@Test
void shouldNotAuthenticateWhenJwtDisabled() throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(false);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(filterChain).doFilter(request, response);
verify(jwtService, never()).extractTokenFromRequest(any());
}
@Test
void shouldNotFilterWhenPageIsLogin() throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/login");
when(request.getContextPath()).thenReturn("/login");
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(filterChain, never()).doFilter(request, response);
}
@Test
void testDoFilterInternal() throws ServletException, IOException {
String token = "valid-jwt-token";
String newToken = "new-jwt-token";
String username = "testuser";
Map<String, Object> claims = Map.of("sub", username, "authType", "WEB");
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getContextPath()).thenReturn("/");
when(request.getRequestURI()).thenReturn("/protected");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doNothing().when(jwtService).validateToken(token);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(userDetails.getAuthorities()).thenReturn(Collections.emptyList());
when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails);
try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) {
UsernamePasswordAuthenticationToken authToken =
new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities());
when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken);
mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext);
when(jwtService.generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims))).thenReturn(newToken);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(jwtService).validateToken(token);
verify(jwtService).extractAllClaims(token);
verify(userDetailsService).loadUserByUsername(username);
verify(securityContext).setAuthentication(any(UsernamePasswordAuthenticationToken.class));
verify(jwtService).generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims));
verify(jwtService).addTokenToResponse(response, newToken);
verify(filterChain).doFilter(request, response);
}
}
@Test
void testDoFilterInternalWithMissingTokenForRootPath() throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/");
when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(null);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(response).sendRedirect("/login");
verify(filterChain, never()).doFilter(request, response);
}
@Test
void validationFailsWithInvalidToken() throws ServletException, IOException {
String token = "invalid-jwt-token";
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doThrow(new AuthenticationFailureException("Invalid token")).when(jwtService).validateToken(token);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(jwtService).validateToken(token);
verify(authenticationEntryPoint).commence(eq(request), eq(response), any(AuthenticationFailureException.class));
verify(filterChain, never()).doFilter(request, response);
}
@Test
void validationFailsWithExpiredToken() throws ServletException, IOException {
String token = "expired-jwt-token";
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doThrow(new AuthenticationFailureException("The token has expired")).when(jwtService).validateToken(token);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(jwtService).validateToken(token);
verify(authenticationEntryPoint).commence(eq(request), eq(response), any());
verify(filterChain, never()).doFilter(request, response);
}
@Test
void exceptinonThrown_WhenUserNotFound() throws ServletException, IOException {
String token = "valid-jwt-token";
String username = "nonexistentuser";
Map<String, Object> claims = Map.of("sub", username, "authType", "WEB");
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doNothing().when(jwtService).validateToken(token);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(userDetailsService.loadUserByUsername(username)).thenReturn(null);
try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) {
when(securityContext.getAuthentication()).thenReturn(null);
mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext);
UsernameNotFoundException result = assertThrows(UsernameNotFoundException.class, () -> jwtAuthenticationFilter.doFilterInternal(request, response, filterChain));
assertEquals("User not found: " + username, result.getMessage());
verify(userDetailsService).loadUserByUsername(username);
verify(filterChain, never()).doFilter(request, response);
}
}
@Test
void testAuthenticationEntryPointCalledWithCorrectException() throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/");
when(jwtService.extractTokenFromRequest(request)).thenReturn(null);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(authenticationEntryPoint).commence(eq(request), eq(response), argThat(exception ->
exception.getMessage().equals("JWT is missing from the request")
));
verify(filterChain, never()).doFilter(request, response);
}
}

View File

@ -0,0 +1,230 @@
package stirling.software.proprietary.security.saml2;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class JwtSaml2AuthenticationRequestRepositoryTest {
private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token";
private Map<String, String> tokenStore;
@Mock
private JwtServiceInterface jwtService;
@Mock
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository;
@BeforeEach
void setUp() {
tokenStore = new ConcurrentHashMap<>();
jwtSaml2AuthenticationRequestRepository = new JwtSaml2AuthenticationRequestRepository(
tokenStore, jwtService, relyingPartyRegistrationRepository);
}
@Test
void saveAuthenticationRequest() {
var authRequest = mock(Saml2PostAuthenticationRequest.class);
var request = mock(MockHttpServletRequest.class);
var response = mock(MockHttpServletResponse.class);
String token = "testToken";
String id = "testId";
String relayState = "testRelayState";
String authnRequestUri = "example.com/authnRequest";
Map<String, Object> claims = Map.of();
String samlRequest = "testSamlRequest";
String relyingPartyRegistrationId = "stirling-pdf";
when(jwtService.isJwtEnabled()).thenReturn(true);
when(authRequest.getRelayState()).thenReturn(relayState);
when(authRequest.getId()).thenReturn(id);
when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri);
when(authRequest.getSamlRequest()).thenReturn(samlRequest);
when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId);
when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(authRequest, request, response);
verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState);
verify(response).addHeader(SAML_REQUEST_TOKEN, relayState);
}
@Test
void saveAuthenticationRequestWithNullRequest() {
var request = mock(MockHttpServletRequest.class);
var response = mock(MockHttpServletResponse.class);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response);
assertTrue(tokenStore.isEmpty());
}
@Test
void loadAuthenticationRequest() {
var request = mock(MockHttpServletRequest.class);
var relyingPartyRegistration = mock(RelyingPartyRegistration.class);
var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims = Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState
);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso");
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
@ParameterizedTest
@NullAndEmptySource
void loadAuthenticationRequestWithInvalidRelayState(String relayState) {
var request = mock(MockHttpServletRequest.class);
when(request.getParameter("RelayState")).thenReturn(relayState);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void loadAuthenticationRequestWithNonExistentToken() {
var request = mock(MockHttpServletRequest.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void loadAuthenticationRequestWithNullRelyingPartyRegistration() {
var request = mock(MockHttpServletRequest.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims = Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState
);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(null);
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void removeAuthenticationRequest() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
var relyingPartyRegistration = mock(RelyingPartyRegistration.class);
var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims = Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState
);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso");
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
@Test
void removeAuthenticationRequestWithNullRelayState() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn(null);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNull(result);
}
@Test
void removeAuthenticationRequestWithNonExistentToken() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNull(result);
}
@Test
void removeAuthenticationRequestWithOnlyRelayState() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
String relayState = "testRelayState";
when(request.getParameter("RelayState")).thenReturn(relayState);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
}

View File

@ -0,0 +1,258 @@
package stirling.software.proprietary.security.service;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import stirling.software.common.configuration.InstallationPathConfig;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.database.repository.JwtSigningKeyRepository;
import stirling.software.proprietary.security.model.JwtSigningKey;
@ExtendWith(MockitoExtension.class)
class JwtKeystoreServiceInterfaceTest {
@Mock
private JwtSigningKeyRepository repository;
@Mock
private ApplicationProperties applicationProperties;
@Mock
private ApplicationProperties.Security security;
@Mock
private ApplicationProperties.Security.Jwt jwtConfig;
@TempDir
Path tempDir;
private JwtKeystoreService keystoreService;
private KeyPair testKeyPair;
@BeforeEach
void setUp() throws NoSuchAlgorithmException {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
testKeyPair = keyPairGenerator.generateKeyPair();
when(applicationProperties.getSecurity()).thenReturn(security);
when(security.getJwt()).thenReturn(jwtConfig);
when(jwtConfig.isEnableKeystore()).thenReturn(true);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testKeystoreEnabled(boolean keystoreEnabled) {
when(jwtConfig.isEnableKeystore()).thenReturn(keystoreEnabled);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
assertEquals(keystoreEnabled, keystoreService.isKeystoreEnabled());
}
}
@Test
void testGetActiveKeypairWhenKeystoreDisabled() {
when(jwtConfig.isEnableKeystore()).thenReturn(false);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
KeyPair result = keystoreService.getActiveKeypair();
assertNotNull(result);
assertNotNull(result.getPublic());
assertNotNull(result.getPrivate());
}
}
@Test
void testGetActiveKeypairWhenNoActiveKeyExists() {
when(repository.findByIsActiveTrue()).thenReturn(Optional.empty());
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
keystoreService.initializeKeystore();
KeyPair result = keystoreService.getActiveKeypair();
assertNotNull(result);
verify(repository).save(any(JwtSigningKey.class));
}
}
@Test
void testGetActiveKeypairWithExistingKey() throws Exception {
String keyId = "test-key-2024-01-01-120000";
String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
String privateKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded());
JwtSigningKey existingKey = new JwtSigningKey(keyId, publicKeyBase64, "RS256");
when(repository.findByIsActiveTrue()).thenReturn(Optional.of(existingKey));
Path keyFile = tempDir.resolve("jwt-keys").resolve(keyId + ".key");
Files.createDirectories(keyFile.getParent());
Files.writeString(keyFile, privateKeyBase64);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
keystoreService.initializeKeystore();
KeyPair result = keystoreService.getActiveKeypair();
assertNotNull(result);
assertEquals(keyId, keystoreService.getActiveKeyId());
}
}
@Test
void testGetKeypairByKeyId() throws Exception {
String keyId = "test-key-123";
String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
String privateKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded());
JwtSigningKey signingKey = new JwtSigningKey(keyId, publicKeyBase64, "RS256");
when(repository.findByKeyId(keyId)).thenReturn(Optional.of(signingKey));
Path keyFile = tempDir.resolve("jwt-keys").resolve(keyId + ".key");
Files.createDirectories(keyFile.getParent());
Files.writeString(keyFile, privateKeyBase64);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
Optional<KeyPair> result = keystoreService.getKeypairByKeyId(keyId);
assertTrue(result.isPresent());
assertNotNull(result.get().getPublic());
assertNotNull(result.get().getPrivate());
}
}
@Test
void testGetKeypairByKeyIdNotFound() {
String keyId = "non-existent-key";
when(repository.findByKeyId(keyId)).thenReturn(Optional.empty());
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
Optional<KeyPair> result = keystoreService.getKeypairByKeyId(keyId);
assertFalse(result.isPresent());
}
}
@Test
void testGetKeypairByKeyIdWhenKeystoreDisabled() {
when(jwtConfig.isEnableKeystore()).thenReturn(false);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
Optional<KeyPair> result = keystoreService.getKeypairByKeyId("any-key");
assertFalse(result.isPresent());
}
}
@Test
void testRotateKeypair() {
String oldKeyId = "old-key-123";
JwtSigningKey oldKey = new JwtSigningKey(oldKeyId, "old-public-key", "RS256");
when(repository.findByIsActiveTrue()).thenReturn(Optional.of(oldKey));
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
keystoreService.initializeKeystore();
keystoreService.rotateKeypair();
assertFalse(oldKey.getIsActive());
verify(repository, atLeast(2)).save(any(JwtSigningKey.class)); // At least one for deactivation, one for new key
assertNotNull(keystoreService.getActiveKeyId());
assertNotEquals(oldKeyId, keystoreService.getActiveKeyId());
}
}
@Test
void testRotateKeypairWhenKeystoreDisabled() {
when(jwtConfig.isEnableKeystore()).thenReturn(false);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
assertDoesNotThrow(() -> keystoreService.rotateKeypair());
verify(repository, never()).save(any());
}
}
@Test
void testInitializeKeystoreCreatesDirectory() throws IOException {
when(repository.findByIsActiveTrue()).thenReturn(Optional.empty());
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
keystoreService.initializeKeystore();
Path jwtKeysDir = tempDir.resolve("jwt-keys");
assertTrue(Files.exists(jwtKeysDir));
assertTrue(Files.isDirectory(jwtKeysDir));
}
}
@Test
void testLoadExistingKeypairWithMissingPrivateKeyFile() {
String keyId = "test-key-missing-file";
String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
JwtSigningKey existingKey = new JwtSigningKey(keyId, publicKeyBase64, "RS256");
when(repository.findByIsActiveTrue()).thenReturn(Optional.of(existingKey));
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) {
mockedStatic.when(InstallationPathConfig::getConfigPath).thenReturn(tempDir.toString());
keystoreService = new JwtKeystoreService(repository, applicationProperties);
keystoreService.initializeKeystore();
KeyPair result = keystoreService.getActiveKeypair();
assertNotNull(result);
verify(repository).save(any(JwtSigningKey.class));
}
}
}

View File

@ -0,0 +1,330 @@
package stirling.software.proprietary.security.service;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.model.User;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.contains;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class JwtServiceTest {
@Mock
private ApplicationProperties.Security securityProperties;
@Mock
private Authentication authentication;
@Mock
private User userDetails;
@Mock
private HttpServletRequest request;
@Mock
private HttpServletResponse response;
@Mock
private JwtKeystoreServiceInterface keystoreService;
private JwtService jwtService;
private KeyPair testKeyPair;
@BeforeEach
void setUp() throws NoSuchAlgorithmException {
// Generate a test keypair
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
testKeyPair = keyPairGenerator.generateKeyPair();
jwtService = new JwtService(true, keystoreService);
}
@Test
void testGenerateTokenWithAuthentication() {
String username = "testuser";
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username);
String token = jwtService.generateToken(authentication, Collections.emptyMap());
assertNotNull(token);
assertFalse(token.isEmpty());
assertEquals(username, jwtService.extractUsername(token));
}
@Test
void testGenerateTokenWithUsernameAndClaims() {
String username = "testuser";
Map<String, Object> claims = new HashMap<>();
claims.put("role", "admin");
claims.put("department", "IT");
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username);
String token = jwtService.generateToken(authentication, claims);
assertNotNull(token);
assertFalse(token.isEmpty());
assertEquals(username, jwtService.extractUsername(token));
Map<String, Object> extractedClaims = jwtService.extractAllClaims(token);
assertEquals("admin", extractedClaims.get("role"));
assertEquals("IT", extractedClaims.get("department"));
}
@Test
void testValidateTokenSuccess() {
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn("testuser");
String token = jwtService.generateToken(authentication, new HashMap<>());
assertDoesNotThrow(() -> jwtService.validateToken(token));
}
@Test
void testValidateTokenWithInvalidToken() {
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
assertThrows(AuthenticationFailureException.class, () -> {
jwtService.validateToken("invalid-token");
});
}
@Test
void testValidateTokenWithMalformedToken() {
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> {
jwtService.validateToken("malformed.token");
});
assertTrue(exception.getMessage().contains("Invalid"));
}
@Test
void testValidateTokenWithEmptyToken() {
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> {
jwtService.validateToken("");
});
assertTrue(exception.getMessage().contains("Claims are empty") || exception.getMessage().contains("Invalid"));
}
@Test
void testExtractUsername() {
String username = "testuser";
User user = mock(User.class);
Map<String, Object> claims = Map.of("sub", "testuser", "authType", "WEB");
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(user);
when(user.getUsername()).thenReturn(username);
String token = jwtService.generateToken(authentication, claims);
assertEquals(username, jwtService.extractUsername(token));
}
@Test
void testExtractUsernameWithInvalidToken() {
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
assertThrows(AuthenticationFailureException.class, () -> jwtService.extractUsername("invalid-token"));
}
@Test
void testExtractAllClaims() {
String username = "testuser";
Map<String, Object> claims = Map.of("role", "admin", "department", "IT");
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username);
String token = jwtService.generateToken(authentication, claims);
Map<String, Object> extractedClaims = jwtService.extractAllClaims(token);
assertEquals("admin", extractedClaims.get("role"));
assertEquals("IT", extractedClaims.get("department"));
assertEquals(username, extractedClaims.get("sub"));
assertEquals("Stirling PDF", extractedClaims.get("iss"));
}
@Test
void testExtractAllClaimsWithInvalidToken() {
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
assertThrows(AuthenticationFailureException.class, () -> jwtService.extractAllClaims("invalid-token"));
}
@Test
void testExtractTokenFromRequestWithAuthorizationHeader() {
String token = "test-token";
when(request.getHeader("Authorization")).thenReturn("Bearer " + token);
assertEquals(token, jwtService.extractTokenFromRequest(request));
}
@Test
void testExtractTokenFromRequestWithCookie() {
String token = "test-token";
Cookie[] cookies = { new Cookie("stirling_jwt", token) };
when(request.getHeader("Authorization")).thenReturn(null);
when(request.getCookies()).thenReturn(cookies);
assertEquals(token, jwtService.extractTokenFromRequest(request));
}
@Test
void testExtractTokenFromRequestWithNoCookies() {
when(request.getHeader("Authorization")).thenReturn(null);
when(request.getCookies()).thenReturn(null);
assertNull(jwtService.extractTokenFromRequest(request));
}
@Test
void testExtractTokenFromRequestWithWrongCookie() {
Cookie[] cookies = {new Cookie("OTHER_COOKIE", "value")};
when(request.getHeader("Authorization")).thenReturn(null);
when(request.getCookies()).thenReturn(cookies);
assertNull(jwtService.extractTokenFromRequest(request));
}
@Test
void testExtractTokenFromRequestWithInvalidAuthorizationHeader() {
when(request.getHeader("Authorization")).thenReturn("Basic token");
when(request.getCookies()).thenReturn(null);
assertNull(jwtService.extractTokenFromRequest(request));
}
@Test
void testAddTokenToResponse() {
String token = "test-token";
jwtService.addTokenToResponse(response, token);
verify(response).setHeader("Authorization", "Bearer " + token);
verify(response).addHeader(eq("Set-Cookie"), contains("stirling_jwt=" + token));
verify(response).addHeader(eq("Set-Cookie"), contains("HttpOnly"));
verify(response).addHeader(eq("Set-Cookie"), contains("Secure"));
}
@Test
void testClearTokenFromResponse() {
jwtService.clearTokenFromResponse(response);
verify(response).setHeader("Authorization", null);
verify(response).addHeader(eq("Set-Cookie"), contains("stirling_jwt="));
verify(response).addHeader(eq("Set-Cookie"), contains("Max-Age=0"));
}
@Test
void testGenerateTokenWithKeyId() {
String username = "testuser";
Map<String, Object> claims = new HashMap<>();
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username);
String token = jwtService.generateToken(authentication, claims);
assertNotNull(token);
assertFalse(token.isEmpty());
// Verify that the keystore service was called
verify(keystoreService).getActiveKeypair();
verify(keystoreService).getActiveKeyId();
}
@Test
void testTokenVerificationWithSpecificKeyId() throws NoSuchAlgorithmException {
String username = "testuser";
Map<String, Object> claims = new HashMap<>();
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username);
// Generate token with key ID
String token = jwtService.generateToken(authentication, claims);
// Mock extraction of key ID and verification (lenient to avoid unused stubbing)
lenient().when(keystoreService.getKeypairByKeyId("test-key-id")).thenReturn(Optional.of(testKeyPair));
// Verify token can be validated
assertDoesNotThrow(() -> jwtService.validateToken(token));
assertEquals(username, jwtService.extractUsername(token));
}
@Test
void testTokenVerificationFallsBackToActiveKeyWhenKeyIdNotFound() {
String username = "testuser";
Map<String, Object> claims = new HashMap<>();
when(keystoreService.getActiveKeypair()).thenReturn(testKeyPair);
when(keystoreService.getActiveKeyId()).thenReturn("test-key-id");
when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username);
String token = jwtService.generateToken(authentication, claims);
// Mock scenario where specific key ID is not found (lenient to avoid unused stubbing)
lenient().when(keystoreService.getKeypairByKeyId("test-key-id")).thenReturn(Optional.empty());
// Should still work using active keypair
assertDoesNotThrow(() -> jwtService.validateToken(token));
assertEquals(username, jwtService.extractUsername(token));
// Verify fallback to active keypair was used (called multiple times during token operations)
verify(keystoreService, atLeast(1)).getActiveKeypair();
}
}