mirror of
https://github.com/Frooodle/Stirling-PDF.git
synced 2026-05-10 23:10:08 +02:00
# Description of Changes Adds the ability for the Edit agent to request the content of the document before it decides which parameters it needs. This makes it able to process requests like `Split the document after the page containing the "My Section" section`, allowing for document context-based requests for all[^1] tools. I had to make a few changes elsewhere to make this work, including: - Moving the requesting of content out of the Question Agent and into a common location - Added specific API docs for the Split param because the generic ones were not specific enough for the AI to be able to reliably perform the correct operation - Fixed an issue in the tool models generator which caused the Redact params to only be half-generated (causing Pydantic to crash when the AI tried to run Redact) - Added missing logging to a bunch of tools and hooked it up properly so it'll print to stderr - Made the limits for the max pages/chars to extract from PDFs configurable via env var [^1]: Many of the tools can't actually do anything useful with the context at this stage, but will just need the tool API to be extended with new features like page-specific operations to be automatically able to do smart operations without needing to change the Edit agent itself.
263 lines
9.4 KiB
Python
263 lines
9.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Generate Python tool models from the Java backend's OpenAPI spec (SwaggerDoc.json).
|
|
|
|
Uses datamodel-code-generator to convert OpenAPI request schemas to Pydantic models.
|
|
Run via:
|
|
task engine:tool-models
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from datamodel_code_generator import InputFileType, PythonVersion, generate
|
|
from datamodel_code_generator.enums import DataModelType
|
|
from datamodel_code_generator.format import Formatter
|
|
from referencing import Registry, Resource
|
|
from referencing.jsonschema import DRAFT202012
|
|
|
|
# Fields inherited from PDFFile base class — not tool parameters.
|
|
BASE_CLASS_FIELDS = frozenset({"fileInput", "fileId"})
|
|
|
|
_ENGINE_ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
_FILE_HEADER = (
|
|
"# AUTO-GENERATED FILE. DO NOT EDIT.\n"
|
|
"# Generated by scripts/generate_tool_models.py from Java OpenAPI spec (SwaggerDoc.json).\n"
|
|
"# ruff: noqa: E501"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ToolSpec:
|
|
path: str
|
|
enum_name: str
|
|
class_name: str
|
|
|
|
|
|
@dataclass
|
|
class DiscoveryResult:
|
|
tools: list[ToolSpec]
|
|
combined_schema: dict[str, Any]
|
|
|
|
|
|
class ToolDiscovery:
|
|
"""Discovers tool endpoints from an OpenAPI spec and builds a combined JSON Schema."""
|
|
|
|
# Namespaces exposed to the LLM as callable tools. Largely matches ``InternalApiClient.java``.
|
|
# Note: ``/api/v1/filter/`` is intentionally excluded because those APIs are for pipeline processing,
|
|
# not tool execution.
|
|
ALLOWED_PATH_PREFIXES = (
|
|
"/api/v1/general/",
|
|
"/api/v1/misc/",
|
|
"/api/v1/security/",
|
|
"/api/v1/convert/",
|
|
)
|
|
|
|
def __init__(self, spec: dict[str, Any]):
|
|
resource = Resource.from_contents(spec, default_specification=DRAFT202012)
|
|
self.resolver = Registry().with_resource("", resource).resolver()
|
|
self.spec = spec
|
|
|
|
def discover(self) -> DiscoveryResult:
|
|
tools: list[ToolSpec] = []
|
|
defs: dict[str, Any] = {}
|
|
used_enum: set[str] = set()
|
|
used_class: set[str] = set()
|
|
|
|
for path, path_item in sorted(self.spec.get("paths", {}).items()):
|
|
if "{" in path or not any(path.startswith(p) for p in self.ALLOWED_PATH_PREFIXES):
|
|
continue
|
|
properties = self._get_request_properties(path_item)
|
|
if not properties:
|
|
continue
|
|
clean_props = self._filter_properties(properties)
|
|
if not clean_props:
|
|
continue
|
|
|
|
enum_name = _deduplicate(_path_to_enum_name(path), used_enum)
|
|
class_name = _deduplicate(_path_to_class_name(path), used_class)
|
|
|
|
defs[class_name] = {"type": "object", "properties": clean_props}
|
|
tools.append(ToolSpec(path, enum_name, class_name))
|
|
|
|
self._inline_component_refs(defs)
|
|
|
|
combined_schema: dict[str, Any] = {
|
|
"$defs": defs,
|
|
"anyOf": [{"$ref": f"#/$defs/{t.class_name}"} for t in tools],
|
|
}
|
|
return DiscoveryResult(tools=tools, combined_schema=combined_schema)
|
|
|
|
def _inline_component_refs(self, defs: dict[str, Any]) -> None:
|
|
"""Pull every component transitively referenced from tool param schemas into ``defs``
|
|
and rewrite the refs from ``#/components/schemas/X`` to ``#/$defs/X``.
|
|
|
|
Without this, nested refs (e.g. ``list[RedactionArea]``) are unresolvable when the
|
|
combined schema is handed to datamodel-code-generator, producing ``RootModel[Any]``
|
|
shells that downstream JSON-schema strict-mode transformers reject.
|
|
"""
|
|
schemas = self.spec.get("components", {}).get("schemas", {})
|
|
queue: list[object] = list(defs.values())
|
|
while queue:
|
|
for name in _rewrite_refs(queue.pop()):
|
|
if name not in defs and name in schemas:
|
|
defs[name] = schemas[name]
|
|
queue.append(schemas[name])
|
|
|
|
def _resolve_ref(self, schema: dict[str, Any]) -> dict[str, Any]:
|
|
if "$ref" in schema:
|
|
return self.resolver.lookup(schema["$ref"]).contents
|
|
return schema
|
|
|
|
def _get_request_properties(self, path_item: dict[str, Any]) -> dict[str, Any] | None:
|
|
post = path_item.get("post")
|
|
if not post:
|
|
return None
|
|
content = post.get("requestBody", {}).get("content", {})
|
|
for media_type in ("multipart/form-data", "application/json"):
|
|
if media_type in content:
|
|
schema = content[media_type].get("schema")
|
|
if schema:
|
|
return self._resolve_ref(schema).get("properties")
|
|
return None
|
|
|
|
def _filter_properties(self, properties: dict[str, Any]) -> dict[str, Any]:
|
|
"""Remove base-class fields and binary upload fields, resolving any $refs."""
|
|
clean: dict[str, Any] = {}
|
|
for name, prop in properties.items():
|
|
if name in BASE_CLASS_FIELDS:
|
|
continue
|
|
prop = self._resolve_ref(prop)
|
|
if prop.get("type") == "string" and prop.get("format") == "binary":
|
|
continue
|
|
clean[name] = prop
|
|
return clean
|
|
|
|
|
|
_COMPONENT_REF_PREFIX = "#/components/schemas/"
|
|
|
|
|
|
def _rewrite_refs(obj: object) -> Iterable[str]:
|
|
"""Rewrite ``#/components/schemas/X`` refs to ``#/$defs/X`` in place, yielding each
|
|
component name encountered so the caller can pull referenced schemas into ``$defs``.
|
|
"""
|
|
if isinstance(obj, dict):
|
|
ref = obj.get("$ref")
|
|
if isinstance(ref, str) and ref.startswith(_COMPONENT_REF_PREFIX):
|
|
name = ref.removeprefix(_COMPONENT_REF_PREFIX)
|
|
obj["$ref"] = "#/$defs/" + name
|
|
yield name
|
|
for value in obj.values():
|
|
yield from _rewrite_refs(value)
|
|
elif isinstance(obj, list):
|
|
for value in obj:
|
|
yield from _rewrite_refs(value)
|
|
|
|
|
|
def _tool_name_segments(path: str) -> str:
|
|
"""Extract a descriptive name from the endpoint path.
|
|
|
|
Converters use two segments (e.g. /api/v1/convert/cbr/pdf → cbr-to-pdf).
|
|
Other tools use the last segment (e.g. /api/v1/misc/compress-pdf → compress-pdf).
|
|
"""
|
|
parts = path.rstrip("/").split("/")
|
|
if "/api/v1/convert/" in path and len(parts) >= 6:
|
|
return f"{parts[-2]}-to-{parts[-1]}"
|
|
return parts[-1]
|
|
|
|
|
|
def _path_to_enum_name(path: str) -> str:
|
|
return _tool_name_segments(path).replace("-", "_").upper()
|
|
|
|
|
|
def _path_to_class_name(path: str) -> str:
|
|
return "".join(p.capitalize() for p in _tool_name_segments(path).split("-")) + "Params"
|
|
|
|
|
|
def _deduplicate(name: str, used: set[str]) -> str:
|
|
"""Return name, appending 2, 3, ... if already in used. Adds result to used."""
|
|
candidate = name
|
|
n = 2
|
|
while candidate in used:
|
|
candidate = f"{name}{n}"
|
|
n += 1
|
|
used.add(candidate)
|
|
return candidate
|
|
|
|
|
|
def generate_models_code(combined_schema: dict[str, Any]) -> str:
|
|
"""Run datamodel-code-generator once on the combined schema."""
|
|
code = generate(
|
|
input_=json.dumps(combined_schema, sort_keys=True),
|
|
input_file_type=InputFileType.JsonSchema,
|
|
output_model_type=DataModelType.PydanticV2BaseModel,
|
|
target_python_version=PythonVersion.PY_313,
|
|
snake_case_field=True,
|
|
base_class="stirling.models.base.ApiModel",
|
|
field_constraints=True,
|
|
no_alias=True,
|
|
set_default_enum_member=True,
|
|
additional_imports=["enum.StrEnum"],
|
|
enable_version_header=False,
|
|
custom_file_header=_FILE_HEADER,
|
|
formatters=[Formatter.RUFF_FORMAT, Formatter.RUFF_CHECK],
|
|
settings_path=_ENGINE_ROOT / "pyproject.toml",
|
|
)
|
|
return str(code or "")
|
|
|
|
|
|
def write_output(out_path: Path, tools: list[ToolSpec], models_code: str) -> None:
|
|
union_lines = ["type ParamToolModel = ("]
|
|
for i, tool in enumerate(tools):
|
|
prefix = " | " if i > 0 else " "
|
|
union_lines.append(f"{prefix}{tool.class_name}")
|
|
union_lines.append(")")
|
|
union_lines.append("type ParamToolModelType = type[ParamToolModel]")
|
|
|
|
enum_lines = [
|
|
"class ToolEndpoint(StrEnum):",
|
|
*(f' {t.enum_name} = "{t.path}"' for t in tools),
|
|
]
|
|
|
|
ops_lines = [
|
|
"OPERATIONS: dict[ToolEndpoint, ParamToolModelType] = {",
|
|
*(f" ToolEndpoint.{t.enum_name}: {t.class_name}," for t in tools),
|
|
"}",
|
|
]
|
|
|
|
parts = [models_code, "\n", *union_lines, "\n", *enum_lines, "\n", *ops_lines, ""]
|
|
out_path.write_text("\n".join(parts), encoding="utf-8")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate Python tool models from Java OpenAPI spec")
|
|
parser.add_argument("--spec", required=True, help="Path to SwaggerDoc.json")
|
|
parser.add_argument("--output", required=True, help="Path to output tool_models.py")
|
|
args = parser.parse_args()
|
|
|
|
spec_path = Path(args.spec)
|
|
if not spec_path.exists():
|
|
raise SystemExit(f"OpenAPI spec not found at {spec_path}\nRun 'task engine:tool-models' to generate it.")
|
|
output_path = Path(args.output)
|
|
|
|
with open(spec_path) as f:
|
|
spec = json.load(f)
|
|
|
|
result = ToolDiscovery(spec).discover()
|
|
models_code = generate_models_code(result.combined_schema)
|
|
write_output(output_path, result.tools, models_code)
|
|
|
|
print(f"Generated {len(result.tools)} tool models from {spec_path.name}")
|
|
for tool in result.tools:
|
|
print(f" {tool.enum_name}: {tool.path} → {tool.class_name}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|