diff --git a/frigate/api/chat.py b/frigate/api/chat.py index 34c318d85..939e399df 100644 --- a/frigate/api/chat.py +++ b/frigate/api/chat.py @@ -5,11 +5,11 @@ import json import logging import time from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional import cv2 from fastapi import APIRouter, Body, Depends, Request -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel from frigate.api.auth import ( @@ -31,6 +31,24 @@ logger = logging.getLogger(__name__) router = APIRouter(tags=[Tags.chat]) +def _chunk_content(content: str, chunk_size: int = 80) -> Generator[str, None, None]: + """Yield content in word-aware chunks for streaming.""" + if not content: + return + words = content.split(" ") + current: List[str] = [] + current_len = 0 + for w in words: + current.append(w) + current_len += len(w) + 1 + if current_len >= chunk_size: + yield " ".join(current) + " " + current = [] + current_len = 0 + if current: + yield " ".join(current) + + def _format_events_with_local_time( events_list: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: @@ -387,7 +405,6 @@ async def _execute_tool_internal( @router.post( "/chat/completion", - response_model=ChatCompletionResponse, dependencies=[Depends(allow_any_authenticated())], summary="Chat completion with tool calling", description=( @@ -399,7 +416,7 @@ async def chat_completion( request: Request, body: ChatCompletionRequest = Body(...), allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), -) -> JSONResponse: +): """ Chat completion endpoint with tool calling support. @@ -554,11 +571,41 @@ Always be accurate with time calculations based on the current date provided.{ca logger.debug( f"Chat completion finished with final answer (iterations: {tool_iterations})" ) + final_content = response.get("content") or "" + + if body.stream: + async def stream_body() -> Any: + if tool_calls: + yield ( + json.dumps( + { + "type": "tool_calls", + "tool_calls": [ + tc.model_dump() for tc in tool_calls + ], + } + ).encode("utf-8") + + b"\n" + ) + # Stream content in word-sized chunks for smooth UX + for part in _chunk_content(final_content): + yield ( + json.dumps({"type": "content", "delta": part}) + .encode("utf-8") + + b"\n" + ) + yield json.dumps({"type": "done"}).encode("utf-8") + b"\n" + + return StreamingResponse( + stream_body(), + media_type="application/x-ndjson", + ) + return JSONResponse( content=ChatCompletionResponse( message=ChatMessageResponse( role="assistant", - content=response.get("content"), + content=final_content, tool_calls=None, ), finish_reason=response.get("finish_reason", "stop"), diff --git a/frigate/api/defs/request/chat_body.py b/frigate/api/defs/request/chat_body.py index fa3c3860a..3a67cd038 100644 --- a/frigate/api/defs/request/chat_body.py +++ b/frigate/api/defs/request/chat_body.py @@ -39,3 +39,7 @@ class ChatCompletionRequest(BaseModel): "user message as multimodal content. Use with get_live_context for detection info." ), ) + stream: bool = Field( + default=False, + description="If true, stream the final assistant response in the body as newline-delimited JSON.", + ) diff --git a/web/src/pages/Chat.tsx b/web/src/pages/Chat.tsx index 74e396721..c562a982d 100644 --- a/web/src/pages/Chat.tsx +++ b/web/src/pages/Chat.tsx @@ -25,27 +25,123 @@ export default function ChatPage() { setMessages((prev) => [...prev, userMessage]); setIsLoading(true); - try { - const apiMessages = [...messages, userMessage].map((m) => ({ - role: m.role, - content: m.content, - })); - const { data } = await axios.post<{ - message: { role: string; content: string | null }; - tool_calls?: ToolCall[]; - }>("chat/completion", { messages: apiMessages }); + const apiMessages = [...messages, userMessage].map((m) => ({ + role: m.role, + content: m.content, + })); - const content = data.message?.content ?? ""; - setMessages((prev) => [ - ...prev, - { - role: "assistant", - content: content || " ", - toolCalls: data.tool_calls?.length ? data.tool_calls : undefined, - }, - ]); + try { + const baseURL = axios.defaults.baseURL ?? ""; + const url = `${baseURL}chat/completion`; + const headers: Record = { + "Content-Type": "application/json", + ...(axios.defaults.headers.common as Record), + }; + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify({ messages: apiMessages, stream: true }), + }); + + if (!res.ok) { + const errBody = await res.json().catch(() => ({})); + throw new Error( + (errBody as { error?: string }).error ?? res.statusText, + ); + } + + const reader = res.body?.getReader(); + const decoder = new TextDecoder(); + if (!reader) throw new Error("No response body"); + + const assistantMessage: ChatMessage = { + role: "assistant", + content: "", + toolCalls: undefined, + }; + setMessages((prev) => [...prev, assistantMessage]); + + let buffer = ""; + for (;;) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed) continue; + let data: { type: string; tool_calls?: ToolCall[]; delta?: string }; + try { + data = JSON.parse(trimmed) as { + type: string; + tool_calls?: ToolCall[]; + delta?: string; + }; + } catch { + continue; + } + if (data.type === "tool_calls" && data.tool_calls?.length) { + setMessages((prev) => { + const next = [...prev]; + const last = next[next.length - 1]; + if (last?.role === "assistant") + next[next.length - 1] = { + ...last, + toolCalls: data.tool_calls, + }; + return next; + }); + } else if (data.type === "content" && data.delta !== undefined) { + setMessages((prev) => { + const next = [...prev]; + const last = next[next.length - 1]; + if (last?.role === "assistant") + next[next.length - 1] = { + ...last, + content: last.content + data.delta, + }; + return next; + }); + } + } + } + if (buffer.trim()) { + try { + const data = JSON.parse(buffer.trim()) as { + type: string; + tool_calls?: ToolCall[]; + delta?: string; + }; + if (data.type === "content" && data.delta !== undefined) { + setMessages((prev) => { + const next = [...prev]; + const last = next[next.length - 1]; + if (last?.role === "assistant") + next[next.length - 1] = { + ...last, + content: last.content + data.delta, + }; + return next; + }); + } + } catch { + // ignore final malformed chunk + } + } + + setMessages((prev) => { + const next = [...prev]; + const last = next[next.length - 1]; + if (last?.role === "assistant" && last.content === "") + next[next.length - 1] = { ...last, content: " " }; + return next; + }); } catch { setError(t("error")); + setMessages((prev) => + prev.filter((m) => !(m.role === "assistant" && m.content === "")), + ); } finally { setIsLoading(false); }