diff --git a/app/core/src/main/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessor.java b/app/core/src/main/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessor.java index f3a926de8..a0bb5b66b 100644 --- a/app/core/src/main/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessor.java +++ b/app/core/src/main/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessor.java @@ -18,10 +18,10 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.http.*; -import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.stereotype.Service; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RequestCallback; import org.springframework.web.client.RestTemplate; import org.springframework.web.multipart.MultipartFile; @@ -287,17 +287,19 @@ public class PipelineProcessor { // Set up headers, including API key HttpHeaders headers = new HttpHeaders(); String apiKey = getApiKeyForUser(); - headers.add("X-API-KEY", apiKey); - headers.setContentType(MediaType.MULTIPART_FORM_DATA); + if (apiKey != null && !apiKey.isEmpty()) { + headers.add("X-API-KEY", apiKey); + } + // Let the message converter set the multipart boundary/content type + HttpEntity> entity = new HttpEntity<>(body, headers); + + RequestCallback requestCallback = + restTemplate.httpEntityCallback(entity, Resource.class /* response type hint */); return restTemplate.execute( url, HttpMethod.POST, - request -> { - request.getHeaders().putAll(headers); - new FormHttpMessageConverter() - .write(body, MediaType.MULTIPART_FORM_DATA, request); - }, + requestCallback, response -> { try { TempFile tempFile = tempFileManager.createManagedTempFile("pipeline"); diff --git a/app/core/src/test/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessorTest.java b/app/core/src/test/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessorTest.java index 0811b7de6..d58770f45 100644 --- a/app/core/src/test/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessorTest.java +++ b/app/core/src/test/java/stirling/software/SPDF/controller/api/pipeline/PipelineProcessorTest.java @@ -4,6 +4,8 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.List; @@ -13,12 +15,16 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.MockedConstruction; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import jakarta.servlet.ServletContext; @@ -126,6 +132,90 @@ class PipelineProcessorTest { Files.deleteIfExists(tempPath); } + @Test + void sendWebRequestDoesNotForceContentType() throws Exception { + MultiValueMap body = new LinkedMultiValueMap<>(); + body.add( + "fileInput", + new ByteArrayResource("data".getBytes(StandardCharsets.UTF_8)) { + @Override + public String getFilename() { + return "input.pdf"; + } + }); + + Path tempPath = Files.createTempFile("pipeline-test", ".tmp"); + var tempFile = mock(stirling.software.common.util.TempFile.class); + when(tempFile.getPath()).thenReturn(tempPath); + when(tempFile.getFile()).thenReturn(tempPath.toFile()); + when(tempFileManager.createManagedTempFile("pipeline")).thenReturn(tempFile); + + var capturedHeaders = new org.springframework.http.HttpHeaders[1]; + + try (MockedConstruction ignored = + mockConstruction( + org.springframework.web.client.RestTemplate.class, + (mock, context) -> { + when(mock.httpEntityCallback(any(), eq(Resource.class))) + .thenAnswer( + invocation -> { + var entity = invocation.getArgument(0); + capturedHeaders[0] = + ((org.springframework.http.HttpEntity) + entity) + .getHeaders(); + return (org.springframework.web.client + .RequestCallback) + request -> {}; + }); + + when(mock.execute( + anyString(), + eq(org.springframework.http.HttpMethod.POST), + any(), + any())) + .thenAnswer( + invocation -> { + @SuppressWarnings("unchecked") + var extractor = + (org.springframework.web.client + .ResponseExtractor< + ResponseEntity>) + invocation.getArgument(3); + ClientHttpResponse response = + mock(ClientHttpResponse.class); + when(response.getBody()) + .thenReturn( + new ByteArrayInputStream( + "ok" + .getBytes( + StandardCharsets + .UTF_8))); + var headers = + new org.springframework.http.HttpHeaders(); + headers.add( + org.springframework.http.HttpHeaders + .CONTENT_DISPOSITION, + "attachment; filename=\"out.pdf\""); + when(response.getHeaders()).thenReturn(headers); + lenient() + .when(response.getStatusCode()) + .thenReturn(HttpStatus.OK); + return extractor.extractData(response); + }); + })) { + ResponseEntity response = + pipelineProcessor.sendWebRequest("http://localhost/api", body); + + assertNotNull(response); + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertNotNull(response.getBody()); + assertNull(capturedHeaders[0].getContentType()); + } finally { + Files.deleteIfExists(tempPath); + } + } + private static class MyFileByteArrayResource extends ByteArrayResource { public MyFileByteArrayResource() { super("data".getBytes());