From 688f7f2013bf3560d06e81350e8b1f01491e18c6 Mon Sep 17 00:00:00 2001 From: James Brunton Date: Fri, 17 Apr 2026 10:01:08 +0100 Subject: [PATCH] Add streaming to Engine orchestrator (#6094) # Description of Changes Adds a streaming endpoint to the Java AI orchestrator (`/api/v1/ai/orchestrate/stream` in addition to the existing `/api/v1/ai/orchestrate`). This allows the caller to get updates of what stage of orchestration is being run at the time so UIs can give the user feedback. Also contains some dubious Gradle changes to suppress errors coming from Spotless, when it crashes in Google stuff. I'm not sure if that's appropriate to add, feel free to ask for changes in review. --- app/common/build.gradle | 2 + app/core/build.gradle | 2 + app/proprietary/build.gradle | 2 + .../proprietary/config/AsyncConfig.java | 8 + .../controller/api/AiEngineController.java | 61 ++++++- .../model/api/ai/AiWorkflowPhase.java | 33 ++++ .../model/api/ai/AiWorkflowProgressEvent.java | 15 ++ .../service/AiWorkflowService.java | 32 +++- .../public/locales/en-GB/translation.toml | 7 + frontend/src/core/services/apiClientSetup.ts | 5 + .../proprietary/services/apiClientSetup.ts | 29 ++-- .../components/chat/ChatContext.tsx | 155 +++++++++++++++--- .../prototypes/components/chat/ChatPanel.tsx | 15 +- 13 files changed, 326 insertions(+), 40 deletions(-) create mode 100644 app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowPhase.java create mode 100644 app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowProgressEvent.java diff --git a/app/common/build.gradle b/app/common/build.gradle index 3cd6bae741..41d17273f5 100644 --- a/app/common/build.gradle +++ b/app/common/build.gradle @@ -7,6 +7,8 @@ spotless { target 'src/**/java/**/*.java' targetExclude 'src/main/java/org/apache/**' googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false) + // google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25 + suppressLintsFor { setStep('google-java-format') } importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling") trimTrailingWhitespace() diff --git a/app/core/build.gradle b/app/core/build.gradle index 277792314e..dbb45669ab 100644 --- a/app/core/build.gradle +++ b/app/core/build.gradle @@ -14,6 +14,8 @@ spotless { target 'src/**/java/**/*.java' targetExclude 'src/main/resources/static/**', 'src/main/java/org/apache/**' googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false) + // google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25 + suppressLintsFor { setStep('google-java-format') } importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling") trimTrailingWhitespace() diff --git a/app/proprietary/build.gradle b/app/proprietary/build.gradle index fb716338d0..141734a909 100644 --- a/app/proprietary/build.gradle +++ b/app/proprietary/build.gradle @@ -16,6 +16,8 @@ spotless { target 'src/**/java/**/*.java' targetExclude 'src/main/java/org/apache/**' googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false) + // google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25 + suppressLintsFor { setStep('google-java-format') } importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling") trimTrailingWhitespace() diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/config/AsyncConfig.java b/app/proprietary/src/main/java/stirling/software/proprietary/config/AsyncConfig.java index aa79f9b05e..5c0bf509ae 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/config/AsyncConfig.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/config/AsyncConfig.java @@ -48,4 +48,12 @@ public class AsyncConfig { adapter.setTaskDecorator(new MDCContextTaskDecorator()); return adapter; } + + @Bean(name = "aiStreamExecutor") + public Executor aiStreamExecutor() { + TaskExecutorAdapter adapter = + new TaskExecutorAdapter(Executors.newVirtualThreadPerTaskExecutor()); + adapter.setTaskDecorator(new MDCContextTaskDecorator()); + return adapter; + } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/controller/api/AiEngineController.java b/app/proprietary/src/main/java/stirling/software/proprietary/controller/api/AiEngineController.java index 3af1637579..3a93730876 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/controller/api/AiEngineController.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/controller/api/AiEngineController.java @@ -1,7 +1,10 @@ package stirling.software.proprietary.controller.api; import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Executor; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; @@ -12,6 +15,7 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import io.swagger.v3.oas.annotations.Hidden; import io.swagger.v3.oas.annotations.Operation; @@ -19,7 +23,6 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.validation.Valid; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import stirling.software.proprietary.model.api.ai.AiWorkflowRequest; @@ -34,7 +37,6 @@ import tools.jackson.databind.ObjectMapper; @Slf4j @RestController @RequestMapping("/api/v1/ai") -@RequiredArgsConstructor @Hidden @Tag(name = "AI Engine", description = "Endpoints for AI-powered PDF workflows") public class AiEngineController { @@ -42,6 +44,18 @@ public class AiEngineController { private final AiEngineClient aiEngineClient; private final AiWorkflowService aiWorkflowService; private final ObjectMapper objectMapper; + private final Executor aiStreamExecutor; + + public AiEngineController( + AiEngineClient aiEngineClient, + AiWorkflowService aiWorkflowService, + ObjectMapper objectMapper, + @Qualifier("aiStreamExecutor") Executor aiStreamExecutor) { + this.aiEngineClient = aiEngineClient; + this.aiWorkflowService = aiWorkflowService; + this.objectMapper = objectMapper; + this.aiStreamExecutor = aiStreamExecutor; + } @GetMapping("/health") @Operation( @@ -62,6 +76,49 @@ public class AiEngineController { return ResponseEntity.ok(aiWorkflowService.orchestrate(request)); } + @PostMapping(value = "/orchestrate/stream", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + @Operation( + summary = "Run an AI workflow with streaming progress", + description = + "Accepts a PDF upload and a user message, returns SSE events with progress" + + " updates followed by the final AI workflow result") + public SseEmitter orchestrateStream(@Valid @ModelAttribute AiWorkflowRequest request) { + SseEmitter emitter = new SseEmitter(180_000L); + + emitter.onTimeout( + () -> { + log.warn("SSE emitter timed out for AI orchestration stream"); + emitter.complete(); + }); + emitter.onError(e -> log.warn("SSE emitter error for AI orchestration stream", e)); + + aiStreamExecutor.execute(() -> runOrchestrationStream(request, emitter)); + + return emitter; + } + + private void runOrchestrationStream(AiWorkflowRequest request, SseEmitter emitter) { + try { + AiWorkflowResponse result = + aiWorkflowService.orchestrate( + request, progress -> sendEvent(emitter, "progress", progress)); + sendEvent(emitter, "result", result); + emitter.complete(); + } catch (Exception e) { + log.error("AI orchestration stream failed", e); + sendEvent(emitter, "error", Map.of("message", e.getMessage())); + emitter.completeWithError(e); + } + } + + private void sendEvent(SseEmitter emitter, String name, Object data) { + try { + emitter.send(SseEmitter.event().name(name).data(data, MediaType.APPLICATION_JSON)); + } catch (IOException e) { + log.debug("Failed to send SSE event (client may have disconnected)", e); + } + } + @PostMapping(value = "/pdf/edit", consumes = MediaType.APPLICATION_JSON_VALUE) @Operation( summary = "Generate a PDF edit plan", diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowPhase.java b/app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowPhase.java new file mode 100644 index 0000000000..bb1759fb3c --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowPhase.java @@ -0,0 +1,33 @@ +package stirling.software.proprietary.model.api.ai; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +/** Progress phases emitted during AI workflow orchestration. */ +public enum AiWorkflowPhase { + ANALYZING("analyzing"), + CALLING_ENGINE("calling_engine"), + EXTRACTING_CONTENT("extracting_content"), + PROCESSING("processing"); + + private final String value; + + AiWorkflowPhase(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return value; + } + + @JsonCreator + public static AiWorkflowPhase fromValue(String value) { + for (AiWorkflowPhase phase : values()) { + if (phase.value.equals(value)) { + return phase; + } + } + throw new IllegalArgumentException("Unknown AI workflow phase: " + value); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowProgressEvent.java b/app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowProgressEvent.java new file mode 100644 index 0000000000..c063e14f6e --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/model/api/ai/AiWorkflowProgressEvent.java @@ -0,0 +1,15 @@ +package stirling.software.proprietary.model.api.ai; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class AiWorkflowProgressEvent { + private AiWorkflowPhase phase; + private long timestamp; + + public static AiWorkflowProgressEvent of(AiWorkflowPhase phase) { + return new AiWorkflowProgressEvent(phase, System.currentTimeMillis()); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/service/AiWorkflowService.java b/app/proprietary/src/main/java/stirling/software/proprietary/service/AiWorkflowService.java index 817d2e4a29..4a966f6390 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/service/AiWorkflowService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/service/AiWorkflowService.java @@ -20,6 +20,8 @@ import stirling.software.common.util.ExceptionUtils; import stirling.software.proprietary.model.api.ai.AiWorkflowFileInput; import stirling.software.proprietary.model.api.ai.AiWorkflowFileRequest; import stirling.software.proprietary.model.api.ai.AiWorkflowOutcome; +import stirling.software.proprietary.model.api.ai.AiWorkflowPhase; +import stirling.software.proprietary.model.api.ai.AiWorkflowProgressEvent; import stirling.software.proprietary.model.api.ai.AiWorkflowRequest; import stirling.software.proprietary.model.api.ai.AiWorkflowResponse; import stirling.software.proprietary.service.PdfContentExtractor.LoadedFile; @@ -38,6 +40,13 @@ public class AiWorkflowService { private final PdfContentExtractor pdfContentExtractor; private final ObjectMapper objectMapper; + @FunctionalInterface + public interface ProgressListener { + void onProgress(AiWorkflowProgressEvent event); + } + + private static final ProgressListener NOOP_LISTENER = event -> {}; + private sealed interface WorkflowState { record Pending(WorkflowTurnRequest request) implements WorkflowState {} @@ -45,6 +54,11 @@ public class AiWorkflowService { } public AiWorkflowResponse orchestrate(AiWorkflowRequest request) throws IOException { + return orchestrate(request, NOOP_LISTENER); + } + + public AiWorkflowResponse orchestrate(AiWorkflowRequest request, ProgressListener listener) + throws IOException { validateRequest(request); Map filesByName = new LinkedHashMap<>(); @@ -57,19 +71,24 @@ public class AiWorkflowService { initialRequest.setUserMessage(request.getUserMessage().trim()); initialRequest.setFileNames(new ArrayList<>(filesByName.keySet())); + listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.ANALYZING)); + WorkflowState state = new WorkflowState.Pending(initialRequest); while (state instanceof WorkflowState.Pending pending) { - state = advance(pending.request(), filesByName); + state = advance(pending.request(), filesByName, listener); } return ((WorkflowState.Terminal) state).response(); } private WorkflowState advance( - WorkflowTurnRequest request, Map filesByName) + WorkflowTurnRequest request, + Map filesByName, + ProgressListener listener) throws IOException { + listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.CALLING_ENGINE)); AiWorkflowResponse response = invokeOrchestrator(request); return switch (response.getOutcome()) { - case NEED_CONTENT -> onNeedContent(response, filesByName, request); + case NEED_CONTENT -> onNeedContent(response, filesByName, request, listener); case ANSWER, NOT_FOUND, PLAN, @@ -86,7 +105,8 @@ public class AiWorkflowService { private WorkflowState onNeedContent( AiWorkflowResponse response, Map filesByName, - WorkflowTurnRequest request) + WorkflowTurnRequest request, + ProgressListener listener) throws IOException { if (!request.getArtifacts().isEmpty()) { return new WorkflowState.Terminal( @@ -119,6 +139,8 @@ public class AiWorkflowService { Collectors.toMap( AiWorkflowFileRequest::getFileName, r -> r)); + listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.EXTRACTING_CONTENT)); + List loadedFiles = new ArrayList<>(); try { for (String fileName : fileNamesToLoad) { @@ -133,6 +155,8 @@ public class AiWorkflowService { response.getMaxPages(), response.getMaxCharacters()); + listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.PROCESSING)); + WorkflowTurnRequest nextRequest = new WorkflowTurnRequest(); nextRequest.setUserMessage(request.getUserMessage()); nextRequest.setFileNames(request.getFileNames()); diff --git a/frontend/public/locales/en-GB/translation.toml b/frontend/public/locales/en-GB/translation.toml index 94248c215e..aef0edf518 100644 --- a/frontend/public/locales/en-GB/translation.toml +++ b/frontend/public/locales/en-GB/translation.toml @@ -2079,6 +2079,13 @@ keywords = "Keywords: odd, even" numbers = "Numbers/ranges: 5, 10-20" progressions = "Progressions: 3n, 4n+1" +[chat.progress] +thinking = "Thinking..." +analyzing = "Analysing your request..." +calling_engine = "AI is thinking..." +extracting_content = "Extracting content from your documents..." +processing = "Processing extracted content..." + [certSign] allSigned = "All participants have signed. Ready to finalize." awaitingSignatures = "Awaiting signatures" diff --git a/frontend/src/core/services/apiClientSetup.ts b/frontend/src/core/services/apiClientSetup.ts index ff19a76fd2..326978717b 100644 --- a/frontend/src/core/services/apiClientSetup.ts +++ b/frontend/src/core/services/apiClientSetup.ts @@ -12,3 +12,8 @@ export function setupApiInterceptors(client: AxiosInstance): void { (error) => Promise.reject(error), ); } + +/** Auth headers for raw fetch() calls (SSE streams, etc.). Proprietary overrides with JWT + XSRF. */ +export function getAuthHeaders(): Record { + return {}; +} diff --git a/frontend/src/proprietary/services/apiClientSetup.ts b/frontend/src/proprietary/services/apiClientSetup.ts index a5f0d5cbbd..79c4c38f9c 100644 --- a/frontend/src/proprietary/services/apiClientSetup.ts +++ b/frontend/src/proprietary/services/apiClientSetup.ts @@ -95,20 +95,29 @@ async function refreshAuthToken(client: AxiosInstance): Promise { } } +/** Auth headers for raw fetch() calls (SSE streams, etc.). */ +export function getAuthHeaders(): Record { + const headers: Record = {}; + const jwt = getJwtTokenFromStorage(); + if (jwt) { + headers["Authorization"] = `Bearer ${jwt}`; + } + const xsrf = getXsrfToken(); + if (xsrf) { + headers["X-XSRF-TOKEN"] = xsrf; + } + return headers; +} + export function setupApiInterceptors(client: AxiosInstance): void { // Install request interceptor to add JWT token client.interceptors.request.use( (config) => { - const jwtToken = getJwtTokenFromStorage(); - const xsrfToken = getXsrfToken(); - - if (jwtToken && !config.headers.Authorization) { - config.headers.Authorization = `Bearer ${jwtToken}`; - console.debug("[API Client] Added JWT token from localStorage to Authorization header"); - } - - if (xsrfToken && !config.headers["X-XSRF-TOKEN"]) { - config.headers["X-XSRF-TOKEN"] = xsrfToken; + const authHeaders = getAuthHeaders(); + for (const [key, value] of Object.entries(authHeaders)) { + if (!config.headers[key]) { + config.headers[key] = value; + } } return config; diff --git a/frontend/src/prototypes/components/chat/ChatContext.tsx b/frontend/src/prototypes/components/chat/ChatContext.tsx index 00b320d5cc..54bd7fee56 100644 --- a/frontend/src/prototypes/components/chat/ChatContext.tsx +++ b/frontend/src/prototypes/components/chat/ChatContext.tsx @@ -1,5 +1,6 @@ -import { createContext, useContext, useReducer, useCallback, type ReactNode } from "react"; +import { createContext, useContext, useReducer, useCallback, useRef, type ReactNode } from "react"; import { useAllFiles } from "@app/contexts/FileContext"; +import { getAuthHeaders } from "@app/services/apiClientSetup"; export interface ChatMessage { id: string; @@ -8,6 +9,13 @@ export interface ChatMessage { timestamp: number; } +export enum AiWorkflowPhase { + ANALYZING = "analyzing", + CALLING_ENGINE = "calling_engine", + EXTRACTING_CONTENT = "extracting_content", + PROCESSING = "processing", +} + type AiWorkflowOutcome = | "answer" | "not_found" @@ -37,11 +45,13 @@ interface ChatState { messages: ChatMessage[]; isOpen: boolean; isLoading: boolean; + progressPhase: AiWorkflowPhase | null; } type ChatAction = | { type: "ADD_MESSAGE"; message: ChatMessage } | { type: "SET_LOADING"; loading: boolean } + | { type: "SET_PROGRESS"; phase: AiWorkflowPhase | null } | { type: "TOGGLE_OPEN" } | { type: "SET_OPEN"; open: boolean }; @@ -51,6 +61,8 @@ function chatReducer(state: ChatState, action: ChatAction): ChatState { return { ...state, messages: [...state.messages, action.message] }; case "SET_LOADING": return { ...state, isLoading: action.loading }; + case "SET_PROGRESS": + return { ...state, progressPhase: action.phase }; case "TOGGLE_OPEN": return { ...state, isOpen: !state.isOpen }; case "SET_OPEN": @@ -85,10 +97,67 @@ function formatWorkflowResponse(data: AiWorkflowResponse): string { } } +/** + * Parses an SSE text stream and invokes callbacks for each named event. + */ +async function consumeSSEStream( + response: Response, + handlers: { + onProgress: (data: { phase: string; timestamp: number }) => void; + onResult: (data: AiWorkflowResponse) => void; + onError: (data: { message: string }) => void; + }, +) { + const reader = response.body!.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + let currentEvent = ""; + + for (;;) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + + // SSE frames are separated by double newlines + let boundary = buffer.indexOf("\n\n"); + while (boundary !== -1) { + const frame = buffer.slice(0, boundary); + buffer = buffer.slice(boundary + 2); + + let dataPayload = ""; + for (const line of frame.split("\n")) { + if (line.startsWith("event:")) { + currentEvent = line.slice(6).trim(); + } else if (line.startsWith("data:")) { + dataPayload += line.slice(5); + } + } + + if (dataPayload) { + try { + const parsed = JSON.parse(dataPayload); + if (currentEvent === "progress") { + handlers.onProgress(parsed); + } else if (currentEvent === "result") { + handlers.onResult(parsed); + } else if (currentEvent === "error") { + handlers.onError(parsed); + } + } catch { + // Skip malformed JSON frames + } + } + currentEvent = ""; + boundary = buffer.indexOf("\n\n"); + } + } +} + interface ChatContextValue { messages: ChatMessage[]; isOpen: boolean; isLoading: boolean; + progressPhase: AiWorkflowPhase | null; toggleOpen: () => void; setOpen: (open: boolean) => void; sendMessage: (content: string) => Promise; @@ -100,17 +169,24 @@ const initialState: ChatState = { messages: [], isOpen: false, isLoading: false, + progressPhase: null, }; export function ChatProvider({ children }: { children: ReactNode }) { const [state, dispatch] = useReducer(chatReducer, initialState); const { files: activeFiles } = useAllFiles(); + const abortRef = useRef(null); const toggleOpen = useCallback(() => dispatch({ type: "TOGGLE_OPEN" }), []); const setOpen = useCallback((open: boolean) => dispatch({ type: "SET_OPEN", open }), []); const sendMessage = useCallback( async (content: string) => { + // Abort any in-flight request + abortRef.current?.abort(); + const controller = new AbortController(); + abortRef.current = controller; + const userMessage: ChatMessage = { id: crypto.randomUUID(), role: "user", @@ -119,6 +195,7 @@ export function ChatProvider({ children }: { children: ReactNode }) { }; dispatch({ type: "ADD_MESSAGE", message: userMessage }); dispatch({ type: "SET_LOADING", loading: true }); + dispatch({ type: "SET_PROGRESS", phase: null }); try { const formData = new FormData(); @@ -127,34 +204,73 @@ export function ChatProvider({ children }: { children: ReactNode }) { formData.append(`fileInputs[${i}].fileInput`, file); }); - const response = await fetch("/api/v1/ai/orchestrate", { + const response = await fetch("/api/v1/ai/orchestrate/stream", { method: "POST", body: formData, + headers: getAuthHeaders(), + credentials: "include", + signal: controller.signal, }); if (!response.ok) { throw new Error(`AI engine request failed: ${response.status}`); } - const data: AiWorkflowResponse = await response.json(); - const replyContent = formatWorkflowResponse(data); - const assistantMessage: ChatMessage = { - id: crypto.randomUUID(), - role: "assistant", - content: replyContent, - timestamp: Date.now(), - }; - dispatch({ type: "ADD_MESSAGE", message: assistantMessage }); - } catch { - const errorMessage: ChatMessage = { - id: crypto.randomUUID(), - role: "assistant", - content: "Failed to get a response. The AI engine may not be available yet.", - timestamp: Date.now(), - }; - dispatch({ type: "ADD_MESSAGE", message: errorMessage }); + let receivedResult = false; + + await consumeSSEStream(response, { + onProgress: (data) => { + dispatch({ type: "SET_PROGRESS", phase: data.phase as AiWorkflowPhase }); + }, + onResult: (data) => { + receivedResult = true; + dispatch({ type: "SET_PROGRESS", phase: null }); + const replyContent = formatWorkflowResponse(data); + dispatch({ + type: "ADD_MESSAGE", + message: { + id: crypto.randomUUID(), + role: "assistant", + content: replyContent, + timestamp: Date.now(), + }, + }); + }, + onError: (data) => { + receivedResult = true; + dispatch({ type: "SET_PROGRESS", phase: null }); + dispatch({ + type: "ADD_MESSAGE", + message: { + id: crypto.randomUUID(), + role: "assistant", + content: data.message || "Something went wrong.", + timestamp: Date.now(), + }, + }); + }, + }); + + if (!receivedResult) { + throw new Error("Stream ended without a result"); + } + } catch (e) { + if ((e as Error).name === "AbortError") return; + dispatch({ type: "SET_PROGRESS", phase: null }); + dispatch({ + type: "ADD_MESSAGE", + message: { + id: crypto.randomUUID(), + role: "assistant", + content: "Failed to get a response. The AI engine may not be available yet.", + timestamp: Date.now(), + }, + }); } finally { dispatch({ type: "SET_LOADING", loading: false }); + if (abortRef.current === controller) { + abortRef.current = null; + } } }, [activeFiles], @@ -166,6 +282,7 @@ export function ChatProvider({ children }: { children: ReactNode }) { messages: state.messages, isOpen: state.isOpen, isLoading: state.isLoading, + progressPhase: state.progressPhase, toggleOpen, setOpen, sendMessage, diff --git a/frontend/src/prototypes/components/chat/ChatPanel.tsx b/frontend/src/prototypes/components/chat/ChatPanel.tsx index d94c5aabe8..0a360ad6db 100644 --- a/frontend/src/prototypes/components/chat/ChatPanel.tsx +++ b/frontend/src/prototypes/components/chat/ChatPanel.tsx @@ -1,5 +1,6 @@ import { useRef, useEffect, useState, type KeyboardEvent } from "react"; -import { ActionIcon, ScrollArea, TextInput, Stack, Text, Paper, Box, Transition } from "@mantine/core"; +import { useTranslation } from "react-i18next"; +import { ActionIcon, ScrollArea, TextInput, Stack, Text, Paper, Box, Transition, Loader, Group } from "@mantine/core"; import SendIcon from "@mui/icons-material/Send"; import ChatBubbleOutlineIcon from "@mui/icons-material/ChatBubbleOutline"; import CloseIcon from "@mui/icons-material/Close"; @@ -19,7 +20,8 @@ function ChatMessageBubble({ role, content }: { role: "user" | "assistant"; cont } export function ChatPanel() { - const { messages, isOpen, isLoading, toggleOpen, sendMessage } = useChat(); + const { t } = useTranslation(); + const { messages, isOpen, isLoading, progressPhase, toggleOpen, sendMessage } = useChat(); const [input, setInput] = useState(""); const scrollRef = useRef(null); const inputRef = useRef(null); @@ -95,9 +97,12 @@ export function ChatPanel() { {isLoading && (
- - Thinking... - + + + + {progressPhase ? t(`chat.progress.${progressPhase}`) : t("chat.progress.thinking")} + +
)}