mirror of
https://github.com/Frooodle/Stirling-PDF.git
synced 2026-04-22 23:08:53 +02:00
Add streaming to Engine orchestrator (#6094)
# Description of Changes Adds a streaming endpoint to the Java AI orchestrator (`/api/v1/ai/orchestrate/stream` in addition to the existing `/api/v1/ai/orchestrate`). This allows the caller to get updates of what stage of orchestration is being run at the time so UIs can give the user feedback. Also contains some dubious Gradle changes to suppress errors coming from Spotless, when it crashes in Google stuff. I'm not sure if that's appropriate to add, feel free to ask for changes in review.
This commit is contained in:
@@ -7,6 +7,8 @@ spotless {
|
||||
target 'src/**/java/**/*.java'
|
||||
targetExclude 'src/main/java/org/apache/**'
|
||||
googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false)
|
||||
// google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25
|
||||
suppressLintsFor { setStep('google-java-format') }
|
||||
|
||||
importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling")
|
||||
trimTrailingWhitespace()
|
||||
|
||||
@@ -14,6 +14,8 @@ spotless {
|
||||
target 'src/**/java/**/*.java'
|
||||
targetExclude 'src/main/resources/static/**', 'src/main/java/org/apache/**'
|
||||
googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false)
|
||||
// google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25
|
||||
suppressLintsFor { setStep('google-java-format') }
|
||||
|
||||
importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling")
|
||||
trimTrailingWhitespace()
|
||||
|
||||
@@ -16,6 +16,8 @@ spotless {
|
||||
target 'src/**/java/**/*.java'
|
||||
targetExclude 'src/main/java/org/apache/**'
|
||||
googleJavaFormat(googleJavaFormatVersion).aosp().reorderImports(false)
|
||||
// google-java-format 1.28.0 bundles Guava 32.x which crashes Spotless lint on JDK 24/25
|
||||
suppressLintsFor { setStep('google-java-format') }
|
||||
|
||||
importOrder("java", "javax", "org", "com", "net", "io", "jakarta", "lombok", "me", "stirling")
|
||||
trimTrailingWhitespace()
|
||||
|
||||
@@ -48,4 +48,12 @@ public class AsyncConfig {
|
||||
adapter.setTaskDecorator(new MDCContextTaskDecorator());
|
||||
return adapter;
|
||||
}
|
||||
|
||||
@Bean(name = "aiStreamExecutor")
|
||||
public Executor aiStreamExecutor() {
|
||||
TaskExecutorAdapter adapter =
|
||||
new TaskExecutorAdapter(Executors.newVirtualThreadPerTaskExecutor());
|
||||
adapter.setTaskDecorator(new MDCContextTaskDecorator());
|
||||
return adapter;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package stirling.software.proprietary.controller.api;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Executor;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
@@ -12,6 +15,7 @@ import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.server.ResponseStatusException;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import io.swagger.v3.oas.annotations.Hidden;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
@@ -19,7 +23,6 @@ import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
|
||||
import jakarta.validation.Valid;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowRequest;
|
||||
@@ -34,7 +37,6 @@ import tools.jackson.databind.ObjectMapper;
|
||||
@Slf4j
|
||||
@RestController
|
||||
@RequestMapping("/api/v1/ai")
|
||||
@RequiredArgsConstructor
|
||||
@Hidden
|
||||
@Tag(name = "AI Engine", description = "Endpoints for AI-powered PDF workflows")
|
||||
public class AiEngineController {
|
||||
@@ -42,6 +44,18 @@ public class AiEngineController {
|
||||
private final AiEngineClient aiEngineClient;
|
||||
private final AiWorkflowService aiWorkflowService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final Executor aiStreamExecutor;
|
||||
|
||||
public AiEngineController(
|
||||
AiEngineClient aiEngineClient,
|
||||
AiWorkflowService aiWorkflowService,
|
||||
ObjectMapper objectMapper,
|
||||
@Qualifier("aiStreamExecutor") Executor aiStreamExecutor) {
|
||||
this.aiEngineClient = aiEngineClient;
|
||||
this.aiWorkflowService = aiWorkflowService;
|
||||
this.objectMapper = objectMapper;
|
||||
this.aiStreamExecutor = aiStreamExecutor;
|
||||
}
|
||||
|
||||
@GetMapping("/health")
|
||||
@Operation(
|
||||
@@ -62,6 +76,49 @@ public class AiEngineController {
|
||||
return ResponseEntity.ok(aiWorkflowService.orchestrate(request));
|
||||
}
|
||||
|
||||
@PostMapping(value = "/orchestrate/stream", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
|
||||
@Operation(
|
||||
summary = "Run an AI workflow with streaming progress",
|
||||
description =
|
||||
"Accepts a PDF upload and a user message, returns SSE events with progress"
|
||||
+ " updates followed by the final AI workflow result")
|
||||
public SseEmitter orchestrateStream(@Valid @ModelAttribute AiWorkflowRequest request) {
|
||||
SseEmitter emitter = new SseEmitter(180_000L);
|
||||
|
||||
emitter.onTimeout(
|
||||
() -> {
|
||||
log.warn("SSE emitter timed out for AI orchestration stream");
|
||||
emitter.complete();
|
||||
});
|
||||
emitter.onError(e -> log.warn("SSE emitter error for AI orchestration stream", e));
|
||||
|
||||
aiStreamExecutor.execute(() -> runOrchestrationStream(request, emitter));
|
||||
|
||||
return emitter;
|
||||
}
|
||||
|
||||
private void runOrchestrationStream(AiWorkflowRequest request, SseEmitter emitter) {
|
||||
try {
|
||||
AiWorkflowResponse result =
|
||||
aiWorkflowService.orchestrate(
|
||||
request, progress -> sendEvent(emitter, "progress", progress));
|
||||
sendEvent(emitter, "result", result);
|
||||
emitter.complete();
|
||||
} catch (Exception e) {
|
||||
log.error("AI orchestration stream failed", e);
|
||||
sendEvent(emitter, "error", Map.of("message", e.getMessage()));
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendEvent(SseEmitter emitter, String name, Object data) {
|
||||
try {
|
||||
emitter.send(SseEmitter.event().name(name).data(data, MediaType.APPLICATION_JSON));
|
||||
} catch (IOException e) {
|
||||
log.debug("Failed to send SSE event (client may have disconnected)", e);
|
||||
}
|
||||
}
|
||||
|
||||
@PostMapping(value = "/pdf/edit", consumes = MediaType.APPLICATION_JSON_VALUE)
|
||||
@Operation(
|
||||
summary = "Generate a PDF edit plan",
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package stirling.software.proprietary.model.api.ai;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
|
||||
/** Progress phases emitted during AI workflow orchestration. */
|
||||
public enum AiWorkflowPhase {
|
||||
ANALYZING("analyzing"),
|
||||
CALLING_ENGINE("calling_engine"),
|
||||
EXTRACTING_CONTENT("extracting_content"),
|
||||
PROCESSING("processing");
|
||||
|
||||
private final String value;
|
||||
|
||||
AiWorkflowPhase(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@JsonValue
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@JsonCreator
|
||||
public static AiWorkflowPhase fromValue(String value) {
|
||||
for (AiWorkflowPhase phase : values()) {
|
||||
if (phase.value.equals(value)) {
|
||||
return phase;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown AI workflow phase: " + value);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package stirling.software.proprietary.model.api.ai;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class AiWorkflowProgressEvent {
|
||||
private AiWorkflowPhase phase;
|
||||
private long timestamp;
|
||||
|
||||
public static AiWorkflowProgressEvent of(AiWorkflowPhase phase) {
|
||||
return new AiWorkflowProgressEvent(phase, System.currentTimeMillis());
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,8 @@ import stirling.software.common.util.ExceptionUtils;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowFileInput;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowFileRequest;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowOutcome;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowPhase;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowProgressEvent;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowRequest;
|
||||
import stirling.software.proprietary.model.api.ai.AiWorkflowResponse;
|
||||
import stirling.software.proprietary.service.PdfContentExtractor.LoadedFile;
|
||||
@@ -38,6 +40,13 @@ public class AiWorkflowService {
|
||||
private final PdfContentExtractor pdfContentExtractor;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@FunctionalInterface
|
||||
public interface ProgressListener {
|
||||
void onProgress(AiWorkflowProgressEvent event);
|
||||
}
|
||||
|
||||
private static final ProgressListener NOOP_LISTENER = event -> {};
|
||||
|
||||
private sealed interface WorkflowState {
|
||||
record Pending(WorkflowTurnRequest request) implements WorkflowState {}
|
||||
|
||||
@@ -45,6 +54,11 @@ public class AiWorkflowService {
|
||||
}
|
||||
|
||||
public AiWorkflowResponse orchestrate(AiWorkflowRequest request) throws IOException {
|
||||
return orchestrate(request, NOOP_LISTENER);
|
||||
}
|
||||
|
||||
public AiWorkflowResponse orchestrate(AiWorkflowRequest request, ProgressListener listener)
|
||||
throws IOException {
|
||||
validateRequest(request);
|
||||
|
||||
Map<String, MultipartFile> filesByName = new LinkedHashMap<>();
|
||||
@@ -57,19 +71,24 @@ public class AiWorkflowService {
|
||||
initialRequest.setUserMessage(request.getUserMessage().trim());
|
||||
initialRequest.setFileNames(new ArrayList<>(filesByName.keySet()));
|
||||
|
||||
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.ANALYZING));
|
||||
|
||||
WorkflowState state = new WorkflowState.Pending(initialRequest);
|
||||
while (state instanceof WorkflowState.Pending pending) {
|
||||
state = advance(pending.request(), filesByName);
|
||||
state = advance(pending.request(), filesByName, listener);
|
||||
}
|
||||
return ((WorkflowState.Terminal) state).response();
|
||||
}
|
||||
|
||||
private WorkflowState advance(
|
||||
WorkflowTurnRequest request, Map<String, MultipartFile> filesByName)
|
||||
WorkflowTurnRequest request,
|
||||
Map<String, MultipartFile> filesByName,
|
||||
ProgressListener listener)
|
||||
throws IOException {
|
||||
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.CALLING_ENGINE));
|
||||
AiWorkflowResponse response = invokeOrchestrator(request);
|
||||
return switch (response.getOutcome()) {
|
||||
case NEED_CONTENT -> onNeedContent(response, filesByName, request);
|
||||
case NEED_CONTENT -> onNeedContent(response, filesByName, request, listener);
|
||||
case ANSWER,
|
||||
NOT_FOUND,
|
||||
PLAN,
|
||||
@@ -86,7 +105,8 @@ public class AiWorkflowService {
|
||||
private WorkflowState onNeedContent(
|
||||
AiWorkflowResponse response,
|
||||
Map<String, MultipartFile> filesByName,
|
||||
WorkflowTurnRequest request)
|
||||
WorkflowTurnRequest request,
|
||||
ProgressListener listener)
|
||||
throws IOException {
|
||||
if (!request.getArtifacts().isEmpty()) {
|
||||
return new WorkflowState.Terminal(
|
||||
@@ -119,6 +139,8 @@ public class AiWorkflowService {
|
||||
Collectors.toMap(
|
||||
AiWorkflowFileRequest::getFileName, r -> r));
|
||||
|
||||
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.EXTRACTING_CONTENT));
|
||||
|
||||
List<LoadedFile> loadedFiles = new ArrayList<>();
|
||||
try {
|
||||
for (String fileName : fileNamesToLoad) {
|
||||
@@ -133,6 +155,8 @@ public class AiWorkflowService {
|
||||
response.getMaxPages(),
|
||||
response.getMaxCharacters());
|
||||
|
||||
listener.onProgress(AiWorkflowProgressEvent.of(AiWorkflowPhase.PROCESSING));
|
||||
|
||||
WorkflowTurnRequest nextRequest = new WorkflowTurnRequest();
|
||||
nextRequest.setUserMessage(request.getUserMessage());
|
||||
nextRequest.setFileNames(request.getFileNames());
|
||||
|
||||
@@ -2079,6 +2079,13 @@ keywords = "Keywords: odd, even"
|
||||
numbers = "Numbers/ranges: 5, 10-20"
|
||||
progressions = "Progressions: 3n, 4n+1"
|
||||
|
||||
[chat.progress]
|
||||
thinking = "Thinking..."
|
||||
analyzing = "Analysing your request..."
|
||||
calling_engine = "AI is thinking..."
|
||||
extracting_content = "Extracting content from your documents..."
|
||||
processing = "Processing extracted content..."
|
||||
|
||||
[certSign]
|
||||
allSigned = "All participants have signed. Ready to finalize."
|
||||
awaitingSignatures = "Awaiting signatures"
|
||||
|
||||
@@ -12,3 +12,8 @@ export function setupApiInterceptors(client: AxiosInstance): void {
|
||||
(error) => Promise.reject(error),
|
||||
);
|
||||
}
|
||||
|
||||
/** Auth headers for raw fetch() calls (SSE streams, etc.). Proprietary overrides with JWT + XSRF. */
|
||||
export function getAuthHeaders(): Record<string, string> {
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -95,20 +95,29 @@ async function refreshAuthToken(client: AxiosInstance): Promise<string> {
|
||||
}
|
||||
}
|
||||
|
||||
/** Auth headers for raw fetch() calls (SSE streams, etc.). */
|
||||
export function getAuthHeaders(): Record<string, string> {
|
||||
const headers: Record<string, string> = {};
|
||||
const jwt = getJwtTokenFromStorage();
|
||||
if (jwt) {
|
||||
headers["Authorization"] = `Bearer ${jwt}`;
|
||||
}
|
||||
const xsrf = getXsrfToken();
|
||||
if (xsrf) {
|
||||
headers["X-XSRF-TOKEN"] = xsrf;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
export function setupApiInterceptors(client: AxiosInstance): void {
|
||||
// Install request interceptor to add JWT token
|
||||
client.interceptors.request.use(
|
||||
(config) => {
|
||||
const jwtToken = getJwtTokenFromStorage();
|
||||
const xsrfToken = getXsrfToken();
|
||||
|
||||
if (jwtToken && !config.headers.Authorization) {
|
||||
config.headers.Authorization = `Bearer ${jwtToken}`;
|
||||
console.debug("[API Client] Added JWT token from localStorage to Authorization header");
|
||||
}
|
||||
|
||||
if (xsrfToken && !config.headers["X-XSRF-TOKEN"]) {
|
||||
config.headers["X-XSRF-TOKEN"] = xsrfToken;
|
||||
const authHeaders = getAuthHeaders();
|
||||
for (const [key, value] of Object.entries(authHeaders)) {
|
||||
if (!config.headers[key]) {
|
||||
config.headers[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return config;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { createContext, useContext, useReducer, useCallback, type ReactNode } from "react";
|
||||
import { createContext, useContext, useReducer, useCallback, useRef, type ReactNode } from "react";
|
||||
import { useAllFiles } from "@app/contexts/FileContext";
|
||||
import { getAuthHeaders } from "@app/services/apiClientSetup";
|
||||
|
||||
export interface ChatMessage {
|
||||
id: string;
|
||||
@@ -8,6 +9,13 @@ export interface ChatMessage {
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export enum AiWorkflowPhase {
|
||||
ANALYZING = "analyzing",
|
||||
CALLING_ENGINE = "calling_engine",
|
||||
EXTRACTING_CONTENT = "extracting_content",
|
||||
PROCESSING = "processing",
|
||||
}
|
||||
|
||||
type AiWorkflowOutcome =
|
||||
| "answer"
|
||||
| "not_found"
|
||||
@@ -37,11 +45,13 @@ interface ChatState {
|
||||
messages: ChatMessage[];
|
||||
isOpen: boolean;
|
||||
isLoading: boolean;
|
||||
progressPhase: AiWorkflowPhase | null;
|
||||
}
|
||||
|
||||
type ChatAction =
|
||||
| { type: "ADD_MESSAGE"; message: ChatMessage }
|
||||
| { type: "SET_LOADING"; loading: boolean }
|
||||
| { type: "SET_PROGRESS"; phase: AiWorkflowPhase | null }
|
||||
| { type: "TOGGLE_OPEN" }
|
||||
| { type: "SET_OPEN"; open: boolean };
|
||||
|
||||
@@ -51,6 +61,8 @@ function chatReducer(state: ChatState, action: ChatAction): ChatState {
|
||||
return { ...state, messages: [...state.messages, action.message] };
|
||||
case "SET_LOADING":
|
||||
return { ...state, isLoading: action.loading };
|
||||
case "SET_PROGRESS":
|
||||
return { ...state, progressPhase: action.phase };
|
||||
case "TOGGLE_OPEN":
|
||||
return { ...state, isOpen: !state.isOpen };
|
||||
case "SET_OPEN":
|
||||
@@ -85,10 +97,67 @@ function formatWorkflowResponse(data: AiWorkflowResponse): string {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses an SSE text stream and invokes callbacks for each named event.
|
||||
*/
|
||||
async function consumeSSEStream(
|
||||
response: Response,
|
||||
handlers: {
|
||||
onProgress: (data: { phase: string; timestamp: number }) => void;
|
||||
onResult: (data: AiWorkflowResponse) => void;
|
||||
onError: (data: { message: string }) => void;
|
||||
},
|
||||
) {
|
||||
const reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let currentEvent = "";
|
||||
|
||||
for (;;) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// SSE frames are separated by double newlines
|
||||
let boundary = buffer.indexOf("\n\n");
|
||||
while (boundary !== -1) {
|
||||
const frame = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
|
||||
let dataPayload = "";
|
||||
for (const line of frame.split("\n")) {
|
||||
if (line.startsWith("event:")) {
|
||||
currentEvent = line.slice(6).trim();
|
||||
} else if (line.startsWith("data:")) {
|
||||
dataPayload += line.slice(5);
|
||||
}
|
||||
}
|
||||
|
||||
if (dataPayload) {
|
||||
try {
|
||||
const parsed = JSON.parse(dataPayload);
|
||||
if (currentEvent === "progress") {
|
||||
handlers.onProgress(parsed);
|
||||
} else if (currentEvent === "result") {
|
||||
handlers.onResult(parsed);
|
||||
} else if (currentEvent === "error") {
|
||||
handlers.onError(parsed);
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON frames
|
||||
}
|
||||
}
|
||||
currentEvent = "";
|
||||
boundary = buffer.indexOf("\n\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface ChatContextValue {
|
||||
messages: ChatMessage[];
|
||||
isOpen: boolean;
|
||||
isLoading: boolean;
|
||||
progressPhase: AiWorkflowPhase | null;
|
||||
toggleOpen: () => void;
|
||||
setOpen: (open: boolean) => void;
|
||||
sendMessage: (content: string) => Promise<void>;
|
||||
@@ -100,17 +169,24 @@ const initialState: ChatState = {
|
||||
messages: [],
|
||||
isOpen: false,
|
||||
isLoading: false,
|
||||
progressPhase: null,
|
||||
};
|
||||
|
||||
export function ChatProvider({ children }: { children: ReactNode }) {
|
||||
const [state, dispatch] = useReducer(chatReducer, initialState);
|
||||
const { files: activeFiles } = useAllFiles();
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
|
||||
const toggleOpen = useCallback(() => dispatch({ type: "TOGGLE_OPEN" }), []);
|
||||
const setOpen = useCallback((open: boolean) => dispatch({ type: "SET_OPEN", open }), []);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (content: string) => {
|
||||
// Abort any in-flight request
|
||||
abortRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
|
||||
const userMessage: ChatMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
role: "user",
|
||||
@@ -119,6 +195,7 @@ export function ChatProvider({ children }: { children: ReactNode }) {
|
||||
};
|
||||
dispatch({ type: "ADD_MESSAGE", message: userMessage });
|
||||
dispatch({ type: "SET_LOADING", loading: true });
|
||||
dispatch({ type: "SET_PROGRESS", phase: null });
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
@@ -127,34 +204,73 @@ export function ChatProvider({ children }: { children: ReactNode }) {
|
||||
formData.append(`fileInputs[${i}].fileInput`, file);
|
||||
});
|
||||
|
||||
const response = await fetch("/api/v1/ai/orchestrate", {
|
||||
const response = await fetch("/api/v1/ai/orchestrate/stream", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
headers: getAuthHeaders(),
|
||||
credentials: "include",
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`AI engine request failed: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: AiWorkflowResponse = await response.json();
|
||||
const replyContent = formatWorkflowResponse(data);
|
||||
const assistantMessage: ChatMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: replyContent,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
dispatch({ type: "ADD_MESSAGE", message: assistantMessage });
|
||||
} catch {
|
||||
const errorMessage: ChatMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: "Failed to get a response. The AI engine may not be available yet.",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
dispatch({ type: "ADD_MESSAGE", message: errorMessage });
|
||||
let receivedResult = false;
|
||||
|
||||
await consumeSSEStream(response, {
|
||||
onProgress: (data) => {
|
||||
dispatch({ type: "SET_PROGRESS", phase: data.phase as AiWorkflowPhase });
|
||||
},
|
||||
onResult: (data) => {
|
||||
receivedResult = true;
|
||||
dispatch({ type: "SET_PROGRESS", phase: null });
|
||||
const replyContent = formatWorkflowResponse(data);
|
||||
dispatch({
|
||||
type: "ADD_MESSAGE",
|
||||
message: {
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: replyContent,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
});
|
||||
},
|
||||
onError: (data) => {
|
||||
receivedResult = true;
|
||||
dispatch({ type: "SET_PROGRESS", phase: null });
|
||||
dispatch({
|
||||
type: "ADD_MESSAGE",
|
||||
message: {
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: data.message || "Something went wrong.",
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
if (!receivedResult) {
|
||||
throw new Error("Stream ended without a result");
|
||||
}
|
||||
} catch (e) {
|
||||
if ((e as Error).name === "AbortError") return;
|
||||
dispatch({ type: "SET_PROGRESS", phase: null });
|
||||
dispatch({
|
||||
type: "ADD_MESSAGE",
|
||||
message: {
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: "Failed to get a response. The AI engine may not be available yet.",
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
dispatch({ type: "SET_LOADING", loading: false });
|
||||
if (abortRef.current === controller) {
|
||||
abortRef.current = null;
|
||||
}
|
||||
}
|
||||
},
|
||||
[activeFiles],
|
||||
@@ -166,6 +282,7 @@ export function ChatProvider({ children }: { children: ReactNode }) {
|
||||
messages: state.messages,
|
||||
isOpen: state.isOpen,
|
||||
isLoading: state.isLoading,
|
||||
progressPhase: state.progressPhase,
|
||||
toggleOpen,
|
||||
setOpen,
|
||||
sendMessage,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useRef, useEffect, useState, type KeyboardEvent } from "react";
|
||||
import { ActionIcon, ScrollArea, TextInput, Stack, Text, Paper, Box, Transition } from "@mantine/core";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { ActionIcon, ScrollArea, TextInput, Stack, Text, Paper, Box, Transition, Loader, Group } from "@mantine/core";
|
||||
import SendIcon from "@mui/icons-material/Send";
|
||||
import ChatBubbleOutlineIcon from "@mui/icons-material/ChatBubbleOutline";
|
||||
import CloseIcon from "@mui/icons-material/Close";
|
||||
@@ -19,7 +20,8 @@ function ChatMessageBubble({ role, content }: { role: "user" | "assistant"; cont
|
||||
}
|
||||
|
||||
export function ChatPanel() {
|
||||
const { messages, isOpen, isLoading, toggleOpen, sendMessage } = useChat();
|
||||
const { t } = useTranslation();
|
||||
const { messages, isOpen, isLoading, progressPhase, toggleOpen, sendMessage } = useChat();
|
||||
const [input, setInput] = useState("");
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
@@ -95,9 +97,12 @@ export function ChatPanel() {
|
||||
{isLoading && (
|
||||
<div className="chat-message chat-message-assistant">
|
||||
<Paper className="chat-bubble chat-bubble-assistant" p="xs" radius="md">
|
||||
<Text size="sm" c="dimmed">
|
||||
Thinking...
|
||||
</Text>
|
||||
<Group gap="xs" wrap="nowrap">
|
||||
<Loader size="xs" type="dots" />
|
||||
<Text size="sm" c="dimmed">
|
||||
{progressPhase ? t(`chat.progress.${progressPhase}`) : t("chat.progress.thinking")}
|
||||
</Text>
|
||||
</Group>
|
||||
</Paper>
|
||||
</div>
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user