diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 6e006423a..bc5358b85 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -5,7 +5,13 @@ "Bash(mkdir:*)", "Bash(./gradlew:*)", "Bash(grep:*)", - "Bash(cat:*)" + "Bash(cat:*)", + "Bash(find:*)", + "Bash(grep:*)", + "Bash(rg:*)", + "Bash(strings:*)", + "Bash(pkill:*)", + "Bash(true)" ], "deny": [] } diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index db847f570..1487b1f3d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -147,6 +147,8 @@ jobs: - name: Generate OpenAPI documentation run: ./gradlew :stirling-pdf:generateOpenApiDocs + env: + DISABLE_ADDITIONAL_FEATURES: true - name: Upload OpenAPI Documentation uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 diff --git a/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java b/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java index f611f42ca..e24a92d6a 100644 --- a/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java +++ b/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Locale; import java.util.Properties; import java.util.function.Predicate; +import java.util.stream.Stream; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -51,6 +52,14 @@ public class AppConfig { @Value("${server.port:8080}") private String serverPort; + @Value("${v2}") + public boolean v2Enabled; + + @Bean + public boolean v2Enabled() { + return v2Enabled; + } + @Bean @ConditionalOnProperty(name = "system.customHTMLFiles", havingValue = "true") public SpringTemplateEngine templateEngine(ResourceLoader resourceLoader) { @@ -120,7 +129,7 @@ public class AppConfig { public boolean rateLimit() { String rateLimit = System.getProperty("rateLimit"); if (rateLimit == null) rateLimit = System.getenv("rateLimit"); - return (rateLimit != null) ? Boolean.valueOf(rateLimit) : false; + return Boolean.parseBoolean(rateLimit); } @Bean(name = "RunningInDocker") @@ -140,8 +149,8 @@ public class AppConfig { if (!Files.exists(mountInfo)) { return true; } - try { - return Files.lines(mountInfo).anyMatch(line -> line.contains(" /configs ")); + try (Stream lines = Files.lines(mountInfo)) { + return lines.anyMatch(line -> line.contains(" /configs ")); } catch (IOException e) { return false; } diff --git a/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java b/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java index 247a012ad..95de76dbd 100644 --- a/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java +++ b/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java @@ -25,6 +25,7 @@ public class InstallationPathConfig { private static final String STATIC_PATH; private static final String TEMPLATES_PATH; private static final String SIGNATURES_PATH; + private static final String PRIVATE_KEY_PATH; static { BASE_PATH = initializeBasePath(); @@ -45,6 +46,7 @@ public class InstallationPathConfig { STATIC_PATH = CUSTOM_FILES_PATH + "static" + File.separator; TEMPLATES_PATH = CUSTOM_FILES_PATH + "templates" + File.separator; SIGNATURES_PATH = CUSTOM_FILES_PATH + "signatures" + File.separator; + PRIVATE_KEY_PATH = CONFIG_PATH + "keys" + File.separator; } private static String initializeBasePath() { @@ -120,4 +122,8 @@ public class InstallationPathConfig { public static String getSignaturesPath() { return SIGNATURES_PATH; } + + public static String getPrivateKeyPath() { + return PRIVATE_KEY_PATH; + } } diff --git a/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java b/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java index 802a55831..48df4f948 100644 --- a/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java +++ b/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java @@ -119,6 +119,7 @@ public class ApplicationProperties { private long loginResetTimeMinutes; private String loginMethod = "all"; private String customGlobalAPIKey; + private Jwt jwt = new Jwt(); public Boolean isAltLogin() { return saml2.getEnabled() || oauth2.getEnabled(); @@ -297,6 +298,15 @@ public class ApplicationProperties { } } } + + @Data + public static class Jwt { + private boolean enableKeystore = true; + private boolean enableKeyRotation = false; + private boolean enableKeyCleanup = true; + private int keyRetentionDays = 7; + private int cleanupBatchSize = 100; + } } @Data diff --git a/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java b/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java index 654c78fe9..239976b66 100644 --- a/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java +++ b/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java @@ -14,8 +14,10 @@ public class RequestUriUtils { || requestURI.startsWith(contextPath + "/images/") || requestURI.startsWith(contextPath + "/public/") || requestURI.startsWith(contextPath + "/pdfjs/") + || requestURI.startsWith(contextPath + "/pdfjs-legacy/") || requestURI.startsWith(contextPath + "/login") || requestURI.startsWith(contextPath + "/error") + || requestURI.startsWith(contextPath + "/favicon") || requestURI.endsWith(".svg") || requestURI.endsWith(".png") || requestURI.endsWith(".ico") diff --git a/app/core/src/main/resources/application.properties b/app/core/src/main/resources/application.properties index ea30bf78e..ec3a0a390 100644 --- a/app/core/src/main/resources/application.properties +++ b/app/core/src/main/resources/application.properties @@ -5,7 +5,7 @@ logging.level.org.eclipse.jetty=WARN #logging.level.org.springframework.security.saml2=TRACE #logging.level.org.springframework.security=DEBUG #logging.level.org.opensaml=DEBUG -#logging.level.stirling.software.SPDF.config.security: DEBUG +#logging.level.stirling.software.proprietary.security: DEBUG logging.level.com.zaxxer.hikari=WARN spring.jpa.open-in-view=false server.forward-headers-strategy=NATIVE @@ -47,4 +47,7 @@ posthog.host=https://eu.i.posthog.com spring.main.allow-bean-definition-overriding=true # Set up a consistent temporary directory location -java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf} \ No newline at end of file +java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf} + +# V2 features +v2=true diff --git a/app/core/src/main/resources/messages_en_GB.properties b/app/core/src/main/resources/messages_en_GB.properties index f78e80b65..32f1bb6ea 100644 --- a/app/core/src/main/resources/messages_en_GB.properties +++ b/app/core/src/main/resources/messages_en_GB.properties @@ -861,7 +861,7 @@ login.rememberme=Remember me login.invalid=Invalid username or password. login.locked=Your account has been locked. login.signinTitle=Please sign in -login.ssoSignIn=Login via Single Sign-on +login.ssoSignIn=Login via Single Sign-On login.oAuth2AutoCreateDisabled=OAUTH2 Auto-Create User Disabled login.oAuth2AdminBlockedUser=Registration or logging in of non-registered users is currently blocked. Please contact the administrator. login.oauth2RequestNotFound=Authorization request not found @@ -876,6 +876,7 @@ login.alreadyLoggedIn=You are already logged in to login.alreadyLoggedIn2=devices. Please log out of the devices and try again. login.toManySessions=You have too many active sessions login.logoutMessage=You have been logged out. +login.invalidInResponseTo=The requested SAML response is invalid or has expired. Please contact the administrator. #auto-redact autoRedact.title=Auto Redact diff --git a/app/core/src/main/resources/settings.yml.template b/app/core/src/main/resources/settings.yml.template index cf22262e4..3d5608bd5 100644 --- a/app/core/src/main/resources/settings.yml.template +++ b/app/core/src/main/resources/settings.yml.template @@ -59,12 +59,18 @@ security: idpCert: classpath:okta.cert # The certificate your Provider will use to authenticate your app's SAML authentication requests. Provided by your Provider privateKey: classpath:saml-private-key.key # Your private key. Generated from your keypair spCert: classpath:saml-public-cert.crt # Your signing certificate. Generated from your keypair + jwt: + persistence: true # Set to 'true' to enable JWT key store + enableKeyRotation: true # Set to 'true' to enable key pair rotation + enableKeyCleanup: true # Set to 'true' to enable key pair cleanup + keyRetentionDays: 7 # Number of days to retain old keys. The default is 7 days. + cleanupBatchSize: 100 # Number of keys to clean up in each batch. The default is 100. + secureCookie: false # Set to 'true' to use secure cookies for JWTs premium: key: 00000000-0000-0000-0000-000000000000 enabled: false # Enable license key checks for pro/enterprise features proFeatures: - database: true # Enable database features SSOAutoLogin: false CustomMetadata: autoUpdateMetadata: false diff --git a/app/core/src/main/resources/static/js/DecryptFiles.js b/app/core/src/main/resources/static/js/DecryptFiles.js index 67349a012..0e5b58a92 100644 --- a/app/core/src/main/resources/static/js/DecryptFiles.js +++ b/app/core/src/main/resources/static/js/DecryptFiles.js @@ -46,10 +46,9 @@ export class DecryptFile { formData.append('password', password); } // Send decryption request - const response = await fetch('/api/v1/security/remove-password', { + const response = await fetchWithCsrf('/api/v1/security/remove-password', { method: 'POST', body: formData, - headers: csrfToken ? {'X-XSRF-TOKEN': csrfToken} : undefined, }); if (response.ok) { diff --git a/app/core/src/main/resources/static/js/downloader.js b/app/core/src/main/resources/static/js/downloader.js index 42ba0c357..b5324dd82 100644 --- a/app/core/src/main/resources/static/js/downloader.js +++ b/app/core/src/main/resources/static/js/downloader.js @@ -218,7 +218,7 @@ formData.append('password', password); // Use handleSingleDownload to send the request - const decryptionResult = await fetch(removePasswordUrl, {method: 'POST', body: formData}); + const decryptionResult = await fetchWithCsrf(removePasswordUrl, {method: 'POST', body: formData}); if (decryptionResult && decryptionResult.blob) { const decryptedBlob = await decryptionResult.blob(); diff --git a/app/core/src/main/resources/static/js/fetch-utils.js b/app/core/src/main/resources/static/js/fetch-utils.js index dfe2604a8..73c26ffb3 100644 --- a/app/core/src/main/resources/static/js/fetch-utils.js +++ b/app/core/src/main/resources/static/js/fetch-utils.js @@ -1,3 +1,80 @@ +// JWT Management Utility +window.JWTManager = { + JWT_STORAGE_KEY: 'stirling_jwt', + + // Store JWT token in localStorage + storeToken: function(token) { + if (token) { + localStorage.setItem(this.JWT_STORAGE_KEY, token); + } + }, + + // Get JWT token from localStorage + getToken: function() { + return localStorage.getItem(this.JWT_STORAGE_KEY); + }, + + // Remove JWT token from localStorage + removeToken: function() { + localStorage.removeItem(this.JWT_STORAGE_KEY); + }, + + // Extract JWT from Authorization header in response + extractTokenFromResponse: function(response) { + const authHeader = response.headers.get('Authorization'); + if (authHeader && authHeader.startsWith('Bearer ')) { + const token = authHeader.substring(7); // Remove 'Bearer ' prefix + this.storeToken(token); + return token; + } + return null; + }, + + // Check if user is authenticated (has valid JWT) + isAuthenticated: function() { + const token = this.getToken(); + if (!token) return false; + + try { + // Basic JWT expiration check (decode payload) + const payload = JSON.parse(atob(token.split('.')[1])); + const now = Date.now() / 1000; + return payload.exp > now; + } catch (error) { + console.warn('Invalid JWT token:', error); + this.removeToken(); + return false; + } + }, + + // Logout - remove token and redirect to login + logout: function() { + this.removeToken(); + + // Clear all possible token storage locations + localStorage.removeItem(this.JWT_STORAGE_KEY); + sessionStorage.removeItem(this.JWT_STORAGE_KEY); + + // Clear JWT cookie manually (fallback) + document.cookie = 'stirling_jwt=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure'; + + // Perform logout request to clear server-side session + fetch('/logout', { + method: 'POST', + credentials: 'include' + }).then(response => { + if (response.redirected) { + window.location.href = response.url; + } else { + window.location.href = '/login?logout=true'; + } + }).catch(() => { + // If logout fails, let server handle it + window.location.href = '/logout'; + }); + } +}; + window.fetchWithCsrf = async function(url, options = {}) { function getCsrfToken() { const cookieValue = document.cookie @@ -24,5 +101,31 @@ window.fetchWithCsrf = async function(url, options = {}) { fetchOptions.headers['X-XSRF-TOKEN'] = csrfToken; } - return fetch(url, fetchOptions); + // Add JWT token to Authorization header if available + const jwtToken = window.JWTManager.getToken(); + if (jwtToken) { + fetchOptions.headers['Authorization'] = `Bearer ${jwtToken}`; + // Include credentials when JWT is enabled + fetchOptions.credentials = 'include'; + } + + // Make the request + const response = await fetch(url, fetchOptions); + + // Extract JWT from response if present + window.JWTManager.extractTokenFromResponse(response); + + // Handle 401 responses (unauthorized) + if (response.status === 401) { + console.warn('Authentication failed, redirecting to login'); + window.JWTManager.logout(); + return response; + } + + return response; +} + +// Enhanced fetch function that always includes JWT +window.fetchWithJWT = async function(url, options = {}) { + return window.fetchWithCsrf(url, options); } diff --git a/app/core/src/main/resources/static/js/jwt-init.js b/app/core/src/main/resources/static/js/jwt-init.js new file mode 100644 index 000000000..a2733bf35 --- /dev/null +++ b/app/core/src/main/resources/static/js/jwt-init.js @@ -0,0 +1,124 @@ +// JWT Initialization Script +// This script handles JWT token extraction during OAuth/Login flows and initializes the JWT manager + +(function() { + // Extract JWT token from URL parameters (for OAuth redirects) + function extractTokenFromUrl() { + const urlParams = new URLSearchParams(window.location.search); + const token = urlParams.get('jwt') || urlParams.get('token'); + if (token) { + window.JWTManager.storeToken(token); + // Clean up URL by removing token parameter + urlParams.delete('jwt'); + urlParams.delete('token'); + const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : ''); + window.history.replaceState({}, '', newUrl); + } + } + + // Extract JWT token from cookie on page load (fallback) + function extractTokenFromCookie() { + const cookieValue = document.cookie + .split('; ') + .find(row => row.startsWith('stirling_jwt=')) + ?.split('=')[1]; + + if (cookieValue) { + window.JWTManager.storeToken(cookieValue); + // Clear the cookie since we're using localStorage with consistent SameSite policy + document.cookie = 'stirling_jwt=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure'; + } + } + + // Initialize JWT handling when page loads + function initializeJWT() { + // Try to extract token from URL first (OAuth flow) + extractTokenFromUrl(); + + // If no token in URL, try cookie (login flow) + if (!window.JWTManager.getToken()) { + extractTokenFromCookie(); + } + + // Check if user is authenticated + if (window.JWTManager.isAuthenticated()) { + console.log('User is authenticated with JWT'); + } else { + console.log('User is not authenticated or token expired'); + // Only redirect to login if we're not already on login/register pages + const currentPath = window.location.pathname; + const currentSearch = window.location.search; + // Don't redirect if we're on logout page or already being logged out + if (!currentPath.includes('/login') && + !currentPath.includes('/register') && + !currentPath.includes('/oauth') && + !currentPath.includes('/saml') && + !currentPath.includes('/error') && + !currentSearch.includes('logout=true')) { + // Redirect to login after a short delay to allow other scripts to load + setTimeout(() => { + window.location.href = '/login'; + }, 100); + } + } + } + + // Override form submissions to include JWT + function enhanceFormSubmissions() { + // Override form submit for login forms + document.addEventListener('submit', function(event) { + const form = event.target; + + // Add JWT to form data if available + const jwtToken = window.JWTManager.getToken(); + if (jwtToken && form.method && form.method.toLowerCase() !== 'get') { + // Create a hidden input for JWT + const jwtInput = document.createElement('input'); + jwtInput.type = 'hidden'; + jwtInput.name = 'jwt'; + jwtInput.value = jwtToken; + form.appendChild(jwtInput); + } + }); + } + + // Add logout functionality to logout buttons + function enhanceLogoutButtons() { + document.addEventListener('click', function(event) { + const element = event.target; + + // Check if clicked element is a logout button/link + if (element.matches('a[href="/logout"], button[data-action="logout"], .logout-btn')) { + event.preventDefault(); + window.JWTManager.logout(); + } + }); + } + + // Initialize when DOM is ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', function() { + initializeJWT(); + enhanceFormSubmissions(); + enhanceLogoutButtons(); + }); + } else { + initializeJWT(); + enhanceFormSubmissions(); + enhanceLogoutButtons(); + } + + // Handle page visibility changes to check token expiration + document.addEventListener('visibilitychange', function() { + if (!document.hidden && !window.JWTManager.isAuthenticated()) { + // Token expired while page was hidden, redirect to login + const currentPath = window.location.pathname; + if (!currentPath.includes('/login') && + !currentPath.includes('/register') && + !currentPath.includes('/oauth') && + !currentPath.includes('/saml')) { + window.location.href = '/login'; + } + } + }); +})(); \ No newline at end of file diff --git a/app/core/src/main/resources/static/js/navbar.js b/app/core/src/main/resources/static/js/navbar.js index a95ff1639..1fd46ed70 100644 --- a/app/core/src/main/resources/static/js/navbar.js +++ b/app/core/src/main/resources/static/js/navbar.js @@ -138,5 +138,19 @@ document.addEventListener('DOMContentLoaded', () => { tooltipSetup(); setupDropdowns(); fixNavbarDropdownStyles(); + // Setup logout button functionality + const logoutButton = document.querySelector('a[href="/logout"]'); + if (logoutButton) { + logoutButton.addEventListener('click', function(event) { + event.preventDefault(); + if (window.JWTManager) { + window.JWTManager.logout(); + } else { + // Fallback if JWTManager is not available + window.location.href = '/logout'; + } + }); + } + }); window.addEventListener('resize', fixNavbarDropdownStyles); diff --git a/app/core/src/main/resources/static/js/usage.js b/app/core/src/main/resources/static/js/usage.js index 624e4ec78..443a27ce1 100644 --- a/app/core/src/main/resources/static/js/usage.js +++ b/app/core/src/main/resources/static/js/usage.js @@ -102,7 +102,7 @@ async function fetchEndpointData() { refreshBtn.classList.add('refreshing'); refreshBtn.disabled = true; - const response = await fetch('/api/v1/info/load/all'); + const response = await fetchWithCsrf('/api/v1/info/load/all'); if (!response.ok) { throw new Error('Network response was not ok'); } diff --git a/app/core/src/main/resources/templates/account.html b/app/core/src/main/resources/templates/account.html index 33a0d9f47..db48bb3a5 100644 --- a/app/core/src/main/resources/templates/account.html +++ b/app/core/src/main/resources/templates/account.html @@ -390,8 +390,13 @@ key.includes('clientSubmissionOrder') || key.includes('lastSubmitTime') || key.includes('lastClientId') || - - + key.includes('stirling_jwt') || + key.includes('JSESSIONID') || + key.includes('XSRF-TOKEN') || + key.includes('remember-me') || + key.includes('auth') || + key.includes('token') || + key.includes('session') || key.includes('posthog') || key.includes('ssoRedirectAttempts') || key.includes('lastRedirectAttempt') || key.includes('surveyVersion') || key.includes('pageViews'); } diff --git a/app/proprietary/build.gradle b/app/proprietary/build.gradle index 2a72f8a65..4154ec14c 100644 --- a/app/proprietary/build.gradle +++ b/app/proprietary/build.gradle @@ -1,9 +1,15 @@ repositories { maven { url = "https://build.shibboleth.net/maven/releases" } } + +ext { + jwtVersion = '0.12.6' +} + bootRun { enabled = false } + spotless { java { target sourceSets.main.allJava @@ -38,6 +44,10 @@ dependencies { implementation 'org.thymeleaf.extras:thymeleaf-extras-springsecurity5:3.1.3.RELEASE' api 'io.micrometer:micrometer-registry-prometheus' implementation 'com.unboundid.product.scim2:scim2-sdk-client:4.0.0' + + api "io.jsonwebtoken:jjwt-api:$jwtVersion" + runtimeOnly "io.jsonwebtoken:jjwt-impl:$jwtVersion" + runtimeOnly "io.jsonwebtoken:jjwt-jackson:$jwtVersion" runtimeOnly 'com.h2database:h2:2.3.232' // Don't upgrade h2database runtimeOnly 'org.postgresql:postgresql:42.7.7' constraints { diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java index d5180c321..2d36815ee 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java @@ -1,6 +1,7 @@ package stirling.software.proprietary.security; import java.io.IOException; +import java.util.Map; import org.springframework.security.core.Authentication; import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; @@ -17,6 +18,8 @@ import stirling.software.common.util.RequestUriUtils; import stirling.software.proprietary.audit.AuditEventType; import stirling.software.proprietary.audit.AuditLevel; import stirling.software.proprietary.audit.Audited; +import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; @@ -24,13 +27,17 @@ import stirling.software.proprietary.security.service.UserService; public class CustomAuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler { - private LoginAttemptService loginAttemptService; - private UserService userService; + private final LoginAttemptService loginAttemptService; + private final UserService userService; + private final JwtServiceInterface jwtService; public CustomAuthenticationSuccessHandler( - LoginAttemptService loginAttemptService, UserService userService) { + LoginAttemptService loginAttemptService, + UserService userService, + JwtServiceInterface jwtService) { this.loginAttemptService = loginAttemptService; this.userService = userService; + this.jwtService = jwtService; } @Override @@ -46,23 +53,35 @@ public class CustomAuthenticationSuccessHandler } loginAttemptService.loginSucceeded(userName); - // Get the saved request - HttpSession session = request.getSession(false); - SavedRequest savedRequest = - (session != null) - ? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST") - : null; + if (jwtService.isJwtEnabled()) { + try { + String jwt = + jwtService.generateToken( + authentication, Map.of("authType", AuthenticationType.WEB)); + jwtService.addToken(response, jwt); + log.debug("JWT generated for user: {}", userName); + } catch (Exception e) { + log.error("Failed to generate JWT token for user: {}", userName, e); + } - if (savedRequest != null - && !RequestUriUtils.isStaticResource( - request.getContextPath(), savedRequest.getRedirectUrl())) { - // Redirect to the original destination - super.onAuthenticationSuccess(request, response, authentication); - } else { - // Redirect to the root URL (considering context path) getRedirectStrategy().sendRedirect(request, response, "/"); - } + } else { + // Get the saved request + HttpSession session = request.getSession(false); + SavedRequest savedRequest = + (session != null) + ? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST") + : null; - // super.onAuthenticationSuccess(request, response, authentication); + if (savedRequest != null + && !RequestUriUtils.isStaticResource( + request.getContextPath(), savedRequest.getRedirectUrl())) { + // Redirect to the original destination + super.onAuthenticationSuccess(request, response, authentication); + } else { + // No saved request or it's a static resource, redirect to home page + getRedirectStrategy().sendRedirect(request, response, "/"); + } + } } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java index 033ea913c..136120528 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java @@ -33,6 +33,7 @@ import stirling.software.proprietary.audit.AuditLevel; import stirling.software.proprietary.audit.Audited; import stirling.software.proprietary.security.saml2.CertificateUtils; import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal; +import stirling.software.proprietary.security.service.JwtServiceInterface; @Slf4j @RequiredArgsConstructor @@ -40,15 +41,18 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { public static final String LOGOUT_PATH = "/login?logout=true"; - private final ApplicationProperties applicationProperties; + private final ApplicationProperties.Security securityProperties; private final AppConfig appConfig; + private final JwtServiceInterface jwtService; + @Override @Audited(type = AuditEventType.USER_LOGOUT, level = AuditLevel.BASIC) public void onLogoutSuccess( HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException { + if (!response.isCommitted()) { if (authentication != null) { if (authentication instanceof Saml2Authentication samlAuthentication) { @@ -67,6 +71,9 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { authentication.getClass().getSimpleName()); getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH); } + } else if (!jwtService.extractToken(request).isBlank()) { + jwtService.clearToken(response); + getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH); } else { // Redirect to login page after logout String path = checkForErrors(request); @@ -82,7 +89,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { Saml2Authentication samlAuthentication) throws IOException { - SAML2 samlConf = applicationProperties.getSecurity().getSaml2(); + SAML2 samlConf = securityProperties.getSaml2(); String registrationId = samlConf.getRegistrationId(); CustomSaml2AuthenticatedPrincipal principal = @@ -127,7 +134,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { OAuth2AuthenticationToken oAuthToken) throws IOException { String registrationId; - OAUTH2 oauth = applicationProperties.getSecurity().getOauth2(); + OAUTH2 oauth = securityProperties.getOauth2(); String path = checkForErrors(request); String redirectUrl = UrlUtils.getOrigin(request) + "/login?" + path; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java index 4b09fe0e9..e145e2754 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java @@ -43,7 +43,6 @@ public class InitialSecuritySetup { } } - userService.migrateOauth2ToSSO(); assignUsersToDefaultTeamIfMissing(); initializeInternalApiUser(); } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/JwtAuthenticationEntryPoint.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/JwtAuthenticationEntryPoint.java new file mode 100644 index 000000000..6805bcb54 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/JwtAuthenticationEntryPoint.java @@ -0,0 +1,22 @@ +package stirling.software.proprietary.security; + +import java.io.IOException; + +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.stereotype.Component; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +@Component +public class JwtAuthenticationEntryPoint implements AuthenticationEntryPoint { + @Override + public void commence( + HttpServletRequest request, + HttpServletResponse response, + AuthenticationException authException) + throws IOException { + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage()); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java index 0d846fc3d..46d0e7d3d 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java @@ -77,8 +77,11 @@ public class AccountWebController { @GetMapping("/login") public String login(HttpServletRequest request, Model model, Authentication authentication) { - // If the user is already authenticated, redirect them to the home page. - if (authentication != null && authentication.isAuthenticated()) { + // If the user is already authenticated and it's not a logout scenario, redirect them to the + // home page. + if (authentication != null + && authentication.isAuthenticated() + && request.getParameter("logout") == null) { return "redirect:/"; } @@ -184,7 +187,7 @@ public class AccountWebController { errorOAuth = "login.relyingPartyRegistrationNotFound"; // Valid InResponseTo was not available from the validation context, unable to // evaluate - case "invalid_in_response_to" -> errorOAuth = "login.invalid_in_response_to"; + case "invalid_in_response_to" -> errorOAuth = "login.invalidInResponseTo"; case "not_authentication_provider_found" -> errorOAuth = "login.not_authentication_provider_found"; } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java index ab809a037..7f2afc735 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java @@ -8,11 +8,14 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.DependsOn; import org.springframework.context.annotation.Lazy; +import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.ProviderManager; import org.springframework.security.authentication.dao.DaoAuthenticationProvider; +import org.springframework.security.config.annotation.authentication.configuration.AuthenticationConfiguration; import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; @@ -35,10 +38,12 @@ import stirling.software.common.model.ApplicationProperties; import stirling.software.proprietary.security.CustomAuthenticationFailureHandler; import stirling.software.proprietary.security.CustomAuthenticationSuccessHandler; import stirling.software.proprietary.security.CustomLogoutSuccessHandler; +import stirling.software.proprietary.security.JwtAuthenticationEntryPoint; import stirling.software.proprietary.security.database.repository.JPATokenRepositoryImpl; import stirling.software.proprietary.security.database.repository.PersistentLoginRepository; import stirling.software.proprietary.security.filter.FirstLoginFilter; import stirling.software.proprietary.security.filter.IPRateLimitingFilter; +import stirling.software.proprietary.security.filter.JwtAuthenticationFilter; import stirling.software.proprietary.security.filter.UserAuthenticationFilter; import stirling.software.proprietary.security.model.User; import stirling.software.proprietary.security.oauth2.CustomOAuth2AuthenticationFailureHandler; @@ -48,6 +53,7 @@ import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticationSuc import stirling.software.proprietary.security.saml2.CustomSaml2ResponseAuthenticationConverter; import stirling.software.proprietary.security.service.CustomOAuth2UserService; import stirling.software.proprietary.security.service.CustomUserDetailsService; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; import stirling.software.proprietary.security.session.SessionPersistentRegistry; @@ -64,9 +70,11 @@ public class SecurityConfiguration { private final boolean loginEnabledValue; private final boolean runningProOrHigher; - private final ApplicationProperties applicationProperties; + private final ApplicationProperties.Security securityProperties; private final AppConfig appConfig; private final UserAuthenticationFilter userAuthenticationFilter; + private final JwtServiceInterface jwtService; + private final JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint; private final LoginAttemptService loginAttemptService; private final FirstLoginFilter firstLoginFilter; private final SessionPersistentRegistry sessionRegistry; @@ -82,8 +90,10 @@ public class SecurityConfiguration { @Qualifier("loginEnabled") boolean loginEnabledValue, @Qualifier("runningProOrHigher") boolean runningProOrHigher, AppConfig appConfig, - ApplicationProperties applicationProperties, + ApplicationProperties.Security securityProperties, UserAuthenticationFilter userAuthenticationFilter, + JwtServiceInterface jwtService, + JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint, LoginAttemptService loginAttemptService, FirstLoginFilter firstLoginFilter, SessionPersistentRegistry sessionRegistry, @@ -97,8 +107,10 @@ public class SecurityConfiguration { this.loginEnabledValue = loginEnabledValue; this.runningProOrHigher = runningProOrHigher; this.appConfig = appConfig; - this.applicationProperties = applicationProperties; + this.securityProperties = securityProperties; this.userAuthenticationFilter = userAuthenticationFilter; + this.jwtService = jwtService; + this.jwtAuthenticationEntryPoint = jwtAuthenticationEntryPoint; this.loginAttemptService = loginAttemptService; this.firstLoginFilter = firstLoginFilter; this.sessionRegistry = sessionRegistry; @@ -115,14 +127,28 @@ public class SecurityConfiguration { @Bean public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { - if (applicationProperties.getSecurity().getCsrfDisabled() || !loginEnabledValue) { - http.csrf(csrf -> csrf.disable()); + if (securityProperties.getCsrfDisabled() || !loginEnabledValue) { + http.csrf(CsrfConfigurer::disable); } if (loginEnabledValue) { + boolean v2Enabled = appConfig.v2Enabled(); + + if (v2Enabled) { + http.addFilterBefore( + jwtAuthenticationFilter(), + UsernamePasswordAuthenticationFilter.class) + .exceptionHandling( + exceptionHandling -> + exceptionHandling.authenticationEntryPoint( + jwtAuthenticationEntryPoint)); + } http.addFilterBefore( - userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class); - if (!applicationProperties.getSecurity().getCsrfDisabled()) { + userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class) + .addFilterAfter(rateLimitingFilter(), UserAuthenticationFilter.class) + .addFilterAfter(firstLoginFilter, UsernamePasswordAuthenticationFilter.class); + + if (!securityProperties.getCsrfDisabled()) { CookieCsrfTokenRepository cookieRepo = CookieCsrfTokenRepository.withHttpOnlyFalse(); CsrfTokenRequestAttributeHandler requestHandler = @@ -156,16 +182,22 @@ public class SecurityConfiguration { .csrfTokenRepository(cookieRepo) .csrfTokenRequestHandler(requestHandler)); } - http.addFilterBefore(rateLimitingFilter(), UsernamePasswordAuthenticationFilter.class); - http.addFilterAfter(firstLoginFilter, UsernamePasswordAuthenticationFilter.class); + + // Configure session management based on JWT setting http.sessionManagement( - sessionManagement -> + sessionManagement -> { + if (v2Enabled) { + sessionManagement.sessionCreationPolicy( + SessionCreationPolicy.STATELESS); + } else { sessionManagement .sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED) .maximumSessions(10) .maxSessionsPreventsLogin(false) .sessionRegistry(sessionRegistry) - .expiredUrl("/login?logout=true")); + .expiredUrl("/login?logout=true"); + } + }); http.authenticationProvider(daoAuthenticationProvider()); http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache())); http.logout( @@ -175,10 +207,10 @@ public class SecurityConfiguration { .matcher("/logout")) .logoutSuccessHandler( new CustomLogoutSuccessHandler( - applicationProperties, appConfig)) + securityProperties, appConfig, jwtService)) .clearAuthentication(true) .invalidateHttpSession(true) - .deleteCookies("JSESSIONID", "remember-me")); + .deleteCookies("JSESSIONID", "remember-me", "stirling_jwt")); http.rememberMe( rememberMeConfigurer -> // Use the configurator directly rememberMeConfigurer @@ -200,6 +232,7 @@ public class SecurityConfiguration { req -> { String uri = req.getRequestURI(); String contextPath = req.getContextPath(); + // Remove the context path from the URI String trimmedUri = uri.startsWith(contextPath) @@ -217,29 +250,35 @@ public class SecurityConfiguration { || trimmedUri.startsWith("/css/") || trimmedUri.startsWith("/fonts/") || trimmedUri.startsWith("/js/") + || trimmedUri.startsWith("/pdfjs/") + || trimmedUri.startsWith("/pdfjs-legacy/") + || trimmedUri.startsWith("/favicon") || trimmedUri.startsWith( - "/api/v1/info/status"); + "/api/v1/info/status") + || trimmedUri.startsWith("/v1/api-docs") + || uri.contains("/v1/api-docs"); }) .permitAll() .anyRequest() .authenticated()); // Handle User/Password Logins - if (applicationProperties.getSecurity().isUserPass()) { + if (securityProperties.isUserPass()) { http.formLogin( formLogin -> formLogin .loginPage("/login") .successHandler( new CustomAuthenticationSuccessHandler( - loginAttemptService, userService)) + loginAttemptService, + userService, + jwtService)) .failureHandler( new CustomAuthenticationFailureHandler( loginAttemptService, userService)) - .defaultSuccessUrl("/") .permitAll()); } // Handle OAUTH2 Logins - if (applicationProperties.getSecurity().isOauth2Active()) { + if (securityProperties.isOauth2Active()) { http.oauth2Login( oauth2 -> oauth2.loginPage("/oauth2") @@ -251,17 +290,18 @@ public class SecurityConfiguration { .successHandler( new CustomOAuth2AuthenticationSuccessHandler( loginAttemptService, - applicationProperties, - userService)) + securityProperties.getOauth2(), + userService, + jwtService)) .failureHandler( new CustomOAuth2AuthenticationFailureHandler()) - . // Add existing Authorities from the database - userInfoEndpoint( + // Add existing Authorities from the database + .userInfoEndpoint( userInfoEndpoint -> userInfoEndpoint .oidcUserService( new CustomOAuth2UserService( - applicationProperties, + securityProperties, userService, loginAttemptService)) .userAuthoritiesMapper( @@ -269,8 +309,7 @@ public class SecurityConfiguration { .permitAll()); } // Handle SAML - if (applicationProperties.getSecurity().isSaml2Active() && runningProOrHigher) { - // Configure the authentication provider + if (securityProperties.isSaml2Active() && runningProOrHigher) { OpenSaml4AuthenticationProvider authenticationProvider = new OpenSaml4AuthenticationProvider(); authenticationProvider.setResponseAuthenticationConverter( @@ -287,8 +326,9 @@ public class SecurityConfiguration { .successHandler( new CustomSaml2AuthenticationSuccessHandler( loginAttemptService, - applicationProperties, - userService)) + securityProperties.getSaml2(), + userService, + jwtService)) .failureHandler( new CustomSaml2AuthenticationFailureHandler()) .authenticationRequestResolver( @@ -306,6 +346,12 @@ public class SecurityConfiguration { return http.build(); } + @Bean + public AuthenticationManager authenticationManager(AuthenticationConfiguration configuration) + throws Exception { + return configuration.getAuthenticationManager(); + } + public DaoAuthenticationProvider daoAuthenticationProvider() { DaoAuthenticationProvider provider = new DaoAuthenticationProvider(userDetailsService); provider.setPasswordEncoder(passwordEncoder()); @@ -323,4 +369,14 @@ public class SecurityConfiguration { public PersistentTokenRepository persistentTokenRepository() { return new JPATokenRepositoryImpl(persistentLoginRepository); } + + @Bean + public JwtAuthenticationFilter jwtAuthenticationFilter() { + return new JwtAuthenticationFilter( + jwtService, + userService, + userDetailsService, + jwtAuthenticationEntryPoint, + securityProperties); + } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/database/repository/JwtSigningKeyRepository.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/database/repository/JwtSigningKeyRepository.java new file mode 100644 index 000000000..7fcb5503a --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/database/repository/JwtSigningKeyRepository.java @@ -0,0 +1,33 @@ +package stirling.software.proprietary.security.database.repository; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.Optional; + +import org.springframework.data.domain.Pageable; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Modifying; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; +import org.springframework.stereotype.Repository; + +import stirling.software.proprietary.security.model.JwtSigningKey; + +@Repository +public interface JwtSigningKeyRepository extends JpaRepository { + + Optional findFirstByIsActiveTrueOrderByCreatedAtDesc(); + + Optional findByKeyId(String keyId); + + @Query("SELECT k FROM JwtSigningKey k WHERE k.createdAt < :cutoffDate ORDER BY k.createdAt ASC") + List findKeysOlderThan( + @Param("cutoffDate") LocalDateTime cutoffDate, Pageable pageable); + + @Query("SELECT COUNT(k) FROM JwtSigningKey k WHERE k.createdAt < :cutoffDate") + long countKeysEligibleForCleanup(@Param("cutoffDate") LocalDateTime cutoffDate); + + @Modifying + @Query("DELETE FROM JwtSigningKey k WHERE k.id IN :ids") + void deleteAllByIdInBatch(@Param("ids") List ids); +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilter.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilter.java new file mode 100644 index 000000000..68bf44b3e --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilter.java @@ -0,0 +1,171 @@ +package stirling.software.proprietary.security.filter; + +import static stirling.software.common.util.RequestUriUtils.isStaticResource; +import static stirling.software.proprietary.security.model.AuthenticationType.*; +import static stirling.software.proprietary.security.model.AuthenticationType.SAML2; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.Map; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.web.filter.OncePerRequestFilter; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.model.ApplicationProperties; +import stirling.software.common.model.exception.UnsupportedProviderException; +import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.service.CustomUserDetailsService; +import stirling.software.proprietary.security.service.JwtServiceInterface; +import stirling.software.proprietary.security.service.UserService; + +@Slf4j +public class JwtAuthenticationFilter extends OncePerRequestFilter { + + private final JwtServiceInterface jwtService; + private final UserService userService; + private final CustomUserDetailsService userDetailsService; + private final AuthenticationEntryPoint authenticationEntryPoint; + private final ApplicationProperties.Security securityProperties; + + public JwtAuthenticationFilter( + JwtServiceInterface jwtService, + UserService userService, + CustomUserDetailsService userDetailsService, + AuthenticationEntryPoint authenticationEntryPoint, + ApplicationProperties.Security securityProperties) { + this.jwtService = jwtService; + this.userService = userService; + this.userDetailsService = userDetailsService; + this.authenticationEntryPoint = authenticationEntryPoint; + this.securityProperties = securityProperties; + } + + @Override + protected void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + if (!jwtService.isJwtEnabled()) { + filterChain.doFilter(request, response); + return; + } + if (isStaticResource(request.getContextPath(), request.getRequestURI())) { + filterChain.doFilter(request, response); + return; + } + + String jwtToken = jwtService.extractToken(request); + // todo: X-API-KEY + if (jwtToken == null) { + // If they are unauthenticated and navigating to '/', redirect to '/login' instead of + // sending a 401 + // todo: any unauthenticated requests should redirect to login + if ("/".equals(request.getRequestURI()) + && "GET".equalsIgnoreCase(request.getMethod())) { + response.sendRedirect("/login"); + return; + } + handleAuthenticationFailure( + request, + response, + new AuthenticationFailureException("JWT is missing from the request")); + return; + } + + try { + jwtService.validateToken(jwtToken); + } catch (AuthenticationFailureException e) { + // Clear invalid tokens from response + jwtService.clearToken(response); + handleAuthenticationFailure(request, response, e); + return; + } + + Map claims = jwtService.extractClaims(jwtToken); + String tokenUsername = claims.get("sub").toString(); + + try { + Authentication authentication = createAuthentication(request, claims); + String jwt = jwtService.generateToken(authentication, claims); + + jwtService.addToken(response, jwt); + } catch (SQLException | UnsupportedProviderException e) { + log.error("Error processing user authentication for user: {}", tokenUsername, e); + handleAuthenticationFailure( + request, + response, + new AuthenticationFailureException("Error processing user authentication", e)); + return; + } + + filterChain.doFilter(request, response); + } + + private Authentication createAuthentication( + HttpServletRequest request, Map claims) + throws SQLException, UnsupportedProviderException { + String username = claims.get("sub").toString(); + + if (username != null && SecurityContextHolder.getContext().getAuthentication() == null) { + processUserAuthenticationType(claims, username); + UserDetails userDetails = userDetailsService.loadUserByUsername(username); + + if (userDetails != null) { + UsernamePasswordAuthenticationToken authToken = + new UsernamePasswordAuthenticationToken( + userDetails, null, userDetails.getAuthorities()); + + authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); + SecurityContextHolder.getContext().setAuthentication(authToken); + } else { + throw new UsernameNotFoundException("User not found: " + username); + } + } + + return SecurityContextHolder.getContext().getAuthentication(); + } + + private void processUserAuthenticationType(Map claims, String username) + throws SQLException, UnsupportedProviderException { + AuthenticationType authenticationType = + AuthenticationType.valueOf(claims.getOrDefault("authType", WEB).toString()); + log.debug("Processing {} login for {} user", authenticationType, username); + + switch (authenticationType) { + case OAUTH2 -> { + ApplicationProperties.Security.OAUTH2 oauth2Properties = + securityProperties.getOauth2(); + userService.processSSOPostLogin( + username, oauth2Properties.getAutoCreateUser(), OAUTH2); + } + case SAML2 -> { + ApplicationProperties.Security.SAML2 saml2Properties = + securityProperties.getSaml2(); + userService.processSSOPostLogin( + username, saml2Properties.getAutoCreateUser(), SAML2); + } + } + } + + private void handleAuthenticationFailure( + HttpServletRequest request, + HttpServletResponse response, + AuthenticationException authException) + throws IOException, ServletException { + authenticationEntryPoint.commence(request, response, authException); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java index e9addd239..f51a9d543 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java @@ -9,7 +9,6 @@ import org.springframework.context.annotation.Lazy; import org.springframework.http.HttpStatus; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.userdetails.UserDetails; @@ -64,6 +63,7 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { return; } String requestURI = request.getRequestURI(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); // Check for session expiration (unsure if needed) @@ -92,14 +92,9 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { response.getWriter().write("Invalid API Key."); return; } - List authorities = - user.get().getAuthorities().stream() - .map( - authority -> - new SimpleGrantedAuthority( - authority.getAuthority())) - .toList(); - authentication = new ApiKeyAuthenticationToken(user.get(), apiKey, authorities); + authentication = + new ApiKeyAuthenticationToken( + user.get(), apiKey, user.get().getAuthorities()); SecurityContextHolder.getContext().setAuthentication(authentication); } catch (AuthenticationException e) { // If API key authentication fails, deny the request @@ -115,20 +110,19 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { String method = request.getMethod(); String contextPath = request.getContextPath(); - if ("GET".equalsIgnoreCase(method) && !(contextPath + "/login").equals(requestURI)) { + if ("GET".equalsIgnoreCase(method) && !requestURI.startsWith(contextPath + "/login")) { response.sendRedirect(contextPath + "/login"); // redirect to the login page - return; } else { response.setStatus(HttpStatus.UNAUTHORIZED.value()); response.getWriter() .write( - "Authentication required. Please provide a X-API-KEY in request" - + " header.\n" - + "This is found in Settings -> Account Settings -> API Key\n" - + "Alternatively you can disable authentication if this is" - + " unexpected"); - return; + """ + Authentication required. Please provide a X-API-KEY in request header. + This is found in Settings -> Account Settings -> API Key + Alternatively you can disable authentication if this is unexpected. + """); } + return; } // Check if the authenticated user is disabled and invalidate their session if so @@ -226,11 +220,12 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { } @Override - protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + protected boolean shouldNotFilter(HttpServletRequest request) { String uri = request.getRequestURI(); String contextPath = request.getContextPath(); String[] permitAllPatterns = { contextPath + "/login", + contextPath + "/signup", contextPath + "/register", contextPath + "/error", contextPath + "/images/", @@ -247,6 +242,7 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { for (String pattern : permitAllPatterns) { if (uri.startsWith(pattern) || uri.endsWith(".svg") + || uri.endsWith(".mjs") || uri.endsWith(".png") || uri.endsWith(".ico")) { return true; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java index ca8140bca..c92c1655e 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java @@ -2,5 +2,8 @@ package stirling.software.proprietary.security.model; public enum AuthenticationType { WEB, - SSO + @Deprecated(since = "1.0.2") + SSO, + OAUTH2, + SAML2 } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java index 382d3a71e..a32e7d7ca 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java @@ -2,6 +2,8 @@ package stirling.software.proprietary.security.model; import java.io.Serializable; +import org.springframework.security.core.GrantedAuthority; + import jakarta.persistence.Column; import jakarta.persistence.Entity; import jakarta.persistence.GeneratedValue; @@ -18,7 +20,7 @@ import lombok.Setter; @Table(name = "authorities") @Getter @Setter -public class Authority implements Serializable { +public class Authority implements GrantedAuthority, Serializable { private static final long serialVersionUID = 1L; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/JwtSigningKey.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/JwtSigningKey.java new file mode 100644 index 000000000..d1b78d8a9 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/JwtSigningKey.java @@ -0,0 +1,62 @@ +package stirling.software.proprietary.security.model; + +import java.io.Serializable; +import java.time.LocalDateTime; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.Table; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.ToString; + +@Entity +@Getter +@Setter +@NoArgsConstructor +@Table(name = "signing_keys") +@ToString(onlyExplicitlyIncluded = true) +@EqualsAndHashCode(onlyExplicitlyIncluded = true) +public class JwtSigningKey implements Serializable { + + private static final long serialVersionUID = 1L; + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Column(name = "signing_key_id") + @EqualsAndHashCode.Include + @ToString.Include + private Long id; + + @Column(name = "key_id", nullable = false, unique = true) + @ToString.Include + private String keyId; + + @Column(name = "signing_key", columnDefinition = "TEXT", nullable = false) + private String signingKey; + + @Column(name = "algorithm", nullable = false) + private String algorithm = "RS256"; + + @Column(name = "created_at", nullable = false) + @ToString.Include + private LocalDateTime createdAt; + + @Column(name = "is_active", nullable = false) + @ToString.Include + private Boolean isActive = true; + + public JwtSigningKey(String keyId, String signingKey, String algorithm) { + this.keyId = keyId; + this.signingKey = signingKey; + this.algorithm = algorithm; + this.createdAt = LocalDateTime.now(); + this.isActive = true; + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java index d3e232f61..7d1b235cd 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java @@ -7,6 +7,8 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.springframework.security.core.userdetails.UserDetails; + import jakarta.persistence.*; import lombok.EqualsAndHashCode; @@ -25,7 +27,7 @@ import stirling.software.proprietary.model.Team; @Setter @EqualsAndHashCode(onlyExplicitlyIncluded = true) @ToString(onlyExplicitlyIncluded = true) -public class User implements Serializable { +public class User implements UserDetails, Serializable { private static final long serialVersionUID = 1L; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/exception/AuthenticationFailureException.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/exception/AuthenticationFailureException.java new file mode 100644 index 000000000..f2cd5e242 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/exception/AuthenticationFailureException.java @@ -0,0 +1,13 @@ +package stirling.software.proprietary.security.model.exception; + +import org.springframework.security.core.AuthenticationException; + +public class AuthenticationFailureException extends AuthenticationException { + public AuthenticationFailureException(String message) { + super(message); + } + + public AuthenticationFailureException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java index 71bd42a85..4e7ed9d9e 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java @@ -1,7 +1,11 @@ package stirling.software.proprietary.security.oauth2; +import static stirling.software.proprietary.security.model.AuthenticationType.OAUTH2; +import static stirling.software.proprietary.security.model.AuthenticationType.SSO; + import java.io.IOException; import java.sql.SQLException; +import java.util.Map; import org.springframework.security.authentication.LockedException; import org.springframework.security.core.Authentication; @@ -18,10 +22,10 @@ import jakarta.servlet.http.HttpSession; import lombok.RequiredArgsConstructor; import stirling.software.common.model.ApplicationProperties; -import stirling.software.common.model.ApplicationProperties.Security.OAUTH2; import stirling.software.common.model.exception.UnsupportedProviderException; import stirling.software.common.util.RequestUriUtils; import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; @@ -30,8 +34,9 @@ public class CustomOAuth2AuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler { private final LoginAttemptService loginAttemptService; - private final ApplicationProperties applicationProperties; + private final ApplicationProperties.Security.OAUTH2 oauth2Properties; private final UserService userService; + private final JwtServiceInterface jwtService; @Override public void onAuthenticationSuccess( @@ -60,8 +65,6 @@ public class CustomOAuth2AuthenticationSuccessHandler // Redirect to the original destination super.onAuthenticationSuccess(request, response, authentication); } else { - OAUTH2 oAuth = applicationProperties.getSecurity().getOauth2(); - if (loginAttemptService.isBlocked(username)) { if (session != null) { session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST"); @@ -69,7 +72,12 @@ public class CustomOAuth2AuthenticationSuccessHandler throw new LockedException( "Your account has been locked due to too many failed login attempts."); } - + if (jwtService.isJwtEnabled()) { + String jwt = + jwtService.generateToken( + authentication, Map.of("authType", AuthenticationType.OAUTH2)); + jwtService.addToken(response, jwt); + } if (userService.isUserDisabled(username)) { getRedirectStrategy() .sendRedirect(request, response, "/logout?userIsDisabled=true"); @@ -77,20 +85,22 @@ public class CustomOAuth2AuthenticationSuccessHandler } if (userService.usernameExistsIgnoreCase(username) && userService.hasPassword(username) - && !userService.isAuthenticationTypeByUsername(username, AuthenticationType.SSO) - && oAuth.getAutoCreateUser()) { + && (!userService.isAuthenticationTypeByUsername(username, SSO) + || !userService.isAuthenticationTypeByUsername(username, OAUTH2)) + && oauth2Properties.getAutoCreateUser()) { response.sendRedirect(contextPath + "/logout?oAuth2AuthenticationErrorWeb=true"); return; } try { - if (oAuth.getBlockRegistration() + if (oauth2Properties.getBlockRegistration() && !userService.usernameExistsIgnoreCase(username)) { response.sendRedirect(contextPath + "/logout?oAuth2AdminBlockedUser=true"); return; } if (principal instanceof OAuth2User) { - userService.processSSOPostLogin(username, oAuth.getAutoCreateUser()); + userService.processSSOPostLogin( + username, oauth2Properties.getAutoCreateUser(), OAUTH2); } response.sendRedirect(contextPath + "/"); } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java index 6516cc7d7..913dc458a 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java @@ -34,6 +34,7 @@ import stirling.software.common.model.oauth2.GitHubProvider; import stirling.software.common.model.oauth2.GoogleProvider; import stirling.software.common.model.oauth2.KeycloakProvider; import stirling.software.common.model.oauth2.Provider; +import stirling.software.proprietary.security.model.Authority; import stirling.software.proprietary.security.model.User; import stirling.software.proprietary.security.model.exception.NoProviderFoundException; import stirling.software.proprietary.security.service.UserService; @@ -239,12 +240,14 @@ public class OAuth2Configuration { Optional userOpt = userService.findByUsernameIgnoreCase( (String) oAuth2Auth.getAttributes().get(useAsUsername)); - if (userOpt.isPresent()) { - User user = userOpt.get(); - mappedAuthorities.add( - new SimpleGrantedAuthority( - userService.findRole(user).getAuthority())); - } + userOpt.ifPresent( + user -> + mappedAuthorities.add( + new Authority( + userService + .findRole(user) + .getAuthority(), + user))); } }); return mappedAuthorities; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java index 2170a9632..3255cbc15 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java @@ -1,7 +1,11 @@ package stirling.software.proprietary.security.saml2; +import static stirling.software.proprietary.security.model.AuthenticationType.SAML2; +import static stirling.software.proprietary.security.model.AuthenticationType.SSO; + import java.io.IOException; import java.sql.SQLException; +import java.util.Map; import org.springframework.security.authentication.LockedException; import org.springframework.security.core.Authentication; @@ -17,10 +21,10 @@ import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import stirling.software.common.model.ApplicationProperties; -import stirling.software.common.model.ApplicationProperties.Security.SAML2; import stirling.software.common.model.exception.UnsupportedProviderException; import stirling.software.common.util.RequestUriUtils; import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; @@ -30,8 +34,9 @@ public class CustomSaml2AuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler { private LoginAttemptService loginAttemptService; - private ApplicationProperties applicationProperties; + private ApplicationProperties.Security.SAML2 saml2Properties; private UserService userService; + private final JwtServiceInterface jwtService; @Override public void onAuthenticationSuccess( @@ -65,10 +70,9 @@ public class CustomSaml2AuthenticationSuccessHandler savedRequest.getRedirectUrl()); super.onAuthenticationSuccess(request, response, authentication); } else { - SAML2 saml2 = applicationProperties.getSecurity().getSaml2(); log.debug( "Processing SAML2 authentication with autoCreateUser: {}", - saml2.getAutoCreateUser()); + saml2Properties.getAutoCreateUser()); if (loginAttemptService.isBlocked(username)) { log.debug("User {} is blocked due to too many login attempts", username); @@ -82,17 +86,21 @@ public class CustomSaml2AuthenticationSuccessHandler boolean userExists = userService.usernameExistsIgnoreCase(username); boolean hasPassword = userExists && userService.hasPassword(username); boolean isSSOUser = - userExists - && userService.isAuthenticationTypeByUsername( - username, AuthenticationType.SSO); + userExists && userService.isAuthenticationTypeByUsername(username, SSO); + boolean isSAML2User = + userExists && userService.isAuthenticationTypeByUsername(username, SAML2); log.debug( - "User status - Exists: {}, Has password: {}, Is SSO user: {}", + "User status - Exists: {}, Has password: {}, Is SSO user: {}, Is SAML2 user: {}", userExists, hasPassword, - isSSOUser); + isSSOUser, + isSAML2User); - if (userExists && hasPassword && !isSSOUser && saml2.getAutoCreateUser()) { + if (userExists + && hasPassword + && (!isSSOUser || !isSAML2User) + && saml2Properties.getAutoCreateUser()) { log.debug( "User {} exists with password but is not SSO user, redirecting to logout", username); @@ -102,15 +110,18 @@ public class CustomSaml2AuthenticationSuccessHandler } try { - if (saml2.getBlockRegistration() && !userExists) { + if (!userExists || saml2Properties.getBlockRegistration()) { log.debug("Registration blocked for new user: {}", username); response.sendRedirect( contextPath + "/login?errorOAuth=oAuth2AdminBlockedUser"); return; } log.debug("Processing SSO post-login for user: {}", username); - userService.processSSOPostLogin(username, saml2.getAutoCreateUser()); + userService.processSSOPostLogin( + username, saml2Properties.getAutoCreateUser(), SAML2); log.debug("Successfully processed authentication for user: {}", username); + + generateJwt(response, authentication); response.sendRedirect(contextPath + "/"); } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { log.debug( @@ -124,4 +135,13 @@ public class CustomSaml2AuthenticationSuccessHandler super.onAuthenticationSuccess(request, response, authentication); } } + + private void generateJwt(HttpServletResponse response, Authentication authentication) { + if (jwtService.isJwtEnabled()) { + String jwt = + jwtService.generateToken( + authentication, Map.of("authType", AuthenticationType.SAML2)); + jwtService.addToken(response, jwt); + } + } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java new file mode 100644 index 000000000..d0508151c --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java @@ -0,0 +1,135 @@ +package stirling.software.proprietary.security.saml2; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.security.service.JwtServiceInterface; + +@Slf4j +public class JwtSaml2AuthenticationRequestRepository + implements Saml2AuthenticationRequestRepository { + private final Map tokenStore; + private final JwtServiceInterface jwtService; + private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + + private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; + + public JwtSaml2AuthenticationRequestRepository( + Map tokenStore, + JwtServiceInterface jwtService, + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + this.tokenStore = tokenStore; + this.jwtService = jwtService; + this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + } + + @Override + public void saveAuthenticationRequest( + Saml2PostAuthenticationRequest authRequest, + HttpServletRequest request, + HttpServletResponse response) { + if (!jwtService.isJwtEnabled()) { + log.debug("V2 is not enabled, skipping SAMLRequest token storage"); + return; + } + + if (authRequest == null) { + removeAuthenticationRequest(request, response); + return; + } + + Map claims = serializeSamlRequest(authRequest); + String token = jwtService.generateToken("", claims); + String relayState = authRequest.getRelayState(); + + tokenStore.put(relayState, token); + request.setAttribute(SAML_REQUEST_TOKEN, relayState); + response.addHeader(SAML_REQUEST_TOKEN, relayState); + + log.debug("Saved SAMLRequest token with RelayState: {}", relayState); + } + + @Override + public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + String token = extractTokenFromStore(request); + + if (token == null) { + log.debug("No SAMLResponse token found in RelayState"); + return null; + } + + Map claims = jwtService.extractClaims(token); + return deserializeSamlRequest(claims); + } + + @Override + public Saml2PostAuthenticationRequest removeAuthenticationRequest( + HttpServletRequest request, HttpServletResponse response) { + Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request); + + String relayStateId = request.getParameter("RelayState"); + if (relayStateId != null) { + tokenStore.remove(relayStateId); + log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId); + } + + return authRequest; + } + + private String extractTokenFromStore(HttpServletRequest request) { + String authnRequestId = request.getParameter("RelayState"); + + if (authnRequestId != null && !authnRequestId.isEmpty()) { + String token = tokenStore.get(authnRequestId); + + if (token != null) { + tokenStore.remove(authnRequestId); + log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId); + return token; + } else { + log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId); + } + } + + return null; + } + + private Map serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) { + Map claims = new HashMap<>(); + + claims.put("id", authRequest.getId()); + claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId()); + claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri()); + claims.put("samlRequest", authRequest.getSamlRequest()); + claims.put("relayState", authRequest.getRelayState()); + + return claims; + } + + private Saml2PostAuthenticationRequest deserializeSamlRequest(Map claims) { + String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId"); + RelyingPartyRegistration relyingPartyRegistration = + relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId); + + if (relyingPartyRegistration == null) { + return null; + } + + return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration) + .id((String) claims.get("id")) + .authenticationRequestUri((String) claims.get("authenticationRequestUri")) + .samlRequest((String) claims.get("samlRequest")) + .relayState((String) claims.get("relayState")) + .build(); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/SAML2Configuration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java similarity index 85% rename from app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/SAML2Configuration.java rename to app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java index 7fd4768b3..9d21f88a3 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/SAML2Configuration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java @@ -3,6 +3,7 @@ package stirling.software.proprietary.security.saml2; import java.security.cert.X509Certificate; import java.util.Collections; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import org.opensaml.saml.saml2.core.AuthnRequest; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -11,12 +12,12 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.Resource; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; -import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; import jakarta.servlet.http.HttpServletRequest; @@ -26,12 +27,13 @@ import lombok.extern.slf4j.Slf4j; import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties.Security.SAML2; +import stirling.software.proprietary.security.service.JwtServiceInterface; @Configuration @Slf4j @ConditionalOnProperty(value = "security.saml2.enabled", havingValue = "true") @RequiredArgsConstructor -public class SAML2Configuration { +public class Saml2Configuration { private final ApplicationProperties applicationProperties; @@ -58,6 +60,7 @@ public class SAML2Configuration { .assertionConsumerServiceBinding(Saml2MessageBinding.POST) .assertionConsumerServiceLocation( "{baseUrl}/login/saml2/sso/{registrationId}") + .authnRequestsSigned(true) .assertingPartyMetadata( metadata -> metadata.entityId(samlConf.getIdpIssuer()) @@ -71,15 +74,29 @@ public class SAML2Configuration { Saml2MessageBinding.POST) .singleLogoutServiceLocation( samlConf.getIdpSingleLogoutUrl()) + .singleLogoutServiceResponseLocation( + "http://localhost:8080/login") .wantAuthnRequestsSigned(true)) .build(); return new InMemoryRelyingPartyRegistrationRepository(rp); } + @Bean + @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") + public Saml2AuthenticationRequestRepository + saml2AuthenticationRequestRepository( + JwtServiceInterface jwtService, + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + return new JwtSaml2AuthenticationRequestRepository( + new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository); + } + @Bean @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver( - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + Saml2AuthenticationRequestRepository + saml2AuthenticationRequestRepository) { OpenSaml4AuthenticationRequestResolver resolver = new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository); @@ -87,10 +104,8 @@ public class SAML2Configuration { customizer -> { HttpServletRequest request = customizer.getRequest(); AuthnRequest authnRequest = customizer.getAuthnRequest(); - HttpSessionSaml2AuthenticationRequestRepository requestRepository = - new HttpSessionSaml2AuthenticationRequestRepository(); - AbstractSaml2AuthenticationRequest saml2AuthenticationRequest = - requestRepository.loadAuthenticationRequest(request); + Saml2PostAuthenticationRequest saml2AuthenticationRequest = + saml2AuthenticationRequestRepository.loadAuthenticationRequest(request); if (saml2AuthenticationRequest != null) { String sessionId = request.getSession(false).getId(); @@ -113,7 +128,6 @@ public class SAML2Configuration { log.debug("Generating new authentication request ID"); authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); } - logAuthnRequestDetails(authnRequest); logHttpRequestDetails(request); }); diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java index 0b286e894..8f9afbe3d 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java @@ -27,13 +27,13 @@ public class CustomOAuth2UserService implements OAuth2UserService new UsernameNotFoundException( "No user found with username: " + username)); + if (loginAttemptService.isBlocked(username)) { throw new LockedException( "Your account has been locked due to too many failed login attempts."); } - if (!user.hasPassword()) { + + AuthenticationType userAuthenticationType = + AuthenticationType.valueOf(user.getAuthenticationType().toUpperCase()); + if (!user.hasPassword() && userAuthenticationType == AuthenticationType.WEB) { throw new IllegalArgumentException("Password must not be null"); } - return new org.springframework.security.core.userdetails.User( - user.getUsername(), - user.getPassword(), - user.isEnabled(), - true, - true, - true, - getAuthorities(user.getAuthorities())); - } - private Collection getAuthorities(Set authorities) { - return authorities.stream() - .map(authority -> new SimpleGrantedAuthority(authority.getAuthority())) - .toList(); + return user; } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java new file mode 100644 index 000000000..32701aa59 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java @@ -0,0 +1,254 @@ +package stirling.software.proprietary.security.service; + +import java.security.KeyPair; +import java.security.PublicKey; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.ResponseCookie; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.stereotype.Service; + +import io.github.pixee.security.Newlines; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.ExpiredJwtException; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.MalformedJwtException; +import io.jsonwebtoken.UnsupportedJwtException; +import io.jsonwebtoken.security.SignatureException; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal; + +@Slf4j +@Service +public class JwtService implements JwtServiceInterface { + + private static final String JWT_COOKIE_NAME = "stirling_jwt"; + private static final String AUTHORIZATION_HEADER = "Authorization"; + private static final String BEARER_PREFIX = "Bearer "; + private static final String ISSUER = "Stirling PDF"; + private static final long EXPIRATION = 3600000; + + @Value("${stirling.security.jwt.secureCookie:true}") + private boolean secureCookie; + + private final KeystoreServiceInterface keystoreService; + private final boolean v2Enabled; + + @Autowired + public JwtService( + @Qualifier("v2Enabled") boolean v2Enabled, KeystoreServiceInterface keystoreService) { + this.v2Enabled = v2Enabled; + this.keystoreService = keystoreService; + } + + @Override + public String generateToken(Authentication authentication, Map claims) { + Object principal = authentication.getPrincipal(); + String username = ""; + + if (principal instanceof UserDetails) { + username = ((UserDetails) principal).getUsername(); + } else if (principal instanceof OAuth2User) { + username = ((OAuth2User) principal).getName(); + } else if (principal instanceof CustomSaml2AuthenticatedPrincipal) { + username = ((CustomSaml2AuthenticatedPrincipal) principal).getName(); + } + + return generateToken(username, claims); + } + + @Override + public String generateToken(String username, Map claims) { + KeyPair keyPair = keystoreService.getActiveKeyPair(); + + var builder = + Jwts.builder() + .claims(claims) + .subject(username) + .issuer(ISSUER) + .issuedAt(new Date()) + .expiration(new Date(System.currentTimeMillis() + EXPIRATION)) + .signWith(keyPair.getPrivate(), Jwts.SIG.RS256); + + String keyId = keystoreService.getActiveKeyId(); + if (keyId != null) { + builder.header().keyId(keyId); + } + + return builder.compact(); + } + + @Override + public void validateToken(String token) throws AuthenticationFailureException { + extractAllClaims(token); + + if (isTokenExpired(token)) { + throw new AuthenticationFailureException("The token has expired"); + } + } + + @Override + public String extractUsername(String token) { + return extractClaim(token, Claims::getSubject); + } + + @Override + public Map extractClaims(String token) { + Claims claims = extractAllClaims(token); + return new HashMap<>(claims); + } + + @Override + public boolean isTokenExpired(String token) { + return extractExpiration(token).before(new Date()); + } + + private Date extractExpiration(String token) { + return extractClaim(token, Claims::getExpiration); + } + + private T extractClaim(String token, Function claimsResolver) { + final Claims claims = extractAllClaims(token); + return claimsResolver.apply(claims); + } + + private Claims extractAllClaims(String token) { + try { + String keyId = extractKeyId(token); + KeyPair keyPair; + + if (keyId != null) { + log.debug("Looking up key pair for key ID: {}", keyId); + Optional specificKeyPair = + keystoreService.getKeyPairByKeyId(keyId); // todo: move to in-memory cache + + if (specificKeyPair.isPresent()) { + keyPair = specificKeyPair.get(); + log.debug("Successfully found key pair for key ID: {}", keyId); + } else { + log.warn( + "Key ID {} not found in keystore, token may have been signed with a rotated key", + keyId); + if (keystoreService.getActiveKeyId().equals(keyId)) { + log.debug("Rotating key pairs"); + keystoreService.refreshKeyPairs(); + } + + keyPair = keystoreService.getActiveKeyPair(); + } + } else { + log.debug("No key ID in token header, using active key pair"); + keyPair = keystoreService.getActiveKeyPair(); + } + + return Jwts.parser() + .verifyWith(keyPair.getPublic()) + .build() + .parseSignedClaims(token) + .getPayload(); + } catch (SignatureException e) { + log.warn("Invalid signature: {}", e.getMessage()); + throw new AuthenticationFailureException("Invalid signature", e); + } catch (MalformedJwtException e) { + log.warn("Invalid token: {}", e.getMessage()); + throw new AuthenticationFailureException("Invalid token", e); + } catch (ExpiredJwtException e) { + log.warn("The token has expired: {}", e.getMessage()); + throw new AuthenticationFailureException("The token has expired", e); + } catch (UnsupportedJwtException e) { + log.warn("The token is unsupported: {}", e.getMessage()); + throw new AuthenticationFailureException("The token is unsupported", e); + } catch (IllegalArgumentException e) { + log.warn("Claims are empty: {}", e.getMessage()); + throw new AuthenticationFailureException("Claims are empty", e); + } + } + + @Override + public String extractToken(HttpServletRequest request) { + Cookie[] cookies = request.getCookies(); + + if (cookies != null) { + for (Cookie cookie : cookies) { + if (JWT_COOKIE_NAME.equals(cookie.getName())) { + return cookie.getValue(); + } + } + } + + return null; + } + + @Override + public void addToken(HttpServletResponse response, String token) { + response.setHeader(AUTHORIZATION_HEADER, Newlines.stripAll(BEARER_PREFIX + token)); + + ResponseCookie cookie = + ResponseCookie.from(JWT_COOKIE_NAME, Newlines.stripAll(token)) + .httpOnly(true) + .secure(secureCookie) + .sameSite("Strict") + .maxAge(EXPIRATION / 1000) + .path("/") + .build(); + + response.addHeader("Set-Cookie", cookie.toString()); + } + + @Override + public void clearToken(HttpServletResponse response) { + response.setHeader(AUTHORIZATION_HEADER, null); + + ResponseCookie cookie = + ResponseCookie.from(JWT_COOKIE_NAME, "") + .httpOnly(true) + .secure(secureCookie) + .sameSite("None") + .maxAge(0) + .path("/") + .build(); + + response.addHeader("Set-Cookie", cookie.toString()); + } + + @Override + public boolean isJwtEnabled() { + return v2Enabled; + } + + private String extractKeyId(String token) { + try { + PublicKey signingKey = keystoreService.getActiveKeyPair().getPublic(); + + String keyId = + (String) + Jwts.parser() + .verifyWith(signingKey) + .build() + .parse(token) + .getHeader() + .get("kid"); + log.debug("Extracted key ID from token: {}", keyId); + return keyId; + } catch (Exception e) { + log.warn("Failed to extract key ID from token header: {}", e.getMessage()); + return null; + } + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtServiceInterface.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtServiceInterface.java new file mode 100644 index 000000000..7cdca8209 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtServiceInterface.java @@ -0,0 +1,90 @@ +package stirling.software.proprietary.security.service; + +import java.util.Map; + +import org.springframework.security.core.Authentication; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public interface JwtServiceInterface { + + /** + * Generate a JWT token for the authenticated user + * + * @param authentication Spring Security authentication object + * @return JWT token as a string + */ + String generateToken(Authentication authentication, Map claims); + + /** + * Generate a JWT token for a specific username + * + * @param username the username for which to generate the token + * @param claims additional claims to include in the token + * @return JWT token as a string + */ + String generateToken(String username, Map claims); + + /** + * Validate a JWT token + * + * @param token the JWT token to validate + * @return true if token is valid, false otherwise + */ + void validateToken(String token); + + /** + * Extract username from JWT token + * + * @param token the JWT token + * @return username extracted from token + */ + String extractUsername(String token); + + /** + * Extract all claims from JWT token + * + * @param token the JWT token + * @return map of claims + */ + Map extractClaims(String token); + + /** + * Check if token is expired + * + * @param token the JWT token + * @return true if token is expired, false otherwise + */ + boolean isTokenExpired(String token); + + /** + * Extract JWT token from HTTP request (header or cookie) + * + * @param request HTTP servlet request + * @return JWT token if found, null otherwise + */ + String extractToken(HttpServletRequest request); + + /** + * Add JWT token to HTTP response (header and cookie) + * + * @param response HTTP servlet response + * @param token JWT token to add + */ + void addToken(HttpServletResponse response, String token); + + /** + * Clear JWT token from HTTP response (remove cookie) + * + * @param response HTTP servlet response + */ + void clearToken(HttpServletResponse response); + + /** + * Check if JWT authentication is enabled + * + * @return true if JWT is enabled, false otherwise + */ + boolean isJwtEnabled(); +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPairCleanupService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPairCleanupService.java new file mode 100644 index 000000000..7727dc1bc --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPairCleanupService.java @@ -0,0 +1,138 @@ +package stirling.software.proprietary.security.service; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.LocalDateTime; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBooleanProperty; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import jakarta.annotation.PostConstruct; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.database.repository.JwtSigningKeyRepository; +import stirling.software.proprietary.security.model.JwtSigningKey; + +@Slf4j +@Service +@ConditionalOnBooleanProperty("v2") +public class KeyPairCleanupService { + + private final JwtSigningKeyRepository signingKeyRepository; + private final KeystoreService keystoreService; + private final ApplicationProperties.Security.Jwt jwtProperties; + + @Autowired + public KeyPairCleanupService( + JwtSigningKeyRepository signingKeyRepository, + KeystoreService keystoreService, + ApplicationProperties applicationProperties) { + this.signingKeyRepository = signingKeyRepository; + this.keystoreService = keystoreService; + this.jwtProperties = applicationProperties.getSecurity().getJwt(); + } + + @Transactional + @PostConstruct + @Scheduled(fixedDelay = 24, timeUnit = TimeUnit.HOURS) + public void cleanup() { + if (!jwtProperties.isEnableKeyCleanup() || !keystoreService.isKeystoreEnabled()) { + log.debug("Key cleanup is disabled"); + return; + } + + log.info("Removing keys older than {} day(s)", jwtProperties.getKeyRetentionDays()); + + try { + LocalDateTime cutoffDate = + LocalDateTime.now().minusDays(jwtProperties.getKeyRetentionDays()); + long totalKeysEligible = signingKeyRepository.countKeysEligibleForCleanup(cutoffDate); + + if (totalKeysEligible == 0) { + log.info("No keys eligible for cleanup"); + return; + } + + log.info("{} eligible keys found", totalKeysEligible); + + batchCleanup(cutoffDate); + } catch (Exception e) { + log.error("Error during scheduled key cleanup", e); + } + } + + private void batchCleanup(LocalDateTime cutoffDate) { + int batchSize = jwtProperties.getCleanupBatchSize(); + + while (true) { + Pageable pageable = PageRequest.of(0, batchSize); + List keysToCleanup = + signingKeyRepository.findKeysOlderThan(cutoffDate, pageable); + + if (keysToCleanup.isEmpty()) { + break; + } + + cleanupKeyBatch(keysToCleanup); + + if (keysToCleanup.size() < batchSize) { + break; + } + } + } + + private void cleanupKeyBatch(List keys) { + keys.forEach( + key -> { + try { + removePrivateKey(key.getKeyId()); + } catch (IOException e) { + log.warn("Failed to cleanup private key for keyId: {}", key.getKeyId(), e); + } + }); + + List keyIds = keys.stream().map(JwtSigningKey::getId).collect(Collectors.toList()); + + signingKeyRepository.deleteAllByIdInBatch(keyIds); + log.debug("Deleted {} signing keys from database", keyIds.size()); + } + + private void removePrivateKey(String keyId) throws IOException { + if (!keystoreService.isKeystoreEnabled()) { + return; + } + + Path privateKeyDirectory = Paths.get(InstallationPathConfig.getPrivateKeyPath()); + Path keyFile = privateKeyDirectory.resolve(keyId + KeystoreService.KEY_SUFFIX); + + if (Files.exists(keyFile)) { + Files.delete(keyFile); + log.debug("Deleted private key file: {}", keyFile); + } else { + log.debug("Private key file not found: {}", keyFile); + } + } + + public long getKeysEligibleForCleanup() { + if (!jwtProperties.isEnableKeyCleanup() || !keystoreService.isKeystoreEnabled()) { + return 0; + } + + LocalDateTime cutoffDate = + LocalDateTime.now().minusDays(jwtProperties.getKeyRetentionDays()); + return signingKeyRepository.countKeysEligibleForCleanup(cutoffDate); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeystoreService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeystoreService.java new file mode 100644 index 000000000..3cb73aadf --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeystoreService.java @@ -0,0 +1,230 @@ +package stirling.software.proprietary.security.service; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.security.spec.X509EncodedKeySpec; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Base64; +import java.util.Optional; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import jakarta.annotation.PostConstruct; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.database.repository.JwtSigningKeyRepository; +import stirling.software.proprietary.security.model.JwtSigningKey; + +@Slf4j +@Service +public class KeystoreService implements KeystoreServiceInterface { + + public static final String KEY_SUFFIX = ".key"; + private final JwtSigningKeyRepository signingKeyRepository; + private final ApplicationProperties.Security.Jwt jwtProperties; + + private volatile KeyPair currentKeyPair; + private volatile String currentKeyId; + + @Autowired + public KeystoreService( + JwtSigningKeyRepository signingKeyRepository, + ApplicationProperties applicationProperties) { + this.signingKeyRepository = signingKeyRepository; + this.jwtProperties = applicationProperties.getSecurity().getJwt(); + } + + @PostConstruct + public void initializeKeystore() { + if (!isKeystoreEnabled()) { + return; + } + + try { + ensurePrivateKeyDirectoryExists(); + loadOrGenerateKeypair(); + } catch (Exception e) { + log.error("Failed to initialize keystore, using in-memory generation", e); + } + } + + @Override + public KeyPair getActiveKeyPair() { + if (!isKeystoreEnabled() || currentKeyPair == null) { + return generateRSAKeypair(); + } + return currentKeyPair; + } + + @Override + public Optional getKeyPairByKeyId(String keyId) { + if (!isKeystoreEnabled()) { + log.debug("Keystore is disabled, cannot lookup key by ID: {}", keyId); + return Optional.empty(); + } + + try { + log.debug("Looking up signing key in database for keyId: {}", keyId); + Optional signingKey = signingKeyRepository.findByKeyId(keyId); + if (signingKey.isEmpty()) { + log.warn("No signing key found in database for keyId: {}", keyId); + return Optional.empty(); + } + + log.debug("Found signing key in database, loading private key for keyId: {}", keyId); + PrivateKey privateKey = loadPrivateKey(keyId); + PublicKey publicKey = decodePublicKey(signingKey.get().getSigningKey()); + + log.debug("Successfully loaded key pair for keyId: {}", keyId); + return Optional.of(new KeyPair(publicKey, privateKey)); + } catch (Exception e) { + log.error("Failed to load keypair for keyId: {}", keyId, e); + return Optional.empty(); + } + } + + @Override + public String getActiveKeyId() { + return currentKeyId; + } + + @Override + public boolean isKeystoreEnabled() { + return jwtProperties.isEnableKeystore(); + } + + private void loadOrGenerateKeypair() { + Optional activeKey = + signingKeyRepository.findFirstByIsActiveTrueOrderByCreatedAtDesc(); + + if (activeKey.isPresent()) { + try { + currentKeyId = activeKey.get().getKeyId(); + PrivateKey privateKey = loadPrivateKey(currentKeyId); + PublicKey publicKey = decodePublicKey(activeKey.get().getSigningKey()); + currentKeyPair = new KeyPair(publicKey, privateKey); + + log.info("Loaded existing keypair: {}", currentKeyId); + } catch (Exception e) { + log.error("Failed to load existing keypair, generating new keypair", e); + generateAndStoreKeypair(); + } + } else { + generateAndStoreKeypair(); + } + } + + @Transactional + private void generateAndStoreKeypair() { + try { + KeyPair keyPair = generateRSAKeypair(); + String keyId = generateKeyId(); + + storePrivateKey(keyId, keyPair.getPrivate()); + + JwtSigningKey signingKey = + new JwtSigningKey(keyId, encodePublicKey(keyPair.getPublic()), "RS256"); + signingKeyRepository.save(signingKey); + currentKeyPair = keyPair; + currentKeyId = keyId; + + log.info("Generated and stored new keypair with keyId: {}", keyId); + } catch (IOException e) { + log.error("Failed to generate and store keypair", e); + throw new RuntimeException("Keypair generation failed", e); + } + } + + private KeyPair generateRSAKeypair() { + KeyPairGenerator keyPairGenerator; + + try { + keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Failed to initialize RSA key pair generator", e); + } + + return keyPairGenerator.generateKeyPair(); + } + + @Override + public KeyPair refreshKeyPairs() { + generateAndStoreKeypair(); + return currentKeyPair; + } + + private String generateKeyId() { + return "jwt-key-" + + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd-HHmmss")); + } + + private void ensurePrivateKeyDirectoryExists() throws IOException { + Path keyPath = Paths.get(InstallationPathConfig.getPrivateKeyPath()); + + if (!Files.exists(keyPath)) { + Files.createDirectories(keyPath); + } + } + + private void storePrivateKey(String keyId, PrivateKey privateKey) throws IOException { + Path keyFile = + Paths.get(InstallationPathConfig.getPrivateKeyPath()).resolve(keyId + KEY_SUFFIX); + String encodedKey = Base64.getEncoder().encodeToString(privateKey.getEncoded()); + Files.writeString(keyFile, encodedKey); + + // Set read/write to only the owner + try { + keyFile.toFile().setReadable(true, true); + keyFile.toFile().setWritable(true, true); + keyFile.toFile().setExecutable(false, false); + } catch (Exception e) { + log.warn("Failed to set permissions on private key file: {}", keyFile, e); + } + } + + private PrivateKey loadPrivateKey(String keyId) + throws IOException, NoSuchAlgorithmException, InvalidKeySpecException { + Path keyFile = + Paths.get(InstallationPathConfig.getPrivateKeyPath()).resolve(keyId + KEY_SUFFIX); + + if (!Files.exists(keyFile)) { + throw new IOException("Private key not found: " + keyFile); + } + + String encodedKey = Files.readString(keyFile); + byte[] keyBytes = Base64.getDecoder().decode(encodedKey); + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + + return keyFactory.generatePrivate(keySpec); + } + + private String encodePublicKey(PublicKey publicKey) { + return Base64.getEncoder().encodeToString(publicKey.getEncoded()); + } + + private PublicKey decodePublicKey(String encodedKey) + throws NoSuchAlgorithmException, InvalidKeySpecException { + byte[] keyBytes = Base64.getDecoder().decode(encodedKey); + X509EncodedKeySpec keySpec = new X509EncodedKeySpec(keyBytes); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + return keyFactory.generatePublic(keySpec); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeystoreServiceInterface.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeystoreServiceInterface.java new file mode 100644 index 000000000..dc0564980 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeystoreServiceInterface.java @@ -0,0 +1,17 @@ +package stirling.software.proprietary.security.service; + +import java.security.KeyPair; +import java.util.Optional; + +public interface KeystoreServiceInterface { + + KeyPair getActiveKeyPair(); + + Optional getKeyPairByKeyId(String keyId); + + String getActiveKeyId(); + + boolean isKeystoreEnabled(); + + KeyPair refreshKeyPairs(); +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java index 50c8027f6..6f213b25e 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java @@ -15,7 +15,6 @@ import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.userdetails.UserDetails; @@ -61,19 +60,9 @@ public class UserService implements UserServiceInterface { private final ApplicationProperties.Security.OAUTH2 oAuth2; - @Transactional - public void migrateOauth2ToSSO() { - userRepository - .findByAuthenticationTypeIgnoreCase("OAUTH2") - .forEach( - user -> { - user.setAuthenticationType(AuthenticationType.SSO); - userRepository.save(user); - }); - } - // Handle OAUTH2 login and user auto creation. - public void processSSOPostLogin(String username, boolean autoCreateUser) + public void processSSOPostLogin( + String username, boolean autoCreateUser, AuthenticationType type) throws IllegalArgumentException, SQLException, UnsupportedProviderException { if (!isUsernameValid(username)) { return; @@ -83,7 +72,7 @@ public class UserService implements UserServiceInterface { return; } if (autoCreateUser) { - saveUser(username, AuthenticationType.SSO); + saveUser(username, type); } } @@ -100,10 +89,7 @@ public class UserService implements UserServiceInterface { } private Collection getAuthorities(User user) { - // Convert each Authority object into a SimpleGrantedAuthority object. - return user.getAuthorities().stream() - .map((Authority authority) -> new SimpleGrantedAuthority(authority.getAuthority())) - .toList(); + return user.getAuthorities(); } private String generateApiKey() { diff --git a/app/proprietary/src/main/resources/static/js/audit/dashboard.js b/app/proprietary/src/main/resources/static/js/audit/dashboard.js index 5cc670908..c0b93bd8e 100644 --- a/app/proprietary/src/main/resources/static/js/audit/dashboard.js +++ b/app/proprietary/src/main/resources/static/js/audit/dashboard.js @@ -230,7 +230,7 @@ function loadAuditData(targetPage, realPageSize) { document.getElementById('page-indicator').textContent = `Page ${requestedPage + 1} of ?`; } - fetch(url) + fetchWithCsrf(url) .then(response => { return response.json(); }) @@ -302,7 +302,7 @@ function loadStats(days) { showLoading('user-chart-loading'); showLoading('time-chart-loading'); - fetch(`/audit/stats?days=${days}`) + fetchWithCsrf(`/audit/stats?days=${days}`) .then(response => response.json()) .then(data => { document.getElementById('total-events').textContent = data.totalEvents; @@ -835,7 +835,7 @@ function hideLoading(id) { // Load event types from the server for filter dropdowns function loadEventTypes() { - fetch('/audit/types') + fetchWithCsrf('/audit/types') .then(response => response.json()) .then(types => { if (!types || types.length === 0) { diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java index a5abd6be5..756405c20 100644 --- a/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java @@ -7,16 +7,21 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import stirling.software.common.configuration.AppConfig; import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.service.JwtServiceInterface; import static org.mockito.Mockito.*; @ExtendWith(MockitoExtension.class) class CustomLogoutSuccessHandlerTest { - @Mock private ApplicationProperties applicationProperties; + @Mock private ApplicationProperties.Security securityProperties; + + @Mock private AppConfig appConfig; + + @Mock private JwtServiceInterface jwtService; @InjectMocks private CustomLogoutSuccessHandler customLogoutSuccessHandler; @@ -24,9 +29,12 @@ class CustomLogoutSuccessHandlerTest { void testSuccessfulLogout() throws IOException { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); - String logoutPath = "logout=true"; + String token = "token"; + String logoutPath = "/login?logout=true"; when(response.isCommitted()).thenReturn(false); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).clearToken(response); when(request.getContextPath()).thenReturn(""); when(response.encodeRedirectURL(logoutPath)).thenReturn(logoutPath); @@ -35,12 +43,30 @@ class CustomLogoutSuccessHandlerTest { verify(response).sendRedirect(logoutPath); } + @Test + void testSuccessfulLogoutViaJWT() throws IOException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + String logoutPath = "/login?logout=true"; + String token = "token"; + + when(response.isCommitted()).thenReturn(false); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).clearToken(response); + when(request.getContextPath()).thenReturn(""); + when(response.encodeRedirectURL(logoutPath)).thenReturn(logoutPath); + + customLogoutSuccessHandler.onLogoutSuccess(request, response, null); + + verify(response).sendRedirect(logoutPath); + verify(jwtService).clearToken(response); + } + @Test void testSuccessfulLogoutViaOAuth2() throws IOException { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken oAuth2AuthenticationToken = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -51,8 +77,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(oAuth2AuthenticationToken.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, oAuth2AuthenticationToken); @@ -67,7 +92,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -81,8 +105,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -98,7 +121,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -108,8 +130,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -124,7 +145,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -135,8 +155,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -151,7 +170,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -164,8 +182,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -180,7 +197,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -195,8 +211,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -211,7 +226,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -227,8 +241,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -243,7 +256,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -256,8 +268,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/JwtAuthenticationEntryPointTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/JwtAuthenticationEntryPointTest.java new file mode 100644 index 000000000..08abd1965 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/JwtAuthenticationEntryPointTest.java @@ -0,0 +1,40 @@ +package stirling.software.proprietary.security; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.io.IOException; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; + +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class JwtAuthenticationEntryPointTest { + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private AuthenticationFailureException authException; + + @InjectMocks + private JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint; + + @Test + void testCommence() throws IOException { + String errorMessage = "Authentication failed"; + when(authException.getMessage()).thenReturn(errorMessage); + + jwtAuthenticationEntryPoint.commence(request, response, authException); + + verify(response).sendError(HttpServletResponse.SC_UNAUTHORIZED, errorMessage); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilterTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilterTest.java new file mode 100644 index 000000000..c42dd7405 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilterTest.java @@ -0,0 +1,221 @@ +package stirling.software.proprietary.security.filter; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.web.AuthenticationEntryPoint; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.service.CustomUserDetailsService; +import stirling.software.proprietary.security.service.JwtServiceInterface; +import stirling.software.proprietary.security.service.UserService; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class JwtAuthenticationFilterTest { + + @Mock + private JwtServiceInterface jwtService; + + @Mock + private CustomUserDetailsService userDetailsService; + + @Mock + private UserService userService; + + @Mock + private ApplicationProperties.Security securityProperties; + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain filterChain; + + @Mock + private UserDetails userDetails; + + @Mock + private SecurityContext securityContext; + + @Mock + private AuthenticationEntryPoint authenticationEntryPoint; + + @InjectMocks + private JwtAuthenticationFilter jwtAuthenticationFilter; + + @Test + void shouldNotAuthenticateWhenJwtDisabled() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(false); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + verify(jwtService, never()).extractToken(any()); + } + + @Test + void shouldNotFilterWhenPageIsLogin() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/login"); + when(request.getContextPath()).thenReturn("/login"); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void testDoFilterInternal() throws ServletException, IOException { + String token = "valid-jwt-token"; + String newToken = "new-jwt-token"; + String username = "testuser"; + Map claims = Map.of("sub", username, "authType", "WEB"); + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getContextPath()).thenReturn("/"); + when(request.getRequestURI()).thenReturn("/protected"); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).validateToken(token); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(userDetails.getAuthorities()).thenReturn(Collections.emptyList()); + when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); + + try (MockedStatic mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { + UsernamePasswordAuthenticationToken authToken = + new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities()); + + when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken); + mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); + when(jwtService.generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims))).thenReturn(newToken); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(jwtService).validateToken(token); + verify(jwtService).extractClaims(token); + verify(userDetailsService).loadUserByUsername(username); + verify(securityContext).setAuthentication(any(UsernamePasswordAuthenticationToken.class)); + verify(jwtService).generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims)); + verify(jwtService).addToken(response, newToken); + verify(filterChain).doFilter(request, response); + } + } + + @Test + void testDoFilterInternalWithMissingTokenForRootPath() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/"); + when(request.getMethod()).thenReturn("GET"); + when(jwtService.extractToken(request)).thenReturn(null); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(response).sendRedirect("/login"); + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void validationFailsWithInvalidToken() throws ServletException, IOException { + String token = "invalid-jwt-token"; + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(token); + doThrow(new AuthenticationFailureException("Invalid token")).when(jwtService).validateToken(token); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(jwtService).validateToken(token); + verify(authenticationEntryPoint).commence(eq(request), eq(response), any(AuthenticationFailureException.class)); + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void validationFailsWithExpiredToken() throws ServletException, IOException { + String token = "expired-jwt-token"; + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(token); + doThrow(new AuthenticationFailureException("The token has expired")).when(jwtService).validateToken(token); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(jwtService).validateToken(token); + verify(authenticationEntryPoint).commence(eq(request), eq(response), any()); + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void exceptinonThrown_WhenUserNotFound() throws ServletException, IOException { + String token = "valid-jwt-token"; + String username = "nonexistentuser"; + Map claims = Map.of("sub", username, "authType", "WEB"); + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).validateToken(token); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(userDetailsService.loadUserByUsername(username)).thenReturn(null); + + try (MockedStatic mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { + when(securityContext.getAuthentication()).thenReturn(null); + mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); + + UsernameNotFoundException result = assertThrows(UsernameNotFoundException.class, () -> jwtAuthenticationFilter.doFilterInternal(request, response, filterChain)); + + assertEquals("User not found: " + username, result.getMessage()); + verify(userDetailsService).loadUserByUsername(username); + verify(filterChain, never()).doFilter(request, response); + } + } + + @Test + void testAuthenticationEntryPointCalledWithCorrectException() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(null); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(authenticationEntryPoint).commence(eq(request), eq(response), argThat(exception -> + exception.getMessage().equals("JWT is missing from the request") + )); + verify(filterChain, never()).doFilter(request, response); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java new file mode 100644 index 000000000..f08d33d8c --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java @@ -0,0 +1,228 @@ +package stirling.software.proprietary.security.saml2; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata; +import stirling.software.proprietary.security.service.JwtServiceInterface; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class JwtSaml2AuthenticationRequestRepositoryTest { + + private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; + + private Map tokenStore; + + @Mock + private JwtServiceInterface jwtService; + + @Mock + private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + + private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository; + + @BeforeEach + void setUp() { + tokenStore = new ConcurrentHashMap<>(); + jwtSaml2AuthenticationRequestRepository = new JwtSaml2AuthenticationRequestRepository( + tokenStore, jwtService, relyingPartyRegistrationRepository); + } + + @Test + void saveAuthenticationRequest() { + var authRequest = mock(Saml2PostAuthenticationRequest.class); + var request = mock(MockHttpServletRequest.class); + var response = mock(MockHttpServletResponse.class); + String token = "testToken"; + String id = "testId"; + String relayState = "testRelayState"; + String authnRequestUri = "example.com/authnRequest"; + Map claims = Map.of(); + String samlRequest = "testSamlRequest"; + String relyingPartyRegistrationId = "stirling-pdf"; + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(authRequest.getRelayState()).thenReturn(relayState); + when(authRequest.getId()).thenReturn(id); + when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri); + when(authRequest.getSamlRequest()).thenReturn(samlRequest); + when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId); + when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token); + + jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(authRequest, request, response); + + verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState); + verify(response).addHeader(SAML_REQUEST_TOKEN, relayState); + } + + @Test + void saveAuthenticationRequestWithNullRequest() { + var request = mock(MockHttpServletRequest.class); + var response = mock(MockHttpServletResponse.class); + + jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response); + + assertTrue(tokenStore.isEmpty()); + } + + @Test + void loadAuthenticationRequest() { + var request = mock(MockHttpServletRequest.class); + var relyingPartyRegistration = mock(RelyingPartyRegistration.class); + var assertingPartyMetadata = mock(AssertingPartyMetadata.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState + ); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration); + when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); + when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata); + when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso"); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNotNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } + + @ParameterizedTest + @NullAndEmptySource + void loadAuthenticationRequestWithInvalidRelayState(String relayState) { + var request = mock(MockHttpServletRequest.class); + when(request.getParameter("RelayState")).thenReturn(relayState); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void loadAuthenticationRequestWithNonExistentToken() { + var request = mock(MockHttpServletRequest.class); + when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void loadAuthenticationRequestWithNullRelyingPartyRegistration() { + var request = mock(MockHttpServletRequest.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState + ); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(null); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void removeAuthenticationRequest() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + var relyingPartyRegistration = mock(RelyingPartyRegistration.class); + var assertingPartyMetadata = mock(AssertingPartyMetadata.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState + ); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration); + when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); + when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata); + when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso"); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNotNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } + + @Test + void removeAuthenticationRequestWithNullRelayState() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + when(request.getParameter("RelayState")).thenReturn(null); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNull(result); + } + + @Test + void removeAuthenticationRequestWithNonExistentToken() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNull(result); + } + + @Test + void removeAuthenticationRequestWithOnlyRelayState() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + String relayState = "testRelayState"; + + when(request.getParameter("RelayState")).thenReturn(relayState); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/service/JwtServiceTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/JwtServiceTest.java new file mode 100644 index 000000000..4f8726335 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/JwtServiceTest.java @@ -0,0 +1,341 @@ +package stirling.software.proprietary.security.service; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.util.Collections; +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.core.Authentication; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.model.User; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.contains; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class JwtServiceTest { + + @Mock + private ApplicationProperties.Security securityProperties; + + @Mock + private Authentication authentication; + + @Mock + private User userDetails; + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private KeystoreServiceInterface keystoreService; + + private JwtService jwtService; + private KeyPair testKeyPair; + + @BeforeEach + void setUp() throws NoSuchAlgorithmException { + // Generate a test keypair + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + testKeyPair = keyPairGenerator.generateKeyPair(); + + jwtService = new JwtService(true, keystoreService); + } + + @Test + void testGenerateTokenWithAuthentication() { + String username = "testuser"; + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, Collections.emptyMap()); + + assertNotNull(token); + assertFalse(token.isEmpty()); + assertEquals(username, jwtService.extractUsername(token)); + } + + @Test + void testGenerateTokenWithUsernameAndClaims() { + String username = "testuser"; + Map claims = new HashMap<>(); + claims.put("role", "admin"); + claims.put("department", "IT"); + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + assertNotNull(token); + assertFalse(token.isEmpty()); + assertEquals(username, jwtService.extractUsername(token)); + + Map extractedClaims = jwtService.extractClaims(token); + assertEquals("admin", extractedClaims.get("role")); + assertEquals("IT", extractedClaims.get("department")); + } + + @Test + void testValidateTokenSuccess() { + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn("testuser"); + + String token = jwtService.generateToken(authentication, new HashMap<>()); + + assertDoesNotThrow(() -> jwtService.validateToken(token)); + } + + @Test + void testValidateTokenWithInvalidToken() { + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + + assertThrows(AuthenticationFailureException.class, () -> { + jwtService.validateToken("invalid-token"); + }); + } + + @Test + void testValidateTokenWithMalformedToken() { + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + + AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> { + jwtService.validateToken("malformed.token"); + }); + + assertTrue(exception.getMessage().contains("Invalid")); + } + + @Test + void testValidateTokenWithEmptyToken() { + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + + AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> { + jwtService.validateToken(""); + }); + + assertTrue(exception.getMessage().contains("Claims are empty") || exception.getMessage().contains("Invalid")); + } + + @Test + void testExtractUsername() { + String username = "testuser"; + User user = mock(User.class); + Map claims = Map.of("sub", "testuser", "authType", "WEB"); + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(user); + when(user.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + assertEquals(username, jwtService.extractUsername(token)); + } + + @Test + void testExtractUsernameWithInvalidToken() { + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + + assertThrows(AuthenticationFailureException.class, () -> jwtService.extractUsername("invalid-token")); + } + + @Test + void testExtractClaims() { + String username = "testuser"; + Map claims = Map.of("role", "admin", "department", "IT"); + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + Map extractedClaims = jwtService.extractClaims(token); + + assertEquals("admin", extractedClaims.get("role")); + assertEquals("IT", extractedClaims.get("department")); + assertEquals(username, extractedClaims.get("sub")); + assertEquals("Stirling PDF", extractedClaims.get("iss")); + } + + @Test + void testExtractClaimsWithInvalidToken() { + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + + assertThrows(AuthenticationFailureException.class, () -> jwtService.extractClaims("invalid-token")); + } + + @Test + void testExtractTokenWithCookie() { + String token = "test-token"; + Cookie[] cookies = { new Cookie("stirling_jwt", token) }; + when(request.getCookies()).thenReturn(cookies); + + assertEquals(token, jwtService.extractToken(request)); + } + + @Test + void testExtractTokenWithNoCookies() { + when(request.getCookies()).thenReturn(null); + + assertNull(jwtService.extractToken(request)); + } + + @Test + void testExtractTokenWithWrongCookie() { + Cookie[] cookies = {new Cookie("OTHER_COOKIE", "value")}; + when(request.getCookies()).thenReturn(cookies); + + assertNull(jwtService.extractToken(request)); + } + + @Test + void testExtractTokenWithInvalidAuthorizationHeader() { + when(request.getCookies()).thenReturn(null); + + assertNull(jwtService.extractToken(request)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testAddToken(boolean secureCookie) throws Exception { + String token = "test-token"; + + // Create new JwtService instance with the secureCookie parameter + JwtService testJwtService = createJwtServiceWithSecureCookie(secureCookie); + + testJwtService.addToken(response, token); + + verify(response).setHeader("Authorization", "Bearer " + token); + verify(response).addHeader(eq("Set-Cookie"), contains("stirling_jwt=" + token)); + verify(response).addHeader(eq("Set-Cookie"), contains("HttpOnly")); + + if (secureCookie) { + verify(response).addHeader(eq("Set-Cookie"), contains("Secure")); + } else { + verify(response, org.mockito.Mockito.never()).addHeader(eq("Set-Cookie"), contains("Secure")); + } + } + + @Test + void testClearToken() { + jwtService.clearToken(response); + + verify(response).setHeader("Authorization", null); + verify(response).addHeader(eq("Set-Cookie"), contains("stirling_jwt=")); + verify(response).addHeader(eq("Set-Cookie"), contains("Max-Age=0")); + } + + @Test + void testGenerateTokenWithKeyId() { + String username = "testuser"; + Map claims = new HashMap<>(); + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + assertNotNull(token); + assertFalse(token.isEmpty()); + // Verify that the keystore service was called + verify(keystoreService).getActiveKeyPair(); + verify(keystoreService).getActiveKeyId(); + } + + @Test + void testTokenVerificationWithSpecificKeyId() throws NoSuchAlgorithmException { + String username = "testuser"; + Map claims = new HashMap<>(); + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + // Generate token with key ID + String token = jwtService.generateToken(authentication, claims); + + // Mock extraction of key ID and verification (lenient to avoid unused stubbing) + lenient().when(keystoreService.getKeyPairByKeyId("test-key-id")).thenReturn(Optional.of(testKeyPair)); + + // Verify token can be validated + assertDoesNotThrow(() -> jwtService.validateToken(token)); + assertEquals(username, jwtService.extractUsername(token)); + } + + @Test + void testTokenVerificationFallsBackToActiveKeyWhenKeyIdNotFound() { + String username = "testuser"; + Map claims = new HashMap<>(); + + when(keystoreService.getActiveKeyPair()).thenReturn(testKeyPair); + when(keystoreService.getActiveKeyId()).thenReturn("test-key-id"); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + // Mock scenario where specific key ID is not found (lenient to avoid unused stubbing) + lenient().when(keystoreService.getKeyPairByKeyId("test-key-id")).thenReturn(Optional.empty()); + + // Should still work using active keypair + assertDoesNotThrow(() -> jwtService.validateToken(token)); + assertEquals(username, jwtService.extractUsername(token)); + + // Verify fallback to active keypair was used (called multiple times during token operations) + verify(keystoreService, atLeast(1)).getActiveKeyPair(); + } + + private JwtService createJwtServiceWithSecureCookie(boolean secureCookie) throws Exception { + // Use reflection to create JwtService with custom secureCookie value + JwtService testService = new JwtService(true, keystoreService); + + // Set the secureCookie field using reflection + java.lang.reflect.Field secureCookieField = JwtService.class.getDeclaredField("secureCookie"); + secureCookieField.setAccessible(true); + secureCookieField.set(testService, secureCookie); + + return testService; + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeyPairCleanupServiceTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeyPairCleanupServiceTest.java new file mode 100644 index 000000000..ae07362c8 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeyPairCleanupServiceTest.java @@ -0,0 +1,248 @@ +package stirling.software.proprietary.security.service; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.data.domain.Pageable; + +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.database.repository.JwtSigningKeyRepository; +import stirling.software.proprietary.security.model.JwtSigningKey; + +@ExtendWith(MockitoExtension.class) +class KeyPairCleanupServiceTest { + + @Mock + private JwtSigningKeyRepository signingKeyRepository; + + @Mock + private KeystoreService keystoreService; + + @Mock + private ApplicationProperties applicationProperties; + + @Mock + private ApplicationProperties.Security security; + + @Mock + private ApplicationProperties.Security.Jwt jwtConfig; + + @TempDir + private Path tempDir; + + private KeyPairCleanupService cleanupService; + + @BeforeEach + void setUp() { + lenient().when(applicationProperties.getSecurity()).thenReturn(security); + lenient().when(security.getJwt()).thenReturn(jwtConfig); + + lenient().when(jwtConfig.isEnableKeyCleanup()).thenReturn(true); + lenient().when(jwtConfig.getKeyRetentionDays()).thenReturn(7); + lenient().when(jwtConfig.getCleanupBatchSize()).thenReturn(100); + lenient().when(keystoreService.isKeystoreEnabled()).thenReturn(true); + + cleanupService = new KeyPairCleanupService(signingKeyRepository, keystoreService, applicationProperties); + } + + + @Test + void testCleanupDisabled_ShouldSkip() { + when(jwtConfig.isEnableKeyCleanup()).thenReturn(false); + + cleanupService.cleanup(); + + verify(signingKeyRepository, never()).countKeysEligibleForCleanup(any(LocalDateTime.class)); + verify(signingKeyRepository, never()).findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class)); + } + + @Test + void testCleanup_WhenKeystoreDisabled_ShouldSkip() { + when(keystoreService.isKeystoreEnabled()).thenReturn(false); + + cleanupService.cleanup(); + + verify(signingKeyRepository, never()).countKeysEligibleForCleanup(any(LocalDateTime.class)); + verify(signingKeyRepository, never()).findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class)); + } + + @Test + void testCleanup_WhenNoKeysEligible_ShouldExitEarly() { + when(signingKeyRepository.countKeysEligibleForCleanup(any(LocalDateTime.class))).thenReturn(0L); + + cleanupService.cleanup(); + + verify(signingKeyRepository).countKeysEligibleForCleanup(any(LocalDateTime.class)); + verify(signingKeyRepository, never()).findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class)); + } + + @Test + void testCleanupSuccessfully() throws IOException { + JwtSigningKey key1 = createTestKey("key-1", 1L); + JwtSigningKey key2 = createTestKey("key-2", 2L); + List keysToCleanup = Arrays.asList(key1, key2); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + + createTestKeyFile("key-1"); + createTestKeyFile("key-2"); + + when(signingKeyRepository.countKeysEligibleForCleanup(any(LocalDateTime.class))).thenReturn(2L); + when(signingKeyRepository.findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class))) + .thenReturn(keysToCleanup) + .thenReturn(Collections.emptyList()); + + cleanupService.cleanup(); + + verify(signingKeyRepository).countKeysEligibleForCleanup(any(LocalDateTime.class)); + verify(signingKeyRepository).findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class)); + verify(signingKeyRepository).deleteAllByIdInBatch(Arrays.asList(1L, 2L)); + + assertFalse(Files.exists(tempDir.resolve("key-1.key"))); + assertFalse(Files.exists(tempDir.resolve("key-2.key"))); + } + } + + @Test + void testCleanup_WithBatchProcessing_ShouldProcessMultipleBatches() throws IOException { + when(jwtConfig.getCleanupBatchSize()).thenReturn(2); + + JwtSigningKey key1 = createTestKey("key-1", 1L); + JwtSigningKey key2 = createTestKey("key-2", 2L); + JwtSigningKey key3 = createTestKey("key-3", 3L); + + List firstBatch = Arrays.asList(key1, key2); + List secondBatch = Arrays.asList(key3); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + + createTestKeyFile("key-1"); + createTestKeyFile("key-2"); + createTestKeyFile("key-3"); + + when(signingKeyRepository.countKeysEligibleForCleanup(any(LocalDateTime.class))).thenReturn(3L); + when(signingKeyRepository.findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class))) + .thenReturn(firstBatch) + .thenReturn(secondBatch) + .thenReturn(Collections.emptyList()); + + cleanupService.cleanup(); + + verify(signingKeyRepository, times(2)).deleteAllByIdInBatch(any()); + verify(signingKeyRepository).deleteAllByIdInBatch(Arrays.asList(1L, 2L)); + verify(signingKeyRepository).deleteAllByIdInBatch(Arrays.asList(3L)); + } + } + + @Test + void testCleanup() throws IOException { + JwtSigningKey key1 = createTestKey("key-1", 1L); + JwtSigningKey key2 = createTestKey("key-2", 2L); + List keysToCleanup = Arrays.asList(key1, key2); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + + createTestKeyFile("key-1"); + + when(signingKeyRepository.countKeysEligibleForCleanup(any(LocalDateTime.class))).thenReturn(2L); + when(signingKeyRepository.findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class))) + .thenReturn(keysToCleanup) + .thenReturn(Collections.emptyList()); + + cleanupService.cleanup(); + + verify(signingKeyRepository).deleteAllByIdInBatch(Arrays.asList(1L, 2L)); + assertFalse(Files.exists(tempDir.resolve("key-1.key"))); + } + } + + @Test + void testGetKeysEligibleForCleanup() { + when(signingKeyRepository.countKeysEligibleForCleanup(any(LocalDateTime.class))).thenReturn(5L); + + long result = cleanupService.getKeysEligibleForCleanup(); + + assertEquals(5L, result); + verify(signingKeyRepository).countKeysEligibleForCleanup(any(LocalDateTime.class)); + } + + @Test + void shouldReturnZero_WhenCleanupDisabled() { + when(jwtConfig.isEnableKeyCleanup()).thenReturn(false); + + long result = cleanupService.getKeysEligibleForCleanup(); + + assertEquals(0L, result); + verify(signingKeyRepository, never()).countKeysEligibleForCleanup(any(LocalDateTime.class)); + } + + @Test + void shouldReturnZero_WhenKeystoreDisabled() { + when(keystoreService.isKeystoreEnabled()).thenReturn(false); + + long result = cleanupService.getKeysEligibleForCleanup(); + + assertEquals(0L, result); + verify(signingKeyRepository, never()).countKeysEligibleForCleanup(any(LocalDateTime.class)); + } + + @Test + void testCleanup_WithRetentionDaysConfiguration_ShouldUseCorrectCutoffDate() { + when(jwtConfig.getKeyRetentionDays()).thenReturn(14); + when(signingKeyRepository.countKeysEligibleForCleanup(any(LocalDateTime.class))).thenReturn(0L); + + cleanupService.cleanup(); + + verify(signingKeyRepository).countKeysEligibleForCleanup(argThat((LocalDateTime cutoffDate) -> { + LocalDateTime expectedCutoff = LocalDateTime.now().minusDays(14); + return Math.abs(java.time.Duration.between(cutoffDate, expectedCutoff).toMinutes()) <= 1; + })); + } + + @Test + void testCleanupPrivateKeyFile_WhenKeystoreDisabled_ShouldSkipFileRemove() throws IOException { + when(keystoreService.isKeystoreEnabled()).thenReturn(false); + + cleanupService.cleanup(); + + verify(signingKeyRepository, never()).countKeysEligibleForCleanup(any(LocalDateTime.class)); + verify(signingKeyRepository, never()).findKeysOlderThan(any(LocalDateTime.class), any(Pageable.class)); + verify(signingKeyRepository, never()).deleteAllByIdInBatch(any()); + } + + private JwtSigningKey createTestKey(String keyId, Long id) { + JwtSigningKey key = new JwtSigningKey(); + key.setId(id); + key.setKeyId(keyId); + key.setSigningKey("test-public-key"); + key.setAlgorithm("RS256"); + key.setIsActive(false); + key.setCreatedAt(LocalDateTime.now().minusDays(10)); + return key; + } + + private void createTestKeyFile(String keyId) throws IOException { + Path keyFile = tempDir.resolve(keyId + ".key"); + Files.writeString(keyFile, "test-private-key-content"); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeystoreServiceInterfaceTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeystoreServiceInterfaceTest.java new file mode 100644 index 000000000..2380eb00b --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeystoreServiceInterfaceTest.java @@ -0,0 +1,222 @@ +package stirling.software.proprietary.security.service; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.database.repository.JwtSigningKeyRepository; +import stirling.software.proprietary.security.model.JwtSigningKey; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KeystoreServiceInterfaceTest { + + @Mock + private JwtSigningKeyRepository repository; + + @Mock + private ApplicationProperties applicationProperties; + + @Mock + private ApplicationProperties.Security security; + + @Mock + private ApplicationProperties.Security.Jwt jwtConfig; + + @TempDir + Path tempDir; + + private KeystoreService keystoreService; + private KeyPair testKeyPair; + + @BeforeEach + void setUp() throws NoSuchAlgorithmException { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + testKeyPair = keyPairGenerator.generateKeyPair(); + + lenient().when(applicationProperties.getSecurity()).thenReturn(security); + lenient().when(security.getJwt()).thenReturn(jwtConfig); + lenient().when(jwtConfig.isEnableKeystore()).thenReturn(true); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testKeystoreEnabled(boolean keystoreEnabled) { + when(jwtConfig.isEnableKeystore()).thenReturn(keystoreEnabled); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + + assertEquals(keystoreEnabled, keystoreService.isKeystoreEnabled()); + } + } + + @Test + void testGetActiveKeyPairWhenKeystoreDisabled() { + when(jwtConfig.isEnableKeystore()).thenReturn(false); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + + KeyPair result = keystoreService.getActiveKeyPair(); + + assertNotNull(result); + assertNotNull(result.getPublic()); + assertNotNull(result.getPrivate()); + } + } + + @Test + void testGetActiveKeypairWhenNoActiveKeyExists() { + when(repository.findFirstByIsActiveTrueOrderByCreatedAtDesc()).thenReturn(Optional.empty()); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + keystoreService.initializeKeystore(); + + KeyPair result = keystoreService.getActiveKeyPair(); + + assertNotNull(result); + verify(repository).save(any(JwtSigningKey.class)); + } + } + + @Test + void testGetActiveKeyPairWithExistingKey() throws Exception { + String keyId = "test-key-2024-01-01-120000"; + String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + String privateKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded()); + + JwtSigningKey existingKey = new JwtSigningKey(keyId, publicKeyBase64, "RS256"); + when(repository.findFirstByIsActiveTrueOrderByCreatedAtDesc()).thenReturn(Optional.of(existingKey)); + + Path keyFile = tempDir.resolve(keyId + ".key"); + Files.writeString(keyFile, privateKeyBase64); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + keystoreService.initializeKeystore(); + + KeyPair result = keystoreService.getActiveKeyPair(); + + assertNotNull(result); + assertEquals(keyId, keystoreService.getActiveKeyId()); + } + } + + @Test + void testGetKeyPairByKeyId() throws Exception { + String keyId = "test-key-123"; + String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + String privateKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded()); + + JwtSigningKey signingKey = new JwtSigningKey(keyId, publicKeyBase64, "RS256"); + when(repository.findByKeyId(keyId)).thenReturn(Optional.of(signingKey)); + + Path keyFile = tempDir.resolve(keyId + ".key"); + Files.writeString(keyFile, privateKeyBase64); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + + Optional result = keystoreService.getKeyPairByKeyId(keyId); + + assertTrue(result.isPresent()); + assertNotNull(result.get().getPublic()); + assertNotNull(result.get().getPrivate()); + } + } + + @Test + void testGetKeyPairByKeyIdNotFound() { + String keyId = "non-existent-key"; + when(repository.findByKeyId(keyId)).thenReturn(Optional.empty()); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + + Optional result = keystoreService.getKeyPairByKeyId(keyId); + + assertFalse(result.isPresent()); + } + } + + @Test + void testGetKeyPairByKeyIdWhenKeystoreDisabled() { + when(jwtConfig.isEnableKeystore()).thenReturn(false); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + + Optional result = keystoreService.getKeyPairByKeyId("any-key"); + + assertFalse(result.isPresent()); + } + } + + @Test + void testInitializeKeystoreCreatesDirectory() throws IOException { + when(repository.findFirstByIsActiveTrueOrderByCreatedAtDesc()).thenReturn(Optional.empty()); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + keystoreService.initializeKeystore(); + + assertTrue(Files.exists(tempDir)); + assertTrue(Files.isDirectory(tempDir)); + } + } + + @Test + void testLoadExistingKeypairWithMissingPrivateKeyFile() { + String keyId = "test-key-missing-file"; + String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + + JwtSigningKey existingKey = new JwtSigningKey(keyId, publicKeyBase64, "RS256"); + when(repository.findFirstByIsActiveTrueOrderByCreatedAtDesc()).thenReturn(Optional.of(existingKey)); + + try (MockedStatic mockedStatic = mockStatic(InstallationPathConfig.class)) { + mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); + keystoreService = new KeystoreService(repository, applicationProperties); + keystoreService.initializeKeystore(); + + KeyPair result = keystoreService.getActiveKeyPair(); + assertNotNull(result); + + verify(repository).save(any(JwtSigningKey.class)); + } + } + +} diff --git a/exampleYmlFiles/test_cicd.yml b/exampleYmlFiles/test_cicd.yml index 31e24da48..086f862d5 100644 --- a/exampleYmlFiles/test_cicd.yml +++ b/exampleYmlFiles/test_cicd.yml @@ -20,6 +20,7 @@ services: environment: DISABLE_ADDITIONAL_FEATURES: "false" SECURITY_ENABLELOGIN: "true" + V2: "false" PUID: 1002 PGID: 1002 UMASK: "022"