Files
Stirling-PDF/engine/scripts/generate_tool_models.py
James Brunton 3e94157137 Add document context for edit agent (#6152)
# Description of Changes
Adds the ability for the Edit agent to request the content of the
document before it decides which parameters it needs. This makes it able
to process requests like `Split the document after the page containing
the "My Section" section`, allowing for document context-based requests
for all[^1] tools.

I had to make a few changes elsewhere to make this work, including:
- Moving the requesting of content out of the Question Agent and into a
common location
- Added specific API docs for the Split param because the generic ones
were not specific enough for the AI to be able to reliably perform the
correct operation
- Fixed an issue in the tool models generator which caused the Redact
params to only be half-generated (causing Pydantic to crash when the AI
tried to run Redact)
- Added missing logging to a bunch of tools and hooked it up properly so
it'll print to stderr
- Made the limits for the max pages/chars to extract from PDFs
configurable via env var

[^1]: Many of the tools can't actually do anything useful with the
context at this stage, but will just need the tool API to be extended
with new features like page-specific operations to be automatically able
to do smart operations without needing to change the Edit agent itself.
2026-04-23 13:19:27 +00:00

263 lines
9.4 KiB
Python

#!/usr/bin/env python3
"""Generate Python tool models from the Java backend's OpenAPI spec (SwaggerDoc.json).
Uses datamodel-code-generator to convert OpenAPI request schemas to Pydantic models.
Run via:
task engine:tool-models
"""
from __future__ import annotations
import argparse
import json
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from datamodel_code_generator import InputFileType, PythonVersion, generate
from datamodel_code_generator.enums import DataModelType
from datamodel_code_generator.format import Formatter
from referencing import Registry, Resource
from referencing.jsonschema import DRAFT202012
# Fields inherited from PDFFile base class — not tool parameters.
BASE_CLASS_FIELDS = frozenset({"fileInput", "fileId"})
_ENGINE_ROOT = Path(__file__).resolve().parents[1]
_FILE_HEADER = (
"# AUTO-GENERATED FILE. DO NOT EDIT.\n"
"# Generated by scripts/generate_tool_models.py from Java OpenAPI spec (SwaggerDoc.json).\n"
"# ruff: noqa: E501"
)
@dataclass
class ToolSpec:
path: str
enum_name: str
class_name: str
@dataclass
class DiscoveryResult:
tools: list[ToolSpec]
combined_schema: dict[str, Any]
class ToolDiscovery:
"""Discovers tool endpoints from an OpenAPI spec and builds a combined JSON Schema."""
# Namespaces exposed to the LLM as callable tools. Largely matches ``InternalApiClient.java``.
# Note: ``/api/v1/filter/`` is intentionally excluded because those APIs are for pipeline processing,
# not tool execution.
ALLOWED_PATH_PREFIXES = (
"/api/v1/general/",
"/api/v1/misc/",
"/api/v1/security/",
"/api/v1/convert/",
)
def __init__(self, spec: dict[str, Any]):
resource = Resource.from_contents(spec, default_specification=DRAFT202012)
self.resolver = Registry().with_resource("", resource).resolver()
self.spec = spec
def discover(self) -> DiscoveryResult:
tools: list[ToolSpec] = []
defs: dict[str, Any] = {}
used_enum: set[str] = set()
used_class: set[str] = set()
for path, path_item in sorted(self.spec.get("paths", {}).items()):
if "{" in path or not any(path.startswith(p) for p in self.ALLOWED_PATH_PREFIXES):
continue
properties = self._get_request_properties(path_item)
if not properties:
continue
clean_props = self._filter_properties(properties)
if not clean_props:
continue
enum_name = _deduplicate(_path_to_enum_name(path), used_enum)
class_name = _deduplicate(_path_to_class_name(path), used_class)
defs[class_name] = {"type": "object", "properties": clean_props}
tools.append(ToolSpec(path, enum_name, class_name))
self._inline_component_refs(defs)
combined_schema: dict[str, Any] = {
"$defs": defs,
"anyOf": [{"$ref": f"#/$defs/{t.class_name}"} for t in tools],
}
return DiscoveryResult(tools=tools, combined_schema=combined_schema)
def _inline_component_refs(self, defs: dict[str, Any]) -> None:
"""Pull every component transitively referenced from tool param schemas into ``defs``
and rewrite the refs from ``#/components/schemas/X`` to ``#/$defs/X``.
Without this, nested refs (e.g. ``list[RedactionArea]``) are unresolvable when the
combined schema is handed to datamodel-code-generator, producing ``RootModel[Any]``
shells that downstream JSON-schema strict-mode transformers reject.
"""
schemas = self.spec.get("components", {}).get("schemas", {})
queue: list[object] = list(defs.values())
while queue:
for name in _rewrite_refs(queue.pop()):
if name not in defs and name in schemas:
defs[name] = schemas[name]
queue.append(schemas[name])
def _resolve_ref(self, schema: dict[str, Any]) -> dict[str, Any]:
if "$ref" in schema:
return self.resolver.lookup(schema["$ref"]).contents
return schema
def _get_request_properties(self, path_item: dict[str, Any]) -> dict[str, Any] | None:
post = path_item.get("post")
if not post:
return None
content = post.get("requestBody", {}).get("content", {})
for media_type in ("multipart/form-data", "application/json"):
if media_type in content:
schema = content[media_type].get("schema")
if schema:
return self._resolve_ref(schema).get("properties")
return None
def _filter_properties(self, properties: dict[str, Any]) -> dict[str, Any]:
"""Remove base-class fields and binary upload fields, resolving any $refs."""
clean: dict[str, Any] = {}
for name, prop in properties.items():
if name in BASE_CLASS_FIELDS:
continue
prop = self._resolve_ref(prop)
if prop.get("type") == "string" and prop.get("format") == "binary":
continue
clean[name] = prop
return clean
_COMPONENT_REF_PREFIX = "#/components/schemas/"
def _rewrite_refs(obj: object) -> Iterable[str]:
"""Rewrite ``#/components/schemas/X`` refs to ``#/$defs/X`` in place, yielding each
component name encountered so the caller can pull referenced schemas into ``$defs``.
"""
if isinstance(obj, dict):
ref = obj.get("$ref")
if isinstance(ref, str) and ref.startswith(_COMPONENT_REF_PREFIX):
name = ref.removeprefix(_COMPONENT_REF_PREFIX)
obj["$ref"] = "#/$defs/" + name
yield name
for value in obj.values():
yield from _rewrite_refs(value)
elif isinstance(obj, list):
for value in obj:
yield from _rewrite_refs(value)
def _tool_name_segments(path: str) -> str:
"""Extract a descriptive name from the endpoint path.
Converters use two segments (e.g. /api/v1/convert/cbr/pdf → cbr-to-pdf).
Other tools use the last segment (e.g. /api/v1/misc/compress-pdf → compress-pdf).
"""
parts = path.rstrip("/").split("/")
if "/api/v1/convert/" in path and len(parts) >= 6:
return f"{parts[-2]}-to-{parts[-1]}"
return parts[-1]
def _path_to_enum_name(path: str) -> str:
return _tool_name_segments(path).replace("-", "_").upper()
def _path_to_class_name(path: str) -> str:
return "".join(p.capitalize() for p in _tool_name_segments(path).split("-")) + "Params"
def _deduplicate(name: str, used: set[str]) -> str:
"""Return name, appending 2, 3, ... if already in used. Adds result to used."""
candidate = name
n = 2
while candidate in used:
candidate = f"{name}{n}"
n += 1
used.add(candidate)
return candidate
def generate_models_code(combined_schema: dict[str, Any]) -> str:
"""Run datamodel-code-generator once on the combined schema."""
code = generate(
input_=json.dumps(combined_schema, sort_keys=True),
input_file_type=InputFileType.JsonSchema,
output_model_type=DataModelType.PydanticV2BaseModel,
target_python_version=PythonVersion.PY_313,
snake_case_field=True,
base_class="stirling.models.base.ApiModel",
field_constraints=True,
no_alias=True,
set_default_enum_member=True,
additional_imports=["enum.StrEnum"],
enable_version_header=False,
custom_file_header=_FILE_HEADER,
formatters=[Formatter.RUFF_FORMAT, Formatter.RUFF_CHECK],
settings_path=_ENGINE_ROOT / "pyproject.toml",
)
return str(code or "")
def write_output(out_path: Path, tools: list[ToolSpec], models_code: str) -> None:
union_lines = ["type ParamToolModel = ("]
for i, tool in enumerate(tools):
prefix = " | " if i > 0 else " "
union_lines.append(f"{prefix}{tool.class_name}")
union_lines.append(")")
union_lines.append("type ParamToolModelType = type[ParamToolModel]")
enum_lines = [
"class ToolEndpoint(StrEnum):",
*(f' {t.enum_name} = "{t.path}"' for t in tools),
]
ops_lines = [
"OPERATIONS: dict[ToolEndpoint, ParamToolModelType] = {",
*(f" ToolEndpoint.{t.enum_name}: {t.class_name}," for t in tools),
"}",
]
parts = [models_code, "\n", *union_lines, "\n", *enum_lines, "\n", *ops_lines, ""]
out_path.write_text("\n".join(parts), encoding="utf-8")
def main() -> None:
parser = argparse.ArgumentParser(description="Generate Python tool models from Java OpenAPI spec")
parser.add_argument("--spec", required=True, help="Path to SwaggerDoc.json")
parser.add_argument("--output", required=True, help="Path to output tool_models.py")
args = parser.parse_args()
spec_path = Path(args.spec)
if not spec_path.exists():
raise SystemExit(f"OpenAPI spec not found at {spec_path}\nRun 'task engine:tool-models' to generate it.")
output_path = Path(args.output)
with open(spec_path) as f:
spec = json.load(f)
result = ToolDiscovery(spec).discover()
models_code = generate_models_code(result.combined_schema)
write_output(output_path, result.tools, models_code)
print(f"Generated {len(result.tools)} tool models from {spec_path.name}")
for tool in result.tools:
print(f" {tool.enum_name}: {tool.path}{tool.class_name}")
if __name__ == "__main__":
main()