Add streaming to Engine orchestrator (#6094)

# Description of Changes
Adds a streaming endpoint to the Java AI orchestrator
(`/api/v1/ai/orchestrate/stream` in addition to the existing
`/api/v1/ai/orchestrate`). This allows the caller to get updates of what
stage of orchestration is being run at the time so UIs can give the user
feedback.

Also contains some dubious Gradle changes to suppress errors coming from
Spotless, when it crashes in Google stuff. I'm not sure if that's
appropriate to add, feel free to ask for changes in review.
This commit is contained in:
James Brunton
2026-04-17 10:01:08 +01:00
committed by GitHub
parent 97ca85d878
commit 688f7f2013
13 changed files with 326 additions and 40 deletions

View File

@@ -7,6 +7,8 @@ spotless {
target 'src/**/java/**/*.java'
targetExclude 'src/main/java/org/apache/**'
googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false)
// google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25
suppressLintsFor { setStep('google-java-format') }
importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling")
trimTrailingWhitespace()

View File

@@ -14,6 +14,8 @@ spotless {
target 'src/**/java/**/*.java'
targetExclude 'src/main/resources/static/**', 'src/main/java/org/apache/**'
googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false)
// google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25
suppressLintsFor { setStep('google-java-format') }
importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling")
trimTrailingWhitespace()

View File

@@ -16,6 +16,8 @@ spotless {
target 'src/**/java/**/*.java'
targetExclude 'src/main/java/org/apache/**'
googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false)
// google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25
suppressLintsFor { setStep('google-java-format') }
importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling")
trimTrailingWhitespace()

View File

@@ -48,4 +48,12 @@ public class AsyncConfig {
adapter.setTaskDecorator(new MDCContextTaskDecorator());
return adapter;
}
@Bean(name = "aiStreamExecutor")
public Executor aiStreamExecutor() {
TaskExecutorAdapter adapter =
new TaskExecutorAdapter(Executors.newVirtualThreadPerTaskExecutor());
adapter.setTaskDecorator(new MDCContextTaskDecorator());
return adapter;
}
}

View File

