wip - making saml auth work

This commit is contained in:
Dario Ghunney Ware 2025-02-03 15:16:49 +00:00
parent dcc2194add
commit 42fd885ac1
15 changed files with 132 additions and 137 deletions

View File

@ -6,6 +6,7 @@ import java.security.interfaces.RSAPrivateKey;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.Resource;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
@ -14,6 +15,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import com.coveo.saml.SamlClient;
import com.coveo.saml.SamlException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
@ -47,9 +49,8 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
} else if (authentication instanceof OAuth2AuthenticationToken) {
// Handle OAuth2 logout redirection
getRedirect_oauth2(request, response, authentication);
}
// Handle Username/Password logout
else if (authentication instanceof UsernamePasswordAuthenticationToken) {
} else if (authentication instanceof UsernamePasswordAuthenticationToken) {
// Handle Username/Password logout
getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH);
} else {
// Handle unknown authentication types
@ -88,27 +89,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
certificates.add(certificate);
// Construct URLs required for SAML configuration
String serverUrl =
SPDFApplication.getStaticBaseUrl() + ":" + SPDFApplication.getStaticPort();
String relyingPartyIdentifier =
serverUrl + "/saml2/service-provider-metadata/" + registrationId;
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
// Create SamlClient instance for SAML logout
SamlClient samlClient =
new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);
SamlClient samlClient = getSamlClient(registrationId, samlConf, certificates);
// Read private key for service provider
Resource privateKeyResource = samlConf.getPrivateKey();
@ -125,7 +106,6 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
}
}
// Redirect for OAuth2 authentication logout
private void getRedirect_oauth2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
@ -162,12 +142,45 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
response.sendRedirect(redirectUrl);
}
default -> {
log.info("Redirecting to default logout URL: {}", redirectUrl);
response.sendRedirect(redirectUrl);
String logoutUrl = oauth.getLogoutUrl();
if (StringUtils.isNotBlank(logoutUrl)) {
log.info("Redirecting to logout URL: {}", logoutUrl);
response.sendRedirect(logoutUrl);
} else {
log.info("Redirecting to default logout URL: {}", redirectUrl);
response.sendRedirect(redirectUrl);
}
}
}
}
// Redirect for OAuth2 authentication logout
private static SamlClient getSamlClient(
String registrationId, SAML2 samlConf, List<X509Certificate> certificates)
throws SamlException {
String serverUrl =
SPDFApplication.getStaticBaseUrl() + ":" + SPDFApplication.getStaticPort();
String relyingPartyIdentifier =
serverUrl + "/saml2/service-provider-metadata/" + registrationId;
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
// Create SamlClient instance for SAML logout
return new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);
}
/**
* Handles different error scenarios during logout. Will return a <code>String</code> containing
* the error request parameter.

View File

@ -369,18 +369,15 @@ public class UserService implements UserServiceInterface {
public void invalidateUserSessions(String username) {
String usernameP = "";
for (Object principal : sessionRegistry.getAllPrincipals()) {
for (SessionInformation sessionsInformation :
sessionRegistry.getAllSessions(principal, false)) {
if (principal instanceof UserDetails) {
UserDetails userDetails = (UserDetails) principal;
if (principal instanceof UserDetails userDetails) {
usernameP = userDetails.getUsername();
} else if (principal instanceof OAuth2User) {
OAuth2User oAuth2User = (OAuth2User) principal;
} else if (principal instanceof OAuth2User oAuth2User) {
usernameP = oAuth2User.getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
CustomSaml2AuthenticatedPrincipal saml2User =
(CustomSaml2AuthenticatedPrincipal) principal;
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) {
usernameP = saml2User.getName();
} else if (principal instanceof String) {
usernameP = (String) principal;
@ -394,6 +391,7 @@ public class UserService implements UserServiceInterface {
public String getCurrentUsername() {
Object principal = SecurityContextHolder.getContext().getAuthentication().getPrincipal();
if (principal instanceof UserDetails) {
return ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
@ -402,8 +400,6 @@ public class UserService implements UserServiceInterface {
applicationProperties.getSecurity().getOauth2().getUseAsUsername());
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
return ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) {
return (String) principal;
} else {
return principal.toString();
}

View File

@ -93,7 +93,7 @@ public class OAuth2Configuration {
.clientId(keycloak.getClientId())
.clientSecret(keycloak.getClientSecret())
.scope(keycloak.getScopes())
.userNameAttributeName(keycloak.getUseAsUsername().name())
.userNameAttributeName(keycloak.getUseAsUsername().getName())
.clientName(keycloak.getClientName())
.build())
: Optional.empty();
@ -124,7 +124,7 @@ public class OAuth2Configuration {
.authorizationUri(google.getAuthorizationUri())
.tokenUri(google.getTokenUri())
.userInfoUri(google.getUserInfoUri())
.userNameAttributeName(google.getUseAsUsername().name())
.userNameAttributeName(google.getUseAsUsername().getName())
.clientName(google.getClientName())
.redirectUri(REDIRECT_URI_PATH + google.getName())
.authorizationGrantType(AUTHORIZATION_CODE)
@ -157,7 +157,7 @@ public class OAuth2Configuration {
.authorizationUri(github.getAuthorizationUri())
.tokenUri(github.getTokenUri())
.userInfoUri(github.getUserInfoUri())
.userNameAttributeName(github.getUseAsUsername().name())
.userNameAttributeName(github.getUseAsUsername().getName())
.clientName(github.getClientName())
.redirectUri(REDIRECT_URI_PATH + github.getName())
.authorizationGrantType(AUTHORIZATION_CODE)
@ -184,7 +184,8 @@ public class OAuth2Configuration {
oauth.getClientId(),
oauth.getClientSecret(),
oauth.getScopes(),
UsernameAttribute.valueOf(oauth.getUseAsUsername()),
UsernameAttribute.valueOf(oauth.getUseAsUsername().toUpperCase()),
oauth.getLogoutUrl(),
null,
null,
null);

View File

@ -21,6 +21,8 @@ public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthentica
HttpServletResponse response,
AuthenticationException exception)
throws IOException {
log.error("Authentication error", exception);
if (exception instanceof Saml2AuthenticationException) {
Saml2Error error = ((Saml2AuthenticationException) exception).getSaml2Error();
getRedirectStrategy()
@ -32,6 +34,5 @@ public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthentica
response,
"/login?errorOAuth=not_authentication_provider_found");
}
log.error("Authentication error", exception);
}
}

View File

@ -110,13 +110,11 @@ public class CustomSaml2AuthenticationSuccessHandler
userService.processSSOPostLogin(username, saml2.getAutoCreateUser());
log.debug("Successfully processed authentication for user: {}", username);
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) {
log.debug(
"Invalid username detected for user: {}, redirecting to logout",
username);
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
}
}
} else {

View File

@ -20,7 +20,7 @@ import stirling.software.SPDF.model.User;
public class CustomSaml2ResponseAuthenticationConverter
implements Converter<ResponseToken, Saml2Authentication> {
private UserService userService;
private final UserService userService;
public CustomSaml2ResponseAuthenticationConverter(UserService userService) {
this.userService = userService;
@ -60,10 +60,10 @@ public class CustomSaml2ResponseAuthenticationConverter
Map<String, List<Object>> attributes = extractAttributes(assertion);
// Debug log with actual values
log.debug("Extracted SAML Attributes: " + attributes);
log.debug("Extracted SAML Attributes: {}", attributes);
// Try to get username/identifier in order of preference
String userIdentifier = null;
String userIdentifier;
if (hasAttribute(attributes, "username")) {
userIdentifier = getFirstAttributeValue(attributes, "username");
} else if (hasAttribute(attributes, "emailaddress")) {
@ -83,10 +83,8 @@ public class CustomSaml2ResponseAuthenticationConverter
SimpleGrantedAuthority simpleGrantedAuthority = new SimpleGrantedAuthority("ROLE_USER");
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null) {
simpleGrantedAuthority =
new SimpleGrantedAuthority(userService.findRole(user).getAuthority());
}
simpleGrantedAuthority =
new SimpleGrantedAuthority(userService.findRole(user).getAuthority());
}
List<String> sessionIndexes = new ArrayList<>();

View File

@ -2,8 +2,8 @@ package stirling.software.SPDF.config.security.saml2;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.UUID;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
@ -49,15 +49,23 @@ public class SAML2Configuration {
RelyingPartyRegistration rp =
RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId())
.signingX509Credentials(c -> c.add(signingCredential))
.entityId(samlConf.getIdpIssuer())
.singleLogoutServiceBinding(Saml2MessageBinding.POST)
.singleLogoutServiceLocation(samlConf.getIdpSingleLogoutUrl())
.authnRequestsSigned(true)
.assertingPartyMetadata(
metadata ->
metadata.entityId(samlConf.getIdpIssuer())
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.verificationX509Credentials(
c -> c.add(verificationCredential))
.singleSignOnServiceBinding(
Saml2MessageBinding.POST)
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.singleLogoutServiceBinding(
Saml2MessageBinding.POST)
.singleLogoutServiceLocation(
samlConf.getIdpSingleLogoutUrl())
.wantAuthnRequestsSigned(true))
.build();
return new InMemoryRelyingPartyRegistrationRepository(rp);
@ -73,9 +81,9 @@ public class SAML2Configuration {
customizer -> {
log.debug("Customizing SAML Authentication request");
AuthnRequest authnRequest = customizer.getAuthnRequest();
log.debug("AuthnRequest ID: {}", authnRequest.getID());
log.debug("AuthnRequest ID: {}", authnRequest.getID());
if (authnRequest.getID() == null) {
authnRequest.setID("ARQ" + UUID.randomUUID()); // fixme: SubjectConfirmationData@InResponseTo
authnRequest.setID("ARQ" + UUID.randomUUID());
}
log.debug("AuthnRequest new ID after set: {}", authnRequest.getID());
log.debug("AuthnRequest IssueInstant: {}", authnRequest.getIssueInstant());
@ -94,12 +102,11 @@ public class SAML2Configuration {
// Log headers
Collections.list(request.getHeaderNames())
.forEach(
headerName -> {
log.debug(
"Header - {}: {}",
headerName,
request.getHeader(headerName));
});
headerName ->
log.debug(
"Header - {}: {}",
headerName,
request.getHeader(headerName)));
// Log SAML specific parameters
log.debug("SAML Request Parameters:");
log.debug("SAMLRequest: {}", request.getParameter("SAMLRequest"));

View File

@ -124,7 +124,7 @@ public class UserController {
return new RedirectView("/change-creds?messageType=notAuthenticated", true);
}
Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName());
if (userOpt == null || userOpt.isEmpty()) {
if (userOpt.isEmpty()) {
return new RedirectView("/change-creds?messageType=userNotFound", true);
}
User user = userOpt.get();
@ -152,7 +152,7 @@ public class UserController {
return new RedirectView("/account?messageType=notAuthenticated", true);
}
Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName());
if (userOpt == null || userOpt.isEmpty()) {
if (userOpt.isEmpty()) {
return new RedirectView("/account?messageType=userNotFound", true);
}
User user = userOpt.get();
@ -174,7 +174,7 @@ public class UserController {
for (Map.Entry<String, String[]> entry : paramMap.entrySet()) {
updates.put(entry.getKey(), entry.getValue()[0]);
}
log.debug("Processed updates: " + updates);
log.debug("Processed updates: {}", updates);
// Assuming you have a method in userService to update the settings for a user
userService.updateUserSettings(principal.getName(), updates);
// Redirect to a page of your choice after updating
@ -197,7 +197,7 @@ public class UserController {
Optional<User> userOpt = userService.findByUsernameIgnoreCase(username);
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null && user.getUsername().equalsIgnoreCase(username)) {
if (user.getUsername().equalsIgnoreCase(username)) {
return new RedirectView("/addUsers?messageType=usernameExists", true);
}
}
@ -274,7 +274,7 @@ public class UserController {
Authentication authentication)
throws SQLException, UnsupportedProviderException {
Optional<User> userOpt = userService.findByUsernameIgnoreCase(username);
if (!userOpt.isPresent()) {
if (userOpt.isEmpty()) {
return new RedirectView("/addUsers?messageType=userNotFound", true);
}
if (!userService.usernameExistsIgnoreCase(username)) {
@ -293,7 +293,7 @@ public class UserController {
List<Object> principals = sessionRegistry.getAllPrincipals();
String userNameP = "";
for (Object principal : principals) {
List<SessionInformation> sessionsInformations =
List<SessionInformation> sessionsInformation =
sessionRegistry.getAllSessions(principal, false);
if (principal instanceof UserDetails) {
userNameP = ((UserDetails) principal).getUsername();
@ -305,8 +305,8 @@ public class UserController {
userNameP = (String) principal;
}
if (userNameP.equalsIgnoreCase(username)) {
for (SessionInformation sessionsInformation : sessionsInformations) {
sessionRegistry.expireSession(sessionsInformation.getSessionId());
for (SessionInformation sessionInfo : sessionsInformation) {
sessionRegistry.expireSession(sessionInfo.getSessionId());
}
}
}

View File

@ -111,7 +111,6 @@ public class AccountWebController {
}
// Remove any null keys/values from the providerList
// providerList might be empty on browser side? Button not showing up
providerList
.entrySet()
.removeIf(entry -> entry.getKey() == null || entry.getValue() == null);
@ -216,13 +215,11 @@ public class AccountWebController {
.plus(maxInactiveInterval, ChronoUnit.SECONDS);
if (now.isAfter(expirationTime)) {
sessionPersistentRegistry.expireSession(sessionEntity.getSessionId());
hasActiveSession = false;
} else {
hasActiveSession = !sessionEntity.isExpired();
}
lastRequest = sessionEntity.getLastRequest();
} else {
hasActiveSession = false;
// No session, set default last request time
lastRequest = new Date(0);
}
@ -259,53 +256,41 @@ public class AccountWebController {
})
.collect(Collectors.toList());
String messageType = request.getParameter("messageType");
String deleteMessage = null;
String deleteMessage;
if (messageType != null) {
switch (messageType) {
case "deleteCurrentUser":
deleteMessage = "deleteCurrentUserMessage";
break;
case "deleteUsernameExists":
deleteMessage = "deleteUsernameExistsMessage";
break;
default:
break;
}
deleteMessage =
switch (messageType) {
case "deleteCurrentUser" -> "deleteCurrentUserMessage";
case "deleteUsernameExists" -> "deleteUsernameExistsMessage";
default -> null;
};
model.addAttribute("deleteMessage", deleteMessage);
String addMessage = null;
switch (messageType) {
case "usernameExists":
addMessage = "usernameExistsMessage";
break;
case "invalidUsername":
addMessage = "invalidUsernameMessage";
break;
case "invalidPassword":
addMessage = "invalidPasswordMessage";
break;
default:
break;
}
String addMessage;
addMessage =
switch (messageType) {
case "usernameExists" -> "usernameExistsMessage";
case "invalidUsername" -> "invalidUsernameMessage";
case "invalidPassword" -> "invalidPasswordMessage";
default -> null;
};
model.addAttribute("addMessage", addMessage);
}
String changeMessage = null;
String changeMessage;
if (messageType != null) {
switch (messageType) {
case "userNotFound":
changeMessage = "userNotFoundMessage";
break;
case "downgradeCurrentUser":
changeMessage = "downgradeCurrentUserMessage";
break;
case "disabledCurrentUser":
changeMessage = "disabledCurrentUserMessage";
break;
default:
changeMessage = messageType;
break;
}
changeMessage =
switch (messageType) {
case "userNotFound" -> "userNotFoundMessage";
case "downgradeCurrentUser" -> "downgradeCurrentUserMessage";
case "disabledCurrentUser" -> "disabledCurrentUserMessage";
default -> messageType;
};
model.addAttribute("changeMessage", changeMessage);
}
model.addAttribute("users", sortedUsers);
model.addAttribute("currentUsername", authentication.getName());
model.addAttribute("roleDetails", roleDetails);
@ -326,39 +311,35 @@ public class AccountWebController {
if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal();
String username = null;
// Retrieve username and other attributes and add login attributes to the model
if (principal instanceof UserDetails userDetails) {
// Retrieve username and other attributes
username = userDetails.getUsername();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", false);
}
if (principal instanceof OAuth2User userDetails) {
// Retrieve username and other attributes
username = userDetails.getName();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", true);
}
if (principal instanceof CustomSaml2AuthenticatedPrincipal userDetails) {
// Retrieve username and other attributes
username = userDetails.getName();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", true);
model.addAttribute("saml2Login", true);
}
if (username != null) {
// Fetch user details from the database, assuming findByUsername method exists
// Fetch user details from the database
Optional<User> user = userRepository.findByUsernameIgnoreCaseWithSettings(username);
if (user.isEmpty()) {
return "redirect:/error";
}
// Convert settings map to JSON string
ObjectMapper objectMapper = new ObjectMapper();
String settingsJson;
try {
settingsJson = objectMapper.writeValueAsString(user.get().getSettings());
} catch (JsonProcessingException e) {
// Handle JSON conversion error
log.error("exception", e);
log.error("Error converting settings map", e);
return "redirect:/error";
}
@ -372,7 +353,7 @@ public class AccountWebController {
case "invalidUsername" -> messageType = "invalidUsernameMessage";
}
}
// Add attributes to the model
model.addAttribute("username", username);
model.addAttribute("messageType", messageType);
model.addAttribute("role", user.get().getRolesAsString());
@ -395,19 +376,12 @@ public class AccountWebController {
}
if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal();
if (principal instanceof UserDetails) {
// Cast the principal object to UserDetails
UserDetails userDetails = (UserDetails) principal;
// Retrieve username and other attributes
if (principal instanceof UserDetails userDetails) {
String username = userDetails.getUsername();
// Fetch user details from the database
Optional<User> user =
userRepository
.findByUsernameIgnoreCase( // Assuming findByUsername method exists
username);
if (!user.isPresent()) {
// Handle error appropriately
// Example redirection in case of error
Optional<User> user = userRepository.findByUsernameIgnoreCase(username);
if (user.isEmpty()) {
// Handle error appropriately, example redirection in case of error
return "redirect:/error";
}
String messageType = request.getParameter("messageType");
@ -430,7 +404,7 @@ public class AccountWebController {
}
model.addAttribute("messageType", messageType);
}
// Add attributes to the model
model.addAttribute("username", username);
}
} else {

View File

@ -227,6 +227,7 @@ public class ApplicationProperties {
private Collection<String> scopes = new ArrayList<>();
private String provider;
private Client client = new Client();
private String logoutUrl;
public void setScopes(String scopes) {
List<String> scopesList =

View File

@ -28,6 +28,7 @@ public class GitHubProvider extends Provider {
clientSecret,
scopes,
useAsUsername != null ? useAsUsername : UsernameAttribute.LOGIN,
null,
AUTHORIZATION_URI,
TOKEN_URI,
USER_INFO_URI);

View File

@ -29,6 +29,7 @@ public class GoogleProvider extends Provider {
clientSecret,
scopes,
useAsUsername,
null,
AUTHORIZATION_URI,
TOKEN_URI,
USER_INFO_URI);

View File

@ -28,6 +28,7 @@ public class KeycloakProvider extends Provider {
useAsUsername,
null,
null,
null,
null);
}

View File

@ -25,6 +25,7 @@ public class Provider {
private String clientSecret;
private Collection<String> scopes;
private UsernameAttribute useAsUsername;
private String logoutUrl;
private String authorizationUri;
private String tokenUri;
private String userInfoUri;
@ -37,6 +38,7 @@ public class Provider {
String clientSecret,
Collection<String> scopes,
UsernameAttribute useAsUsername,
String logoutUrl,
String authorizationUri,
String tokenUri,
String userInfoUri) {
@ -48,6 +50,7 @@ public class Provider {
this.scopes = scopes == null ? new ArrayList<>() : scopes;
this.useAsUsername =
useAsUsername != null ? validateUsernameAttribute(useAsUsername) : EMAIL;
this.logoutUrl = logoutUrl;
this.authorizationUri = authorizationUri;
this.tokenUri = tokenUri;
this.userInfoUri = userInfoUri;

View File

@ -34,7 +34,7 @@
</th:block>
<!-- Change Username Form -->
<th:block th:if="${!oAuth2Login}">
<th:block th:if="not ${oAuth2Login} or not ${saml2Login}">
<h4 th:text="#{account.changeUsername}">Change Username?</h4>
<form id="formsavechangeusername" class="bg-card mt-4 mb-4" th:action="@{'/api/v1/user/change-username'}" method="post">
<div class="mb-3">
@ -53,7 +53,7 @@
</th:block>
<!-- Change Password Form -->
<th:block th:if="${!oAuth2Login}">
<th:block th:if="not ${oAuth2Login} or not ${saml2Login}">
<h4 th:text="#{account.changePassword}">Change Password?</h4>
<form id="formsavechangepassword" class="bg-card mt-4 mb-4" th:action="@{'/api/v1/user/change-password'}" method="post">
<div class="mb-3">