Full streaming support

This commit is contained in:
Nicolas Mowen
2026-02-16 17:40:11 -07:00
parent e09e9a0b7a
commit 858367c98a
6 changed files with 591 additions and 301 deletions

View File

@@ -25,6 +25,7 @@ from frigate.api.defs.response.chat_response import (
)
from frigate.api.defs.tags import Tags
from frigate.api.event import events
from frigate.genai.utils import build_assistant_message_for_conversation
logger = logging.getLogger(__name__)
@@ -403,6 +404,78 @@ async def _execute_tool_internal(
return {"error": f"Unknown tool: {tool_name}"}
async def _execute_pending_tools(
pending_tool_calls: List[Dict[str, Any]],
request: Request,
allowed_cameras: List[str],
) -> tuple[List[ToolCall], List[Dict[str, Any]]]:
"""
Execute a list of tool calls; return (ToolCall list for API response, tool result dicts for conversation).
"""
tool_calls_out: List[ToolCall] = []
tool_results: List[Dict[str, Any]] = []
for tool_call in pending_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call.get("arguments") or {}
tool_call_id = tool_call["id"]
logger.debug(
f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}"
)
try:
tool_result = await _execute_tool_internal(
tool_name, tool_args, request, allowed_cameras
)
if tool_name == "search_objects" and isinstance(tool_result, list):
tool_result = _format_events_with_local_time(tool_result)
_keys = {
"id",
"camera",
"label",
"zones",
"start_time_local",
"end_time_local",
"sub_label",
"event_count",
}
tool_result = [
{k: evt[k] for k in _keys if k in evt}
for evt in tool_result
if isinstance(evt, dict)
]
result_content = (
json.dumps(tool_result)
if isinstance(tool_result, (dict, list))
else (tool_result if isinstance(tool_result, str) else str(tool_result))
)
tool_calls_out.append(
ToolCall(name=tool_name, arguments=tool_args, response=result_content)
)
tool_results.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": result_content,
}
)
except Exception as e:
logger.error(
f"Error executing tool {tool_name} (id: {tool_call_id}): {e}",
exc_info=True,
)
error_content = json.dumps({"error": f"Tool execution failed: {str(e)}"})
tool_calls_out.append(
ToolCall(name=tool_name, arguments=tool_args, response=error_content)
)
tool_results.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": error_content,
}
)
return (tool_calls_out, tool_results)
@router.post(
"/chat/completion",
dependencies=[Depends(allow_any_authenticated())],
@@ -527,6 +600,81 @@ Always be accurate with time calculations based on the current date provided.{ca
f"{len(tools)} tool(s) available, max_iterations={max_iterations}"
)
# True LLM streaming when client supports it and stream requested
if body.stream and hasattr(genai_client, "chat_with_tools_stream"):
stream_tool_calls: List[ToolCall] = []
stream_iterations = 0
async def stream_body_llm():
nonlocal conversation, stream_tool_calls, stream_iterations
while stream_iterations < max_iterations:
logger.debug(
f"Streaming LLM (iteration {stream_iterations + 1}/{max_iterations}) "
f"with {len(conversation)} message(s)"
)
async for event in genai_client.chat_with_tools_stream(
messages=conversation,
tools=tools if tools else None,
tool_choice="auto",
):
kind, value = event
if kind == "content_delta":
yield (
json.dumps({"type": "content", "delta": value}).encode(
"utf-8"
)
+ b"\n"
)
elif kind == "message":
msg = value
if msg.get("finish_reason") == "error":
yield (
json.dumps(
{
"type": "error",
"error": "An error occurred while processing your request.",
}
).encode("utf-8")
+ b"\n"
)
return
pending = msg.get("tool_calls")
if pending:
stream_iterations += 1
conversation.append(
build_assistant_message_for_conversation(
msg.get("content"), pending
)
)
executed_calls, tool_results = await _execute_pending_tools(
pending, request, allowed_cameras
)
stream_tool_calls.extend(executed_calls)
conversation.extend(tool_results)
yield (
json.dumps(
{
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in stream_tool_calls
],
}
).encode("utf-8")
+ b"\n"
)
break
else:
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
return
else:
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
return StreamingResponse(
stream_body_llm(),
media_type="application/x-ndjson",
headers={"X-Accel-Buffering": "no"},
)
try:
while tool_iterations < max_iterations:
logger.debug(
@@ -548,23 +696,11 @@ Always be accurate with time calculations based on the current date provided.{ca
status_code=500,
)
assistant_message = {
"role": "assistant",
"content": response.get("content"),
}
if response.get("tool_calls"):
assistant_message["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["arguments"]),
},
}
for tc in response["tool_calls"]
]
conversation.append(assistant_message)
conversation.append(
build_assistant_message_for_conversation(
response.get("content"), response.get("tool_calls")
)
)
pending_tool_calls = response.get("tool_calls")
if not pending_tool_calls:
@@ -574,6 +710,7 @@ Always be accurate with time calculations based on the current date provided.{ca
final_content = response.get("content") or ""
if body.stream:
async def stream_body() -> Any:
if tool_calls:
yield (
@@ -590,8 +727,9 @@ Always be accurate with time calculations based on the current date provided.{ca
# Stream content in word-sized chunks for smooth UX
for part in _chunk_content(final_content):
yield (
json.dumps({"type": "content", "delta": part})
.encode("utf-8")
json.dumps({"type": "content", "delta": part}).encode(
"utf-8"
)
+ b"\n"
)
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
@@ -614,123 +752,15 @@ Always be accurate with time calculations based on the current date provided.{ca
).model_dump(),
)
# Execute tools
tool_iterations += 1
logger.debug(
f"Tool calls detected (iteration {tool_iterations}/{max_iterations}): "
f"{len(pending_tool_calls)} tool(s) to execute"
)
tool_results = []
for tool_call in pending_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["arguments"]
tool_call_id = tool_call["id"]
logger.debug(
f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}"
)
try:
tool_result = await _execute_tool_internal(
tool_name, tool_args, request, allowed_cameras
)
# Add local time fields to search_objects results so the LLM doesn't hallucinate timestamps
if tool_name == "search_objects" and isinstance(tool_result, list):
tool_result = _format_events_with_local_time(tool_result)
_keys = {
"id",
"camera",
"label",
"zones",
"start_time_local",
"end_time_local",
"sub_label",
"event_count",
}
tool_result = [
{k: evt[k] for k in _keys if k in evt}
for evt in tool_result
if isinstance(evt, dict)
]
if isinstance(tool_result, dict):
result_content = json.dumps(tool_result)
result_summary = tool_result
if isinstance(tool_result, dict) and isinstance(
tool_result.get("content"), list
):
result_count = len(tool_result.get("content", []))
result_summary = {
"count": result_count,
"sample": tool_result.get("content", [])[:2]
if result_count > 0
else [],
}
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result: {json.dumps(result_summary, indent=2)}"
)
elif isinstance(tool_result, list):
result_content = json.dumps(tool_result)
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result: {len(tool_result)} item(s)"
)
elif isinstance(tool_result, str):
result_content = tool_result
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result length: {len(result_content)} characters"
)
else:
result_content = str(tool_result)
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result type: {type(tool_result).__name__}"
)
tool_calls.append(
ToolCall(
name=tool_name,
arguments=tool_args or {},
response=result_content,
)
)
tool_results.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": result_content,
}
)
except Exception as e:
logger.error(
f"Error executing tool {tool_name} (id: {tool_call_id}): {e}",
exc_info=True,
)
error_content = json.dumps(
{"error": f"Tool execution failed: {str(e)}"}
)
tool_calls.append(
ToolCall(
name=tool_name,
arguments=tool_args or {},
response=error_content,
)
)
tool_results.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": error_content,
}
)
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) failed. Error result added to conversation."
)
executed_calls, tool_results = await _execute_pending_tools(
pending_tool_calls, request, allowed_cameras
)
tool_calls.extend(executed_calls)
conversation.extend(tool_results)
logger.debug(
f"Added {len(tool_results)} tool result(s) to conversation. "