wip - making saml auth work

This commit is contained in:
Dario Ghunney Ware 2025-02-03 15:16:49 +00:00
parent dcc2194add
commit 42fd885ac1
15 changed files with 132 additions and 137 deletions

View File

@ -6,6 +6,7 @@ import java.security.interfaces.RSAPrivateKey;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -14,6 +15,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler; import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import com.coveo.saml.SamlClient; import com.coveo.saml.SamlClient;
import com.coveo.saml.SamlException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
@ -47,9 +49,8 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
} else if (authentication instanceof OAuth2AuthenticationToken) { } else if (authentication instanceof OAuth2AuthenticationToken) {
// Handle OAuth2 logout redirection // Handle OAuth2 logout redirection
getRedirect_oauth2(request, response, authentication); getRedirect_oauth2(request, response, authentication);
} } else if (authentication instanceof UsernamePasswordAuthenticationToken) {
// Handle Username/Password logout // Handle Username/Password logout
else if (authentication instanceof UsernamePasswordAuthenticationToken) {
getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH); getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH);
} else { } else {
// Handle unknown authentication types // Handle unknown authentication types
@ -88,27 +89,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
certificates.add(certificate); certificates.add(certificate);
// Construct URLs required for SAML configuration // Construct URLs required for SAML configuration
String serverUrl = SamlClient samlClient = getSamlClient(registrationId, samlConf, certificates);
SPDFApplication.getStaticBaseUrl() + ":" + SPDFApplication.getStaticPort();
String relyingPartyIdentifier =
serverUrl + "/saml2/service-provider-metadata/" + registrationId;
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
// Create SamlClient instance for SAML logout
SamlClient samlClient =
new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);
// Read private key for service provider // Read private key for service provider
Resource privateKeyResource = samlConf.getPrivateKey(); Resource privateKeyResource = samlConf.getPrivateKey();
@ -125,7 +106,6 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
} }
} }
// Redirect for OAuth2 authentication logout
private void getRedirect_oauth2( private void getRedirect_oauth2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication) HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException { throws IOException {
@ -162,12 +142,45 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
response.sendRedirect(redirectUrl); response.sendRedirect(redirectUrl);
} }
default -> { default -> {
log.info("Redirecting to default logout URL: {}", redirectUrl); String logoutUrl = oauth.getLogoutUrl();
response.sendRedirect(redirectUrl);
if (StringUtils.isNotBlank(logoutUrl)) {
log.info("Redirecting to logout URL: {}", logoutUrl);
response.sendRedirect(logoutUrl);
} else {
log.info("Redirecting to default logout URL: {}", redirectUrl);
response.sendRedirect(redirectUrl);
}
} }
} }
} }
// Redirect for OAuth2 authentication logout
private static SamlClient getSamlClient(
String registrationId, SAML2 samlConf, List<X509Certificate> certificates)
throws SamlException {
String serverUrl =
SPDFApplication.getStaticBaseUrl() + ":" + SPDFApplication.getStaticPort();
String relyingPartyIdentifier =
serverUrl + "/saml2/service-provider-metadata/" + registrationId;
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
// Create SamlClient instance for SAML logout
return new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);
}
/** /**
* Handles different error scenarios during logout. Will return a <code>String</code> containing * Handles different error scenarios during logout. Will return a <code>String</code> containing
* the error request parameter. * the error request parameter.

View File

@ -369,18 +369,15 @@ public class UserService implements UserServiceInterface {
public void invalidateUserSessions(String username) { public void invalidateUserSessions(String username) {
String usernameP = ""; String usernameP = "";
for (Object principal : sessionRegistry.getAllPrincipals()) { for (Object principal : sessionRegistry.getAllPrincipals()) {
for (SessionInformation sessionsInformation : for (SessionInformation sessionsInformation :
sessionRegistry.getAllSessions(principal, false)) { sessionRegistry.getAllSessions(principal, false)) {
if (principal instanceof UserDetails) { if (principal instanceof UserDetails userDetails) {
UserDetails userDetails = (UserDetails) principal;
usernameP = userDetails.getUsername(); usernameP = userDetails.getUsername();
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User oAuth2User) {
OAuth2User oAuth2User = (OAuth2User) principal;
usernameP = oAuth2User.getName(); usernameP = oAuth2User.getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) { } else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) {
CustomSaml2AuthenticatedPrincipal saml2User =
(CustomSaml2AuthenticatedPrincipal) principal;
usernameP = saml2User.getName(); usernameP = saml2User.getName();
} else if (principal instanceof String) { } else if (principal instanceof String) {
usernameP = (String) principal; usernameP = (String) principal;
@ -394,6 +391,7 @@ public class UserService implements UserServiceInterface {
public String getCurrentUsername() { public String getCurrentUsername() {
Object principal = SecurityContextHolder.getContext().getAuthentication().getPrincipal(); Object principal = SecurityContextHolder.getContext().getAuthentication().getPrincipal();
if (principal instanceof UserDetails) { if (principal instanceof UserDetails) {
return ((UserDetails) principal).getUsername(); return ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) { } else if (principal instanceof OAuth2User) {
@ -402,8 +400,6 @@ public class UserService implements UserServiceInterface {
applicationProperties.getSecurity().getOauth2().getUseAsUsername()); applicationProperties.getSecurity().getOauth2().getUseAsUsername());
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) { } else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
return ((CustomSaml2AuthenticatedPrincipal) principal).getName(); return ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) {
return (String) principal;
} else { } else {
return principal.toString(); return principal.toString();
} }

View File

@ -93,7 +93,7 @@ public class OAuth2Configuration {
.clientId(keycloak.getClientId()) .clientId(keycloak.getClientId())
.clientSecret(keycloak.getClientSecret()) .clientSecret(keycloak.getClientSecret())
.scope(keycloak.getScopes()) .scope(keycloak.getScopes())
.userNameAttributeName(keycloak.getUseAsUsername().name()) .userNameAttributeName(keycloak.getUseAsUsername().getName())
.clientName(keycloak.getClientName()) .clientName(keycloak.getClientName())
.build()) .build())
: Optional.empty(); : Optional.empty();
@ -124,7 +124,7 @@ public class OAuth2Configuration {
.authorizationUri(google.getAuthorizationUri()) .authorizationUri(google.getAuthorizationUri())
.tokenUri(google.getTokenUri()) .tokenUri(google.getTokenUri())
.userInfoUri(google.getUserInfoUri()) .userInfoUri(google.getUserInfoUri())
.userNameAttributeName(google.getUseAsUsername().name()) .userNameAttributeName(google.getUseAsUsername().getName())
.clientName(google.getClientName()) .clientName(google.getClientName())
.redirectUri(REDIRECT_URI_PATH + google.getName()) .redirectUri(REDIRECT_URI_PATH + google.getName())
.authorizationGrantType(AUTHORIZATION_CODE) .authorizationGrantType(AUTHORIZATION_CODE)
@ -157,7 +157,7 @@ public class OAuth2Configuration {
.authorizationUri(github.getAuthorizationUri()) .authorizationUri(github.getAuthorizationUri())
.tokenUri(github.getTokenUri()) .tokenUri(github.getTokenUri())
.userInfoUri(github.getUserInfoUri()) .userInfoUri(github.getUserInfoUri())
.userNameAttributeName(github.getUseAsUsername().name()) .userNameAttributeName(github.getUseAsUsername().getName())
.clientName(github.getClientName()) .clientName(github.getClientName())
.redirectUri(REDIRECT_URI_PATH + github.getName()) .redirectUri(REDIRECT_URI_PATH + github.getName())
.authorizationGrantType(AUTHORIZATION_CODE) .authorizationGrantType(AUTHORIZATION_CODE)
@ -184,7 +184,8 @@ public class OAuth2Configuration {
oauth.getClientId(), oauth.getClientId(),
oauth.getClientSecret(), oauth.getClientSecret(),
oauth.getScopes(), oauth.getScopes(),
UsernameAttribute.valueOf(oauth.getUseAsUsername()), UsernameAttribute.valueOf(oauth.getUseAsUsername().toUpperCase()),
oauth.getLogoutUrl(),
null, null,
null, null,
null); null);

View File

@ -21,6 +21,8 @@ public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthentica
HttpServletResponse response, HttpServletResponse response,
AuthenticationException exception) AuthenticationException exception)
throws IOException { throws IOException {
log.error("Authentication error", exception);
if (exception instanceof Saml2AuthenticationException) { if (exception instanceof Saml2AuthenticationException) {
Saml2Error error = ((Saml2AuthenticationException) exception).getSaml2Error(); Saml2Error error = ((Saml2AuthenticationException) exception).getSaml2Error();
getRedirectStrategy() getRedirectStrategy()
@ -32,6 +34,5 @@ public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthentica
response, response,
"/login?errorOAuth=not_authentication_provider_found"); "/login?errorOAuth=not_authentication_provider_found");
} }
log.error("Authentication error", exception);
} }
} }

View File

@ -110,13 +110,11 @@ public class CustomSaml2AuthenticationSuccessHandler
userService.processSSOPostLogin(username, saml2.getAutoCreateUser()); userService.processSSOPostLogin(username, saml2.getAutoCreateUser());
log.debug("Successfully processed authentication for user: {}", username); log.debug("Successfully processed authentication for user: {}", username);
response.sendRedirect(contextPath + "/"); response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) {
log.debug( log.debug(
"Invalid username detected for user: {}, redirecting to logout", "Invalid username detected for user: {}, redirecting to logout",
username); username);
response.sendRedirect(contextPath + "/logout?invalidUsername=true"); response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
} }
} }
} else { } else {

View File

@ -20,7 +20,7 @@ import stirling.software.SPDF.model.User;
public class CustomSaml2ResponseAuthenticationConverter public class CustomSaml2ResponseAuthenticationConverter
implements Converter<ResponseToken, Saml2Authentication> { implements Converter<ResponseToken, Saml2Authentication> {
private UserService userService; private final UserService userService;
public CustomSaml2ResponseAuthenticationConverter(UserService userService) { public CustomSaml2ResponseAuthenticationConverter(UserService userService) {
this.userService = userService; this.userService = userService;
@ -60,10 +60,10 @@ public class CustomSaml2ResponseAuthenticationConverter
Map<String, List<Object>> attributes = extractAttributes(assertion); Map<String, List<Object>> attributes = extractAttributes(assertion);
// Debug log with actual values // Debug log with actual values
log.debug("Extracted SAML Attributes: " + attributes); log.debug("Extracted SAML Attributes: {}", attributes);
// Try to get username/identifier in order of preference // Try to get username/identifier in order of preference
String userIdentifier = null; String userIdentifier;
if (hasAttribute(attributes, "username")) { if (hasAttribute(attributes, "username")) {
userIdentifier = getFirstAttributeValue(attributes, "username"); userIdentifier = getFirstAttributeValue(attributes, "username");
} else if (hasAttribute(attributes, "emailaddress")) { } else if (hasAttribute(attributes, "emailaddress")) {
@ -83,10 +83,8 @@ public class CustomSaml2ResponseAuthenticationConverter
SimpleGrantedAuthority simpleGrantedAuthority = new SimpleGrantedAuthority("ROLE_USER"); SimpleGrantedAuthority simpleGrantedAuthority = new SimpleGrantedAuthority("ROLE_USER");
if (userOpt.isPresent()) { if (userOpt.isPresent()) {
User user = userOpt.get(); User user = userOpt.get();
if (user != null) { simpleGrantedAuthority =
simpleGrantedAuthority = new SimpleGrantedAuthority(userService.findRole(user).getAuthority());
new SimpleGrantedAuthority(userService.findRole(user).getAuthority());
}
} }
List<String> sessionIndexes = new ArrayList<>(); List<String> sessionIndexes = new ArrayList<>();

View File

@ -2,8 +2,8 @@ package stirling.software.SPDF.config.security.saml2;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Collections; import java.util.Collections;
import java.util.UUID; import java.util.UUID;
import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
@ -49,15 +49,23 @@ public class SAML2Configuration {
RelyingPartyRegistration rp = RelyingPartyRegistration rp =
RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId()) RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId())
.signingX509Credentials(c -> c.add(signingCredential)) .signingX509Credentials(c -> c.add(signingCredential))
.entityId(samlConf.getIdpIssuer())
.singleLogoutServiceBinding(Saml2MessageBinding.POST)
.singleLogoutServiceLocation(samlConf.getIdpSingleLogoutUrl())
.authnRequestsSigned(true)
.assertingPartyMetadata( .assertingPartyMetadata(
metadata -> metadata ->
metadata.entityId(samlConf.getIdpIssuer()) metadata.entityId(samlConf.getIdpIssuer())
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.verificationX509Credentials( .verificationX509Credentials(
c -> c.add(verificationCredential)) c -> c.add(verificationCredential))
.singleSignOnServiceBinding( .singleSignOnServiceBinding(
Saml2MessageBinding.POST) Saml2MessageBinding.POST)
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.singleLogoutServiceBinding(
Saml2MessageBinding.POST)
.singleLogoutServiceLocation(
samlConf.getIdpSingleLogoutUrl())
.wantAuthnRequestsSigned(true)) .wantAuthnRequestsSigned(true))
.build(); .build();
return new InMemoryRelyingPartyRegistrationRepository(rp); return new InMemoryRelyingPartyRegistrationRepository(rp);
@ -73,9 +81,9 @@ public class SAML2Configuration {
customizer -> { customizer -> {
log.debug("Customizing SAML Authentication request"); log.debug("Customizing SAML Authentication request");
AuthnRequest authnRequest = customizer.getAuthnRequest(); AuthnRequest authnRequest = customizer.getAuthnRequest();
log.debug("AuthnRequest ID: {}", authnRequest.getID()); log.debug("AuthnRequest ID: {}", authnRequest.getID());
if (authnRequest.getID() == null) { if (authnRequest.getID() == null) {
authnRequest.setID("ARQ" + UUID.randomUUID()); // fixme: SubjectConfirmationData@InResponseTo authnRequest.setID("ARQ" + UUID.randomUUID());
} }
log.debug("AuthnRequest new ID after set: {}", authnRequest.getID()); log.debug("AuthnRequest new ID after set: {}", authnRequest.getID());
log.debug("AuthnRequest IssueInstant: {}", authnRequest.getIssueInstant()); log.debug("AuthnRequest IssueInstant: {}", authnRequest.getIssueInstant());
@ -94,12 +102,11 @@ public class SAML2Configuration {
// Log headers // Log headers
Collections.list(request.getHeaderNames()) Collections.list(request.getHeaderNames())
.forEach( .forEach(
headerName -> { headerName ->
log.debug( log.debug(
"Header - {}: {}", "Header - {}: {}",
headerName, headerName,
request.getHeader(headerName)); request.getHeader(headerName)));
});
// Log SAML specific parameters // Log SAML specific parameters
log.debug("SAML Request Parameters:"); log.debug("SAML Request Parameters:");
log.debug("SAMLRequest: {}", request.getParameter("SAMLRequest")); log.debug("SAMLRequest: {}", request.getParameter("SAMLRequest"));

View File

@ -124,7 +124,7 @@ public class UserController {
return new RedirectView("/change-creds?messageType=notAuthenticated", true); return new RedirectView("/change-creds?messageType=notAuthenticated", true);
} }
Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName()); Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName());
if (userOpt == null || userOpt.isEmpty()) { if (userOpt.isEmpty()) {
return new RedirectView("/change-creds?messageType=userNotFound", true); return new RedirectView("/change-creds?messageType=userNotFound", true);
} }
User user = userOpt.get(); User user = userOpt.get();
@ -152,7 +152,7 @@ public class UserController {
return new RedirectView("/account?messageType=notAuthenticated", true); return new RedirectView("/account?messageType=notAuthenticated", true);
} }
Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName()); Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName());
if (userOpt == null || userOpt.isEmpty()) { if (userOpt.isEmpty()) {
return new RedirectView("/account?messageType=userNotFound", true); return new RedirectView("/account?messageType=userNotFound", true);
} }
User user = userOpt.get(); User user = userOpt.get();
@ -174,7 +174,7 @@ public class UserController {
for (Map.Entry<String, String[]> entry : paramMap.entrySet()) { for (Map.Entry<String, String[]> entry : paramMap.entrySet()) {
updates.put(entry.getKey(), entry.getValue()[0]); updates.put(entry.getKey(), entry.getValue()[0]);
} }
log.debug("Processed updates: " + updates); log.debug("Processed updates: {}", updates);
// Assuming you have a method in userService to update the settings for a user // Assuming you have a method in userService to update the settings for a user
userService.updateUserSettings(principal.getName(), updates); userService.updateUserSettings(principal.getName(), updates);
// Redirect to a page of your choice after updating // Redirect to a page of your choice after updating
@ -197,7 +197,7 @@ public class UserController {
Optional<User> userOpt = userService.findByUsernameIgnoreCase(username); Optional<User> userOpt = userService.findByUsernameIgnoreCase(username);
if (userOpt.isPresent()) { if (userOpt.isPresent()) {
User user = userOpt.get(); User user = userOpt.get();
if (user != null && user.getUsername().equalsIgnoreCase(username)) { if (user.getUsername().equalsIgnoreCase(username)) {
return new RedirectView("/addUsers?messageType=usernameExists", true); return new RedirectView("/addUsers?messageType=usernameExists", true);
} }
} }
@ -274,7 +274,7 @@ public class UserController {
Authentication authentication) Authentication authentication)
throws SQLException, UnsupportedProviderException { throws SQLException, UnsupportedProviderException {
Optional<User> userOpt = userService.findByUsernameIgnoreCase(username); Optional<User> userOpt = userService.findByUsernameIgnoreCase(username);
if (!userOpt.isPresent()) { if (userOpt.isEmpty()) {
return new RedirectView("/addUsers?messageType=userNotFound", true); return new RedirectView("/addUsers?messageType=userNotFound", true);
} }
if (!userService.usernameExistsIgnoreCase(username)) { if (!userService.usernameExistsIgnoreCase(username)) {
@ -293,7 +293,7 @@ public class UserController {
List<Object> principals = sessionRegistry.getAllPrincipals(); List<Object> principals = sessionRegistry.getAllPrincipals();
String userNameP = ""; String userNameP = "";
for (Object principal : principals) { for (Object principal : principals) {
List<SessionInformation> sessionsInformations = List<SessionInformation> sessionsInformation =
sessionRegistry.getAllSessions(principal, false); sessionRegistry.getAllSessions(principal, false);
if (principal instanceof UserDetails) { if (principal instanceof UserDetails) {
userNameP = ((UserDetails) principal).getUsername(); userNameP = ((UserDetails) principal).getUsername();
@ -305,8 +305,8 @@ public class UserController {
userNameP = (String) principal; userNameP = (String) principal;
} }
if (userNameP.equalsIgnoreCase(username)) { if (userNameP.equalsIgnoreCase(username)) {
for (SessionInformation sessionsInformation : sessionsInformations) { for (SessionInformation sessionInfo : sessionsInformation) {
sessionRegistry.expireSession(sessionsInformation.getSessionId()); sessionRegistry.expireSession(sessionInfo.getSessionId());
} }
} }
} }

View File

@ -111,7 +111,6 @@ public class AccountWebController {
} }
// Remove any null keys/values from the providerList // Remove any null keys/values from the providerList
// providerList might be empty on browser side? Button not showing up
providerList providerList
.entrySet() .entrySet()
.removeIf(entry -> entry.getKey() == null || entry.getValue() == null); .removeIf(entry -> entry.getKey() == null || entry.getValue() == null);
@ -216,13 +215,11 @@ public class AccountWebController {
.plus(maxInactiveInterval, ChronoUnit.SECONDS); .plus(maxInactiveInterval, ChronoUnit.SECONDS);
if (now.isAfter(expirationTime)) { if (now.isAfter(expirationTime)) {
sessionPersistentRegistry.expireSession(sessionEntity.getSessionId()); sessionPersistentRegistry.expireSession(sessionEntity.getSessionId());
hasActiveSession = false;
} else { } else {
hasActiveSession = !sessionEntity.isExpired(); hasActiveSession = !sessionEntity.isExpired();
} }
lastRequest = sessionEntity.getLastRequest(); lastRequest = sessionEntity.getLastRequest();
} else { } else {
hasActiveSession = false;
// No session, set default last request time // No session, set default last request time
lastRequest = new Date(0); lastRequest = new Date(0);
} }
@ -259,53 +256,41 @@ public class AccountWebController {
}) })
.collect(Collectors.toList()); .collect(Collectors.toList());
String messageType = request.getParameter("messageType"); String messageType = request.getParameter("messageType");
String deleteMessage = null;
String deleteMessage;
if (messageType != null) { if (messageType != null) {
switch (messageType) { deleteMessage =
case "deleteCurrentUser": switch (messageType) {
deleteMessage = "deleteCurrentUserMessage"; case "deleteCurrentUser" -> "deleteCurrentUserMessage";
break; case "deleteUsernameExists" -> "deleteUsernameExistsMessage";
case "deleteUsernameExists": default -> null;
deleteMessage = "deleteUsernameExistsMessage"; };
break;
default:
break;
}
model.addAttribute("deleteMessage", deleteMessage); model.addAttribute("deleteMessage", deleteMessage);
String addMessage = null;
switch (messageType) { String addMessage;
case "usernameExists": addMessage =
addMessage = "usernameExistsMessage"; switch (messageType) {
break; case "usernameExists" -> "usernameExistsMessage";
case "invalidUsername": case "invalidUsername" -> "invalidUsernameMessage";
addMessage = "invalidUsernameMessage"; case "invalidPassword" -> "invalidPasswordMessage";
break; default -> null;
case "invalidPassword": };
addMessage = "invalidPasswordMessage";
break;
default:
break;
}
model.addAttribute("addMessage", addMessage); model.addAttribute("addMessage", addMessage);
} }
String changeMessage = null;
String changeMessage;
if (messageType != null) { if (messageType != null) {
switch (messageType) { changeMessage =
case "userNotFound": switch (messageType) {
changeMessage = "userNotFoundMessage"; case "userNotFound" -> "userNotFoundMessage";
break; case "downgradeCurrentUser" -> "downgradeCurrentUserMessage";
case "downgradeCurrentUser": case "disabledCurrentUser" -> "disabledCurrentUserMessage";
changeMessage = "downgradeCurrentUserMessage"; default -> messageType;
break; };
case "disabledCurrentUser":
changeMessage = "disabledCurrentUserMessage";
break;
default:
changeMessage = messageType;
break;
}
model.addAttribute("changeMessage", changeMessage); model.addAttribute("changeMessage", changeMessage);
} }
model.addAttribute("users", sortedUsers); model.addAttribute("users", sortedUsers);
model.addAttribute("currentUsername", authentication.getName()); model.addAttribute("currentUsername", authentication.getName());
model.addAttribute("roleDetails", roleDetails); model.addAttribute("roleDetails", roleDetails);
@ -326,39 +311,35 @@ public class AccountWebController {
if (authentication != null && authentication.isAuthenticated()) { if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal(); Object principal = authentication.getPrincipal();
String username = null; String username = null;
// Retrieve username and other attributes and add login attributes to the model
if (principal instanceof UserDetails userDetails) { if (principal instanceof UserDetails userDetails) {
// Retrieve username and other attributes
username = userDetails.getUsername(); username = userDetails.getUsername();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", false); model.addAttribute("oAuth2Login", false);
} }
if (principal instanceof OAuth2User userDetails) { if (principal instanceof OAuth2User userDetails) {
// Retrieve username and other attributes
username = userDetails.getName(); username = userDetails.getName();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", true); model.addAttribute("oAuth2Login", true);
} }
if (principal instanceof CustomSaml2AuthenticatedPrincipal userDetails) { if (principal instanceof CustomSaml2AuthenticatedPrincipal userDetails) {
// Retrieve username and other attributes
username = userDetails.getName(); username = userDetails.getName();
// Add oAuth2 Login attributes to the model model.addAttribute("saml2Login", true);
model.addAttribute("oAuth2Login", true);
} }
if (username != null) { if (username != null) {
// Fetch user details from the database, assuming findByUsername method exists // Fetch user details from the database
Optional<User> user = userRepository.findByUsernameIgnoreCaseWithSettings(username); Optional<User> user = userRepository.findByUsernameIgnoreCaseWithSettings(username);
if (user.isEmpty()) { if (user.isEmpty()) {
return "redirect:/error"; return "redirect:/error";
} }
// Convert settings map to JSON string // Convert settings map to JSON string
ObjectMapper objectMapper = new ObjectMapper(); ObjectMapper objectMapper = new ObjectMapper();
String settingsJson; String settingsJson;
try { try {
settingsJson = objectMapper.writeValueAsString(user.get().getSettings()); settingsJson = objectMapper.writeValueAsString(user.get().getSettings());
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
// Handle JSON conversion error log.error("Error converting settings map", e);
log.error("exception", e);
return "redirect:/error"; return "redirect:/error";
} }
@ -372,7 +353,7 @@ public class AccountWebController {
case "invalidUsername" -> messageType = "invalidUsernameMessage"; case "invalidUsername" -> messageType = "invalidUsernameMessage";
} }
} }
// Add attributes to the model
model.addAttribute("username", username); model.addAttribute("username", username);
model.addAttribute("messageType", messageType); model.addAttribute("messageType", messageType);
model.addAttribute("role", user.get().getRolesAsString()); model.addAttribute("role", user.get().getRolesAsString());
@ -395,19 +376,12 @@ public class AccountWebController {
} }
if (authentication != null && authentication.isAuthenticated()) { if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal(); Object principal = authentication.getPrincipal();
if (principal instanceof UserDetails) { if (principal instanceof UserDetails userDetails) {
// Cast the principal object to UserDetails
UserDetails userDetails = (UserDetails) principal;
// Retrieve username and other attributes
String username = userDetails.getUsername(); String username = userDetails.getUsername();
// Fetch user details from the database // Fetch user details from the database
Optional<User> user = Optional<User> user = userRepository.findByUsernameIgnoreCase(username);
userRepository if (user.isEmpty()) {
.findByUsernameIgnoreCase( // Assuming findByUsername method exists // Handle error appropriately, example redirection in case of error
username);
if (!user.isPresent()) {
// Handle error appropriately
// Example redirection in case of error
return "redirect:/error"; return "redirect:/error";
} }
String messageType = request.getParameter("messageType"); String messageType = request.getParameter("messageType");
@ -430,7 +404,7 @@ public class AccountWebController {
} }
model.addAttribute("messageType", messageType); model.addAttribute("messageType", messageType);
} }
// Add attributes to the model
model.addAttribute("username", username); model.addAttribute("username", username);
} }
} else { } else {

View File

@ -227,6 +227,7 @@ public class ApplicationProperties {
private Collection<String> scopes = new ArrayList<>(); private Collection<String> scopes = new ArrayList<>();
private String provider; private String provider;
private Client client = new Client(); private Client client = new Client();
private String logoutUrl;
public void setScopes(String scopes) { public void setScopes(String scopes) {
List<String> scopesList = List<String> scopesList =

View File

@ -28,6 +28,7 @@ public class GitHubProvider extends Provider {
clientSecret, clientSecret,
scopes, scopes,
useAsUsername != null ? useAsUsername : UsernameAttribute.LOGIN, useAsUsername != null ? useAsUsername : UsernameAttribute.LOGIN,
null,
AUTHORIZATION_URI, AUTHORIZATION_URI,
TOKEN_URI, TOKEN_URI,
USER_INFO_URI); USER_INFO_URI);

View File

@ -29,6 +29,7 @@ public class GoogleProvider extends Provider {
clientSecret, clientSecret,
scopes, scopes,
useAsUsername, useAsUsername,
null,
AUTHORIZATION_URI, AUTHORIZATION_URI,
TOKEN_URI, TOKEN_URI,
USER_INFO_URI); USER_INFO_URI);

View File

@ -28,6 +28,7 @@ public class KeycloakProvider extends Provider {
useAsUsername, useAsUsername,
null, null,
null, null,
null,
null); null);
} }

View File

@ -25,6 +25,7 @@ public class Provider {
private String clientSecret; private String clientSecret;
private Collection<String> scopes; private Collection<String> scopes;
private UsernameAttribute useAsUsername; private UsernameAttribute useAsUsername;
private String logoutUrl;
private String authorizationUri; private String authorizationUri;
private String tokenUri; private String tokenUri;
private String userInfoUri; private String userInfoUri;
@ -37,6 +38,7 @@ public class Provider {
String clientSecret, String clientSecret,
Collection<String> scopes, Collection<String> scopes,
UsernameAttribute useAsUsername, UsernameAttribute useAsUsername,
String logoutUrl,
String authorizationUri, String authorizationUri,
String tokenUri, String tokenUri,
String userInfoUri) { String userInfoUri) {
@ -48,6 +50,7 @@ public class Provider {
this.scopes = scopes == null ? new ArrayList<>() : scopes; this.scopes = scopes == null ? new ArrayList<>() : scopes;
this.useAsUsername = this.useAsUsername =
useAsUsername != null ? validateUsernameAttribute(useAsUsername) : EMAIL; useAsUsername != null ? validateUsernameAttribute(useAsUsername) : EMAIL;
this.logoutUrl = logoutUrl;
this.authorizationUri = authorizationUri; this.authorizationUri = authorizationUri;
this.tokenUri = tokenUri; this.tokenUri = tokenUri;
this.userInfoUri = userInfoUri; this.userInfoUri = userInfoUri;

View File

@ -34,7 +34,7 @@
</th:block> </th:block>
<!-- Change Username Form --> <!-- Change Username Form -->
<th:block th:if="${!oAuth2Login}"> <th:block th:if="not ${oAuth2Login} or not ${saml2Login}">
<h4 th:text="#{account.changeUsername}">Change Username?</h4> <h4 th:text="#{account.changeUsername}">Change Username?</h4>
<form id="formsavechangeusername" class="bg-card mt-4 mb-4" th:action="@{'/api/v1/user/change-username'}" method="post"> <form id="formsavechangeusername" class="bg-card mt-4 mb-4" th:action="@{'/api/v1/user/change-username'}" method="post">
<div class="mb-3"> <div class="mb-3">
@ -53,7 +53,7 @@
</th:block> </th:block>
<!-- Change Password Form --> <!-- Change Password Form -->
<th:block th:if="${!oAuth2Login}"> <th:block th:if="not ${oAuth2Login} or not ${saml2Login}">
<h4 th:text="#{account.changePassword}">Change Password?</h4> <h4 th:text="#{account.changePassword}">Change Password?</h4>
<form id="formsavechangepassword" class="bg-card mt-4 mb-4" th:action="@{'/api/v1/user/change-password'}" method="post"> <form id="formsavechangepassword" class="bg-card mt-4 mb-4" th:action="@{'/api/v1/user/change-password'}" method="post">
<div class="mb-3"> <div class="mb-3">