@@ -1,7 +1,10 @@
package stirling.software.proprietary.controller.api;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.Executor;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
@@ -12,6 +15,7 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import io.swagger.v3.oas.annotations.Hidden;
import io.swagger.v3.oas.annotations.Operation;
@@ -19,7 +23,6 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.model.api.ai.AiWorkflowRequest;
@@ -34,7 +37,6 @@ import tools.jackson.databind.ObjectMapper;
@Slf4j
@RestController
@RequestMapping("/api/v1/ai")
@RequiredArgsConstructor
@Hidden
@Tag(name = "AI Engine", description = "Endpoints for AI-powered PDF workflows")
public class AiEngineController {
@@ -42,6 +44,18 @@ public class AiEngineController {
private final AiEngineClient aiEngineClient;
private final AiWorkflowService aiWorkflowService;
private final ObjectMapper objectMapper;
private final Executor aiStreamExecutor;
public AiEngineController(
AiEngineClient aiEngineClient,
AiWorkflowService aiWorkflowService,
ObjectMapper objectMapper,
@Qualifier("aiStreamExecutor") Executor aiStreamExecutor) {
this.aiEngineClient = aiEngineClient;
this.aiWorkflowService = aiWorkflowService;
this.objectMapper = objectMapper;
this.aiStreamExecutor = aiStreamExecutor;
}
@GetMapping("/health")
@Operation(
@@ -62,6 +76,49 @@ public class AiEngineController {
return ResponseEntity.ok(aiWorkflowService.orchestrate(request));
}
@PostMapping(value = "/orchestrate/stream", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
@Operation(
summary = "Run an AI workflow with streaming progress",
description =
"Accepts a PDF upload and a user message, returns SSE events with progress"
+ " updates followed by the final AI workflow result")
public SseEmitter orchestrateStream(@Valid @ModelAttribute AiWorkflowRequest request) {
SseEmitter emitter = new SseEmitter(180_000L);
emitter.onTimeout(
() -> {
log.warn("SSE emitter timed out for AI orchestration stream");
emitter.complete();
});
emitter.onError(e -> log.warn("SSE emitter error for AI orchestration stream", e));
aiStreamExecutor.execute(() -> runOrchestrationStream(request, emitter));
return emitter;
}
private void runOrchestrationStream(AiWorkflowRequest request, SseEmitter emitter) {
try {
AiWorkflowResponse result =
aiWorkflowService.orchestrate(
request, progress -> sendEvent(emitter, "progress", progress));
sendEvent(emitter, "result", result);
emitter.complete();
} catch (Exception e) {
log.error("AI orchestration stream failed", e);
sendEvent(emitter, "error", Map.of("message", e.getMessage()));
emitter.completeWithError(e);
}
}
private void sendEvent(SseEmitter emitter, String name, Object data) {
try {
emitter.send(SseEmitter.event().name(name).data(data, MediaType.APPLICATION_JSON));
} catch (IOException e) {
log.debug("Failed to send SSE event (client may have disconnected)", e);
}
}
@PostMapping(value = "/pdf/edit", consumes = MediaType.APPLICATION_JSON_VALUE)
@Operation(
summary = "Generate a PDF edit plan",

View File

@@ -0,0 +1,33 @@
package stirling.software.proprietary.model.api.ai;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
/** Progress phases emitted during AI workflow orchestration. */
public enum AiWorkflowPhase {
ANALYZING("analyzing"),
CALLING_ENGINE("calling_engine"),
EXTRACTING_CONTENT("extracting_content"),
PROCESSING("processing");
private final String value;
AiWorkflowPhase(String value) {
this.value = value;
}
@JsonValue
public String getValue() {
return value;
}
@JsonCreator
public static AiWorkflowPhase fromValue(String value) {
for (AiWorkflowPhase phase : values()) {
if (phase.value.equals(value)) {
return phase;
}
}
throw new IllegalArgumentException("Unknown AI workflow phase: " + value);
}
}

View File

@@ -0,0 +1,15 @@
package stirling.software.proprietary.model.api.ai;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class AiWorkflowProgressEvent {
private AiWorkflowPhase phase;
private long timestamp;
public static AiWorkflowProgressEvent of(AiWorkflowPhase phase) {
return new AiWorkflowProgressEvent(phase, System.currentTimeMillis());
}
}

View File

@@ -20,6 +20,8 @@ import stirling.software.common.util.ExceptionUtils;
import stirling.software.proprietary.model.api.ai.AiWorkflowFileInput;
import stirling.software.proprietary.model.api.ai.AiWorkflowFileRequest;
import stirling.software.proprietary.model.api.ai.AiWorkflowOutcome;
import stirling.software.proprietary.model.api.ai.AiWorkflowPhase;
import stirling.software.proprietary.model.api.ai.AiWorkflowProgressEvent;
import stirling.software.proprietary.model.api.ai.AiWorkflowRequest;
import stirling.software.proprietary.model.api.ai.AiWorkflowResponse;
import stirling.software.proprietary.service.PdfContentExtractor.LoadedFile;
@@ -38,6 +40,13 @@ public class AiWorkflowService {
private final PdfContentExtractor pdfContentExtractor;
private final ObjectMapper objectMapper;
@FunctionalInterface
public interface ProgressListener {
void onProgress(AiWorkflowProgressEvent event);
}
private static final ProgressListener NOOP_LISTENER = event -> {};
private sealed interface WorkflowState {
record Pending(WorkflowTurnRequest request) implements WorkflowState {}
@@ -45,6 +54,11 @@ public class AiWorkflowService {
}
public AiWorkflowResponse orchestrate(AiWorkflowRequest request) throws IOException {
return orchestrate(request, NOOP_LISTENER);
}
public AiWorkflowResponse orchestrate(AiWorkflowRequest request, ProgressListener listener)
throws IOException {
validateRequest(request);
Map<String, MultipartFile> filesByName = new LinkedHashMap<>();
@@ -57,19 +71,24 @@ public class AiWorkflowService {
initialRequest.setUserMessage(request.getUserMessage().trim());
initialRequest.setFileNames(new ArrayList<>(filesByName.keySet()));
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.ANALYZING));
WorkflowState state = new WorkflowState.Pending(initialRequest);
while (state instanceof WorkflowState.Pending pending) {
state = advance(pending.request(), filesByName);
state = advance(pending.request(), filesByName, listener);
}
return ((WorkflowState.Terminal) state).response();
}
private WorkflowState advance(
WorkflowTurnRequest request, Map<String, MultipartFile> filesByName)
WorkflowTurnRequest request,
Map<String, MultipartFile> filesByName,
ProgressListener listener)
throws IOException {
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.CALLING_ENGINE));
AiWorkflowResponse response = invokeOrchestrator(request);
return switch (response.getOutcome()) {
case NEED_CONTENT -> onNeedContent(response, filesByName, request);
case NEED_CONTENT -> onNeedContent(response, filesByName, request, listener);
case ANSWER,
NOT_FOUND,
PLAN,
@@ -86,7 +105,8 @@ public class AiWorkflowService {
private WorkflowState onNeedContent(
AiWorkflowResponse response,
Map<String, MultipartFile> filesByName,
WorkflowTurnRequest request)
WorkflowTurnRequest request,
ProgressListener listener)
throws IOException {
if (!request.getArtifacts().isEmpty()) {
return new WorkflowState.Terminal(
@@ -119,6 +139,8 @@ public class AiWorkflowService {
Collectors.toMap(
AiWorkflowFileRequest::getFileName, r -> r));
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.EXTRACTING_CONTENT));
List<LoadedFile> loadedFiles = new ArrayList<>();
try {
for (String fileName : fileNamesToLoad) {
@@ -133,6 +155,8 @@ public class AiWorkflowService {
response.getMaxPages(),
response.getMaxCharacters());
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.PROCESSING));
WorkflowTurnRequest nextRequest = new WorkflowTurnRequest();
nextRequest.setUserMessage(request.getUserMessage());
nextRequest.setFileNames(request.getFileNames());

View File

@@ -2079,6 +2079,13 @@ keywords = "Keywords: odd, even"
numbers = "Numbers/ranges: 5, 10-20"
progressions = "Progressions: 3n, 4n+1"
[chat.progress]
thinking = "Thinking..."
analyzing = "Analysing your request..."
calling_engine = "AI is thinking..."
extracting_content = "Extracting content from your documents..."
processing = "Processing extracted content..."
[certSign]
allSigned = "All participants have signed. Ready to finalize."
awaitingSignatures = "Awaiting signatures"

View File

@@ -12,3 +12,8 @@ export function setupApiInterceptors(client: AxiosInstance): void {
(error) => Promise.reject(error),
);
}
/** Auth headers for raw fetch() calls (SSE streams, etc.). Proprietary overrides with JWT + XSRF. */
export function getAuthHeaders(): Record<string, string> {
return {};
}

View File

@@ -95,20 +95,29 @@ async function refreshAuthToken(client: AxiosInstance): Promise<string> {
}
}
/** Auth headers for raw fetch() calls (SSE streams, etc.). */
export function getAuthHeaders(): Record<string, string> {
const headers: Record<string, string> = {};
const jwt = getJwtTokenFromStorage();
if (jwt) {
headers["Authorization"] = `Bearer ${jwt}`;
}
const xsrf = getXsrfToken();
if (xsrf) {
headers["X-XSRF-TOKEN"] = xsrf;
}
return headers;
}
export function setupApiInterceptors(client: AxiosInstance): void {
// Install request interceptor to add JWT token
client.interceptors.request.use(
(config) => {
const jwtToken = getJwtTokenFromStorage();
const xsrfToken = getXsrfToken();
if (jwtToken && !config.headers.Authorization) {
config.headers.Authorization = `Bearer ${jwtToken}`;
console.debug("[API Client] Added JWT token from localStorage to Authorization header");
}
if (xsrfToken && !config.headers["X-XSRF-TOKEN"]) {
config.headers["X-XSRF-TOKEN"] = xsrfToken;
const authHeaders = getAuthHeaders();
for (const [key, value] of Object.entries(authHeaders)) {
if (!config.headers[key]) {
config.headers[key] = value;
}
}
return config;

View File

@@ -1,5 +1,6 @@
import { createContext, useContext, useReducer, useCallback, type ReactNode } from "react";
import { createContext, useContext, useReducer, useCallback, useRef, type ReactNode } from "react";
import { useAllFiles } from "@app/contexts/FileContext";
import { getAuthHeaders } from "@app/services/apiClientSetup";
export interface ChatMessage {
id: string;
@@ -8,6 +9,13 @@ export interface ChatMessage {
timestamp: number;
}
export enum AiWorkflowPhase {
ANALYZING = "analyzing",
CALLING_ENGINE = "calling_engine",
EXTRACTING_CONTENT = "extracting_content",
PROCESSING = "processing",
}
type AiWorkflowOutcome =
| "answer"
| "not_found"
@@ -37,11 +45,13 @@ interface ChatState {
messages: ChatMessage[];
isOpen: boolean;
isLoading: boolean;
progressPhase: AiWorkflowPhase | null;
}
type ChatAction =
| { type: "ADD_MESSAGE"; message: ChatMessage }
| { type: "SET_LOADING"; loading: boolean }
| { type: "SET_PROGRESS"; phase: AiWorkflowPhase | null }
| { type: "TOGGLE_OPEN" }
| { type: "SET_OPEN"; open: boolean };
@@ -51,6 +61,8 @@ function chatReducer(state: ChatState, action: ChatAction): ChatState {
return { ...state, messages: [...state.messages, action.message] };
case "SET_LOADING":
return { ...state, isLoading: action.loading };
case "SET_PROGRESS":
return { ...state, progressPhase: action.phase };
case "TOGGLE_OPEN":
return { ...state, isOpen: !state.isOpen };
case "SET_OPEN":
@@ -85,10 +97,67 @@ function formatWorkflowResponse(data: AiWorkflowResponse): string {
}
}
/**
* Parses an SSE text stream and invokes callbacks for each named event.
*/
async function consumeSSEStream(
response: Response,
handlers: {
onProgress: (data: { phase: string; timestamp: number }) => void;
onResult: (data: AiWorkflowResponse) => void;
onError: (data: { message: string }) => void;
},
) {
const reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
let currentEvent = "";
for (;;) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
// SSE frames are separated by double newlines
let boundary = buffer.indexOf("\n\n");
while (boundary !== -1) {
const frame = buffer.slice(0, boundary);
buffer = buffer.slice(boundary + 2);
let dataPayload = "";
for (const line of frame.split("\n")) {
if (line.startsWith("event:")) {
currentEvent = line.slice(6).trim();
} else if (line.startsWith("data:")) {
dataPayload += line.slice(5);
}
}
if (dataPayload) {
try {
const parsed = JSON.parse(dataPayload);
if (currentEvent === "progress") {
handlers.onProgress(parsed);
} else if (currentEvent === "result") {
handlers.onResult(parsed);
} else if (currentEvent === "error") {
handlers.onError(parsed);
}
} catch {
// Skip malformed JSON frames
}
}
currentEvent = "";
boundary = buffer.indexOf("\n\n");
}
}
}
interface ChatContextValue {
messages: ChatMessage[];
isOpen: boolean;
isLoading: boolean;
progressPhase: AiWorkflowPhase | null;
toggleOpen: () => void;
setOpen: (open: boolean) => void;
sendMessage: (content: string) => Promise<void>;
@@ -100,17 +169,24 @@ const initialState: ChatState = {
messages: [],
isOpen: false,
isLoading: false,
progressPhase: null,
};
export function ChatProvider({ children }: { children: ReactNode }) {
const [state, dispatch] = useReducer(chatReducer, initialState);
const { files: activeFiles } = useAllFiles();
const abortRef = useRef<AbortController | null>(null);
const toggleOpen = useCallback(() => dispatch({ type: "TOGGLE_OPEN" }), []);
const setOpen = useCallback((open: boolean) => dispatch({ type: "SET_OPEN", open }), []);
const sendMessage = useCallback(
async (content: string) => {
// Abort any in-flight request
abortRef.current?.abort();
const controller = new AbortController();
abortRef.current = controller;
const userMessage: ChatMessage = {
id: crypto.randomUUID(),
role: "user",
@@ -119,6 +195,7 @@ export function ChatProvider({ children }: { children: ReactNode }) {
};
dispatch({ type: "ADD_MESSAGE", message: userMessage });
dispatch({ type: "SET_LOADING", loading: true });
dispatch({ type: "SET_PROGRESS", phase: null });
try {
const formData = new FormData();
@@ -127,34 +204,73 @@ export function ChatProvider({ children }: { children: ReactNode }) {
formData.append(`fileInputs[${i}].fileInput`, file);
});
const response = await fetch("/api/v1/ai/orchestrate", {
const response = await fetch("/api/v1/ai/orchestrate/stream", {
method: "POST",
body: formData,
headers: getAuthHeaders(),
credentials: "include",
signal: controller.signal,
});
if (!response.ok) {
throw new Error(`AI engine request failed: ${response.status}`);
}
const data: AiWorkflowResponse = await response.json();
const replyContent = formatWorkflowResponse(data);
const assistantMessage: ChatMessage = {
id: crypto.randomUUID(),
role: "assistant",
content: replyContent,
timestamp: Date.now(),
};
dispatch({ type: "ADD_MESSAGE", message: assistantMessage });
} catch {
const errorMessage: ChatMessage = {
id: crypto.randomUUID(),
role: "assistant",
content: "Failed to get a response. The AI engine may not be available yet.",
timestamp: Date.now(),
};
dispatch({ type: "ADD_MESSAGE", message: errorMessage });
let receivedResult = false;
await consumeSSEStream(response, {
onProgress: (data) => {
dispatch({ type: "SET_PROGRESS", phase: data.phase as AiWorkflowPhase });
},
onResult: (data) => {
receivedResult = true;
dispatch({ type: "SET_PROGRESS", phase: null });
const replyContent = formatWorkflowResponse(data);
dispatch({
type: "ADD_MESSAGE",
message: {
id: crypto.randomUUID(),
role: "assistant",
content: replyContent,
timestamp: Date.now(),
},
});
},
onError: (data) => {
receivedResult = true;
dispatch({ type: "SET_PROGRESS", phase: null });
dispatch({
type: "ADD_MESSAGE",
message: {
id: crypto.randomUUID(),
role: "assistant",
content: data.message || "Something went wrong.",
timestamp: Date.now(),
},
});
},
});
if (!receivedResult) {
throw new Error("Stream ended without a result");
}
} catch (e) {
if ((e as Error).name === "AbortError") return;
dispatch({ type: "SET_PROGRESS", phase: null });
dispatch({
type: "ADD_MESSAGE",
message: {
id: crypto.randomUUID(),
role: "assistant",
content: "Failed to get a response. The AI engine may not be available yet.",
timestamp: Date.now(),
},
});
} finally {
dispatch({ type: "SET_LOADING", loading: false });
if (abortRef.current === controller) {
abortRef.current = null;
}
}
},
[activeFiles],
@@ -166,6 +282,7 @@ export function ChatProvider({ children }: { children: ReactNode }) {
messages: state.messages,
isOpen: state.isOpen,
isLoading: state.isLoading,
progressPhase: state.progressPhase,
toggleOpen,
setOpen,
sendMessage,

View File

@@ -1,5 +1,6 @@
import { useRef, useEffect, useState, type KeyboardEvent } from "react";
import { ActionIcon, ScrollArea, TextInput, Stack, Text, Paper, Box, Transition } from "@mantine/core";
import { useTranslation } from "react-i18next";
import { ActionIcon, ScrollArea, TextInput, Stack, Text, Paper, Box, Transition, Loader, Group } from "@mantine/core";
import SendIcon from "@mui/icons-material/Send";
import ChatBubbleOutlineIcon from "@mui/icons-material/ChatBubbleOutline";
import CloseIcon from "@mui/icons-material/Close";
@@ -19,7 +20,8 @@ function ChatMessageBubble({ role, content }: { role: "user" | "assistant"; cont
}
export function ChatPanel() {
const { messages, isOpen, isLoading, toggleOpen, sendMessage } = useChat();
const { t } = useTranslation();
const { messages, isOpen, isLoading, progressPhase, toggleOpen, sendMessage } = useChat();
const [input, setInput] = useState("");
const scrollRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
@@ -95,9 +97,12 @@ export function ChatPanel() {
{isLoading && (
<div className="chat-message chat-message-assistant">
<Paper className="chat-bubble chat-bubble-assistant" p="xs" radius="md">
<Text size="sm" c="dimmed">
Thinking...
</Text>
<Group gap="xs" wrap="nowrap">
<Loader size="xs" type="dots" />
<Text size="sm" c="dimmed">
{progressPhase ? t(`chat.progress.${progressPhase}`) : t("chat.progress.thinking")}
</Text>
</Group>
</Paper>
</div>
)}