Files
Stirling-PDF/engine/scripts/generate_tool_models.py
2026-03-16 11:01:50 +00:00

511 lines
17 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import keyword
import re
import subprocess
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
TOOL_MODELS_HEADER = """# AUTO-GENERATED FILE. DO NOT EDIT.
# Generated by scripts/generate_tool_models.py from frontend TypeScript sources.
# ruff: noqa: N815
"""
OPERATION_TYPE_RE = re.compile(r"operationType\s*:\s*['\"]([A-Za-z0-9_]+)['\"]")
DEFAULT_REF_RE = re.compile(r"defaultParameters\s*:\s*([A-Za-z0-9_]+)")
DEFAULT_SHORTHAND_RE = re.compile(r"\bdefaultParameters\b")
IMPORT_RE = re.compile(r"import\s*\{([^}]+)\}\s*from\s*['\"]([^'\"]+)['\"]")
VAR_OBJ_RE_TEMPLATE = r"(?:export\s+)?const\s+{name}\b[^=]*=\s*\{{"
@dataclass
class ToolModelSpec:
tool_id: str
params: dict[str, Any]
param_types: dict[str, Any]
class ParseError(Exception):
pass
def _find_matching(text: str, start: int, open_char: str, close_char: str) -> int:
depth = 0
i = start
in_str: str | None = None
while i < len(text):
ch = text[i]
if in_str:
if ch == "\\":
i += 2
continue
if ch == in_str:
in_str = None
i += 1
continue
if ch in {"'", '"'}:
in_str = ch
elif ch == open_char:
depth += 1
elif ch == close_char:
depth -= 1
if depth == 0:
return i
i += 1
raise ParseError(f"Unmatched {open_char}{close_char} block")
def _extract_block(text: str, pattern: str) -> str | None:
match = re.search(pattern, text)
if not match:
return None
brace_start = text.find("{", match.end() - 1)
if brace_start == -1:
return None
brace_end = _find_matching(text, brace_start, "{", "}")
return text[brace_start : brace_end + 1]
def _split_top_level_items(obj_body: str) -> list[str]:
items: list[str] = []
depth_obj = depth_arr = 0
in_str: str | None = None
token_start = 0
i = 0
while i < len(obj_body):
ch = obj_body[i]
if in_str:
if ch == "\\":
i += 2
continue
if ch == in_str:
in_str = None
i += 1
continue
if ch in {"'", '"'}:
in_str = ch
elif ch == "{":
depth_obj += 1
elif ch == "}":
depth_obj -= 1
elif ch == "[":
depth_arr += 1
elif ch == "]":
depth_arr -= 1
elif ch == "," and depth_obj == 0 and depth_arr == 0:
piece = obj_body[token_start:i].strip()
if piece:
items.append(piece)
token_start = i + 1
i += 1
tail = obj_body[token_start:].strip()
if tail:
items.append(tail)
return items
def _resolve_import_path(repo_root: Path, current_file: Path, module_path: str) -> Path | None:
candidates: list[Path] = []
if module_path.startswith("@app/"):
rel = module_path[len("@app/") :]
candidates.extend(
[
repo_root / "frontend/src/core" / f"{rel}.ts",
repo_root / "frontend/src/core" / f"{rel}.tsx",
repo_root / "frontend/src/saas" / f"{rel}.ts",
repo_root / "frontend/src/saas" / f"{rel}.tsx",
repo_root / "frontend/src" / f"{rel}.ts",
repo_root / "frontend/src" / f"{rel}.tsx",
]
)
elif module_path.startswith("."):
base = (current_file.parent / module_path).resolve()
candidates.extend([Path(f"{base}.ts"), Path(f"{base}.tsx")])
for candidate in candidates:
if candidate.exists():
return candidate
return None
def _parse_literal_value(value: str, resolver: Callable[[str], dict[str, Any] | None]) -> Any:
value = value.strip()
if not value:
return None
if value.startswith("{") and value.endswith("}"):
return _parse_object_literal(value, resolver)
if value.startswith("[") and value.endswith("]"):
inner = value[1:-1].strip()
if not inner:
return []
return [_parse_literal_value(item, resolver) for item in _split_top_level_items(inner)]
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
return value[1:-1]
if value in {"true", "false"}:
return value == "true"
if value == "null":
return None
if re.fullmatch(r"-?\d+", value):
return int(value)
if re.fullmatch(r"-?\d+\.\d+", value):
return float(value)
resolved = resolver(value)
if resolved is not None:
return resolved
return None
def _parse_object_literal(obj_text: str, resolver: Callable[[str], dict[str, Any] | None]) -> dict[str, Any]:
body = obj_text.strip()[1:-1]
result: dict[str, Any] = {}
for item in _split_top_level_items(body):
if item.startswith("..."):
spread_name = item[3:].strip()
spread = resolver(spread_name)
if isinstance(spread, dict):
result.update(spread)
continue
if ":" not in item:
continue
key, raw_value = item.split(":", 1)
key = key.strip().strip("'\"")
result[key] = _parse_literal_value(raw_value.strip(), resolver)
return result
def _extract_imports(source: str) -> dict[str, str]:
imports: dict[str, str] = {}
for names, module_path in IMPORT_RE.findall(source):
for part in names.split(","):
segment = part.strip()
if not segment:
continue
if " as " in segment:
original, alias = [x.strip() for x in segment.split(" as ", 1)]
imports[alias] = module_path
imports[original] = module_path
else:
imports[segment] = module_path
return imports
def _resolve_object_identifier(repo_root: Path, file_path: Path, source: str, identifier: str) -> dict[str, Any] | None:
var_pattern = VAR_OBJ_RE_TEMPLATE.format(name=re.escape(identifier))
block = _extract_block(source, var_pattern)
imports = _extract_imports(source)
def resolver(name: str) -> dict[str, Any] | None:
local_block = _extract_block(source, VAR_OBJ_RE_TEMPLATE.format(name=re.escape(name)))
if local_block:
return _parse_object_literal(local_block, resolver)
import_path = imports.get(name)
if not import_path:
return None
resolved_file = _resolve_import_path(repo_root, file_path, import_path)
if not resolved_file:
return None
imported_source = resolved_file.read_text(encoding="utf-8")
return _resolve_object_identifier(repo_root, resolved_file, imported_source, name)
if block:
return _parse_object_literal(block, resolver)
import_path = imports.get(identifier)
if not import_path:
return None
resolved_file = _resolve_import_path(repo_root, file_path, import_path)
if not resolved_file:
return None
imported_source = resolved_file.read_text(encoding="utf-8")
return _resolve_object_identifier(repo_root, resolved_file, imported_source, identifier)
def _infer_py_type(value: Any) -> str:
if isinstance(value, bool):
return "bool"
if isinstance(value, int):
return "int"
if isinstance(value, float):
return "float"
if isinstance(value, str):
return "str"
if isinstance(value, list):
return "list[Any]"
if isinstance(value, dict):
return "dict[str, Any]"
return "Any"
def _spec_is_none(spec: dict[str, Any]) -> bool:
return spec.get("kind") == "null"
def _py_type_from_spec(spec: dict[str, Any]) -> str:
kind = spec.get("kind")
if kind == "string":
return "str"
if kind == "number":
return "float"
if kind == "boolean":
return "bool"
if kind == "date":
return "str"
if kind == "enum":
values = spec.get("values")
if isinstance(values, list) and values:
literal_values = ", ".join(_py_repr(v) for v in values)
return f"Literal[{literal_values}]"
if kind == "ref":
ref_name = spec.get("name")
if isinstance(ref_name, str) and ref_name.endswith("Parameters"):
return f"{ref_name[:-10]}Params"
if kind == "array":
element = spec.get("element")
inner = _py_type_from_spec(element) if isinstance(element, dict) else "Any"
return f"list[{inner}]"
if kind == "object":
dict_value = spec.get("dictValue")
if isinstance(dict_value, dict):
inner = _py_type_from_spec(dict_value)
return f"dict[str, {inner}]"
properties = spec.get("properties")
if isinstance(properties, dict) and properties:
property_types = {_py_type_from_spec(p) for p in properties.values() if isinstance(p, dict)}
if len(property_types) == 1:
inner = next(iter(property_types))
return f"dict[str, {inner}]"
return "dict[str, Any]"
if kind in {"null"}:
return "Any"
return "Any"
def _to_class_name(tool_id: str) -> str:
cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1 \2", tool_id)
cleaned = re.sub(r"[^A-Za-z0-9]+", " ", cleaned)
parts = [part.capitalize() for part in cleaned.split() if part]
return "".join(parts) + "Params"
def _to_snake_case(name: str) -> str:
snake = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name)
snake = re.sub(r"[^A-Za-z0-9]+", "_", snake).strip("_").lower()
if not snake:
snake = "param"
if snake[0].isdigit():
snake = f"param_{snake}"
if keyword.iskeyword(snake):
snake = f"{snake}_"
return snake
def _build_field_name_map(params: dict[str, Any]) -> dict[str, str]:
field_map: dict[str, str] = {}
used: set[str] = set()
for original_key in sorted(params):
base_name = _to_snake_case(original_key)
candidate = base_name
suffix = 2
while candidate in used:
candidate = f"{base_name}_{suffix}"
suffix += 1
used.add(candidate)
field_map[original_key] = candidate
return field_map
def _to_enum_member_name(tool_id: str) -> str:
return _to_snake_case(tool_id).upper()
def _build_enum_member_map(specs: list[ToolModelSpec]) -> dict[str, str]:
member_map: dict[str, str] = {}
used: set[str] = set()
for spec in specs:
base_name = _to_enum_member_name(spec.tool_id)
candidate = base_name
suffix = 2
while candidate in used:
candidate = f"{base_name}_{suffix}"
suffix += 1
used.add(candidate)
member_map[spec.tool_id] = candidate
return member_map
def _py_repr(value: Any) -> str:
return (
json.dumps(value, ensure_ascii=True).replace("true", "True").replace("false", "False").replace("null", "None")
)
def discover_tool_specs(repo_root: Path) -> list[ToolModelSpec]:
frontend_dir = repo_root / "frontend"
extractor = frontend_dir / "scripts/export-tool-specs.ts"
command = ["node", "--import", "tsx", str(extractor)]
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True,
cwd=str(frontend_dir),
)
raw = json.loads(result.stdout)
specs: list[ToolModelSpec] = []
for item in raw:
tool_id = item.get("tool_id")
if not isinstance(tool_id, str) or not tool_id:
continue
params = item.get("params")
param_types = item.get("param_types")
specs.append(
ToolModelSpec(
tool_id=tool_id,
params=params if isinstance(params, dict) else {},
param_types=param_types if isinstance(param_types, dict) else {},
)
)
return sorted(specs, key=lambda spec: spec.tool_id)
def write_models_module(out_path: Path, specs: list[ToolModelSpec]) -> None:
lines: list[str] = [
TOOL_MODELS_HEADER,
"from __future__ import annotations\n\n",
"from enum import StrEnum\n",
"from typing import Any, Literal\n\n",
"from models.base import ApiModel\n",
]
class_names: dict[str, str] = {spec.tool_id: _to_class_name(spec.tool_id) for spec in specs}
class_name_to_tool_id = {name: tool_id for tool_id, name in class_names.items()}
def extract_class_dependencies(spec: ToolModelSpec) -> set[str]:
deps: set[str] = set()
if not isinstance(spec.param_types, dict):
return deps
for entry in spec.param_types.values():
if not isinstance(entry, dict):
continue
type_spec = entry
if "type" in entry and isinstance(entry.get("type"), dict):
type_spec = entry["type"]
if not isinstance(type_spec, dict):
continue
if type_spec.get("kind") != "ref":
continue
ref_name = type_spec.get("name")
if isinstance(ref_name, str) and ref_name.endswith("Parameters"):
ref_class = f"{ref_name[:-10]}Params"
if ref_class in class_name_to_tool_id:
deps.add(ref_class)
return deps
dependencies_by_class: dict[str, set[str]] = {}
for spec in specs:
class_name = class_names[spec.tool_id]
dependencies_by_class[class_name] = extract_class_dependencies(spec)
remaining = set(class_names.values())
ordered_class_names: list[str] = []
while remaining:
progress = False
for class_name in sorted(remaining):
deps = dependencies_by_class.get(class_name, set())
if deps.issubset(set(ordered_class_names)):
ordered_class_names.append(class_name)
remaining.remove(class_name)
progress = True
break
if not progress:
ordered_class_names.extend(sorted(remaining))
break
ordered_specs = [next(spec for spec in specs if class_names[spec.tool_id] == name) for name in ordered_class_names]
for spec in ordered_specs:
class_name = class_names[spec.tool_id]
lines.append(f"class {class_name}(ApiModel):\n")
all_param_keys = set(spec.params)
if isinstance(spec.param_types, dict):
all_param_keys.update(spec.param_types.keys())
if not all_param_keys:
lines.append(" pass\n\n\n")
continue
field_name_map = _build_field_name_map({key: True for key in all_param_keys})
for key in sorted(all_param_keys):
field_name = field_name_map[key]
value = spec.params.get(key)
type_spec = spec.param_types.get(key) if isinstance(spec.param_types, dict) else None
if isinstance(type_spec, dict):
py_type = _py_type_from_spec(type_spec)
else:
py_type = _infer_py_type(value)
if value is None and (isinstance(type_spec, dict) and _spec_is_none(type_spec)):
if py_type != "Any" and "| None" not in py_type:
py_type = f"{py_type} | None"
lines.append(f" {field_name}: {py_type} = None\n")
elif value is None:
lines.append(f" {field_name}: {py_type} | None = None\n")
else:
if isinstance(type_spec, dict) and type_spec.get("kind") == "ref" and isinstance(value, dict):
lines.append(f" {field_name}: {py_type} = {py_type}.model_validate({_py_repr(value)})\n")
continue
lines.append(f" {field_name}: {py_type} = {_py_repr(value)}\n")
lines.append("\n\n")
if class_names:
union_members = " | ".join(class_names[tool_id] for tool_id in sorted(class_names))
lines.append(f"type ParamToolModel = {union_members}\n")
lines.append("type ParamToolModelType = type[ParamToolModel]\n\n")
else:
lines.append("type ParamToolModel = ApiModel\n")
lines.append("type ParamToolModelType = type[ParamToolModel]\n\n")
enum_member_map = _build_enum_member_map(specs)
lines.append("class OperationId(StrEnum):\n")
for spec in specs:
lines.append(f" {enum_member_map[spec.tool_id]} = {spec.tool_id!r}\n")
lines.extend(
[
"\n\n",
"OPERATIONS: dict[OperationId, ParamToolModelType] = {\n",
]
)
for spec in specs:
model_name = _to_class_name(spec.tool_id)
lines.append(f" OperationId.{enum_member_map[spec.tool_id]}: {model_name},\n")
lines.append("}\n")
out_path.write_text("".join(lines), encoding="utf-8")
def main() -> None:
parser = argparse.ArgumentParser(description="Generate tool models from frontend TypeScript tool definitions")
parser.add_argument("--spec", help="Deprecated (ignored)", default="")
parser.add_argument("--output", default="", help="Path to tool_models.py")
parser.add_argument("--ai-output", default="", help="Deprecated (ignored)")
args = parser.parse_args()
repo_root = Path(__file__).resolve().parents[3]
specs = discover_tool_specs(repo_root)
output_path = Path(args.output) if args.output else (repo_root / "docgen/backend/models/tool_models.py")
write_models_module(output_path, specs)
print(f"Wrote {len(specs)} tool model specs")
if __name__ == "__main__":
main()