mirror of
https://github.com/Frooodle/Stirling-PDF.git
synced 2026-03-28 02:31:17 +01:00
Add SaaS AI engine (#5907)
This commit is contained in:
510
engine/scripts/generate_tool_models.py
Normal file
510
engine/scripts/generate_tool_models.py
Normal file
@@ -0,0 +1,510 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import keyword
|
||||
import re
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
TOOL_MODELS_HEADER = """# AUTO-GENERATED FILE. DO NOT EDIT.
|
||||
# Generated by scripts/generate_tool_models.py from frontend TypeScript sources.
|
||||
# ruff: noqa: N815
|
||||
"""
|
||||
|
||||
|
||||
OPERATION_TYPE_RE = re.compile(r"operationType\s*:\s*['\"]([A-Za-z0-9_]+)['\"]")
|
||||
DEFAULT_REF_RE = re.compile(r"defaultParameters\s*:\s*([A-Za-z0-9_]+)")
|
||||
DEFAULT_SHORTHAND_RE = re.compile(r"\bdefaultParameters\b")
|
||||
IMPORT_RE = re.compile(r"import\s*\{([^}]+)\}\s*from\s*['\"]([^'\"]+)['\"]")
|
||||
VAR_OBJ_RE_TEMPLATE = r"(?:export\s+)?const\s+{name}\b[^=]*=\s*\{{"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolModelSpec:
|
||||
tool_id: str
|
||||
params: dict[str, Any]
|
||||
param_types: dict[str, Any]
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _find_matching(text: str, start: int, open_char: str, close_char: str) -> int:
|
||||
depth = 0
|
||||
i = start
|
||||
in_str: str | None = None
|
||||
while i < len(text):
|
||||
ch = text[i]
|
||||
if in_str:
|
||||
if ch == "\\":
|
||||
i += 2
|
||||
continue
|
||||
if ch == in_str:
|
||||
in_str = None
|
||||
i += 1
|
||||
continue
|
||||
if ch in {"'", '"'}:
|
||||
in_str = ch
|
||||
elif ch == open_char:
|
||||
depth += 1
|
||||
elif ch == close_char:
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return i
|
||||
i += 1
|
||||
raise ParseError(f"Unmatched {open_char}{close_char} block")
|
||||
|
||||
|
||||
def _extract_block(text: str, pattern: str) -> str | None:
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return None
|
||||
brace_start = text.find("{", match.end() - 1)
|
||||
if brace_start == -1:
|
||||
return None
|
||||
brace_end = _find_matching(text, brace_start, "{", "}")
|
||||
return text[brace_start : brace_end + 1]
|
||||
|
||||
|
||||
def _split_top_level_items(obj_body: str) -> list[str]:
|
||||
items: list[str] = []
|
||||
depth_obj = depth_arr = 0
|
||||
in_str: str | None = None
|
||||
token_start = 0
|
||||
i = 0
|
||||
while i < len(obj_body):
|
||||
ch = obj_body[i]
|
||||
if in_str:
|
||||
if ch == "\\":
|
||||
i += 2
|
||||
continue
|
||||
if ch == in_str:
|
||||
in_str = None
|
||||
i += 1
|
||||
continue
|
||||
if ch in {"'", '"'}:
|
||||
in_str = ch
|
||||
elif ch == "{":
|
||||
depth_obj += 1
|
||||
elif ch == "}":
|
||||
depth_obj -= 1
|
||||
elif ch == "[":
|
||||
depth_arr += 1
|
||||
elif ch == "]":
|
||||
depth_arr -= 1
|
||||
elif ch == "," and depth_obj == 0 and depth_arr == 0:
|
||||
piece = obj_body[token_start:i].strip()
|
||||
if piece:
|
||||
items.append(piece)
|
||||
token_start = i + 1
|
||||
i += 1
|
||||
tail = obj_body[token_start:].strip()
|
||||
if tail:
|
||||
items.append(tail)
|
||||
return items
|
||||
|
||||
|
||||
def _resolve_import_path(repo_root: Path, current_file: Path, module_path: str) -> Path | None:
|
||||
candidates: list[Path] = []
|
||||
if module_path.startswith("@app/"):
|
||||
rel = module_path[len("@app/") :]
|
||||
candidates.extend(
|
||||
[
|
||||
repo_root / "frontend/src/core" / f"{rel}.ts",
|
||||
repo_root / "frontend/src/core" / f"{rel}.tsx",
|
||||
repo_root / "frontend/src/saas" / f"{rel}.ts",
|
||||
repo_root / "frontend/src/saas" / f"{rel}.tsx",
|
||||
repo_root / "frontend/src" / f"{rel}.ts",
|
||||
repo_root / "frontend/src" / f"{rel}.tsx",
|
||||
]
|
||||
)
|
||||
elif module_path.startswith("."):
|
||||
base = (current_file.parent / module_path).resolve()
|
||||
candidates.extend([Path(f"{base}.ts"), Path(f"{base}.tsx")])
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _parse_literal_value(value: str, resolver: Callable[[str], dict[str, Any] | None]) -> Any:
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
if value.startswith("{") and value.endswith("}"):
|
||||
return _parse_object_literal(value, resolver)
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
inner = value[1:-1].strip()
|
||||
if not inner:
|
||||
return []
|
||||
return [_parse_literal_value(item, resolver) for item in _split_top_level_items(inner)]
|
||||
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
|
||||
return value[1:-1]
|
||||
if value in {"true", "false"}:
|
||||
return value == "true"
|
||||
if value == "null":
|
||||
return None
|
||||
if re.fullmatch(r"-?\d+", value):
|
||||
return int(value)
|
||||
if re.fullmatch(r"-?\d+\.\d+", value):
|
||||
return float(value)
|
||||
resolved = resolver(value)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
return None
|
||||
|
||||
|
||||
def _parse_object_literal(obj_text: str, resolver: Callable[[str], dict[str, Any] | None]) -> dict[str, Any]:
|
||||
body = obj_text.strip()[1:-1]
|
||||
result: dict[str, Any] = {}
|
||||
for item in _split_top_level_items(body):
|
||||
if item.startswith("..."):
|
||||
spread_name = item[3:].strip()
|
||||
spread = resolver(spread_name)
|
||||
if isinstance(spread, dict):
|
||||
result.update(spread)
|
||||
continue
|
||||
if ":" not in item:
|
||||
continue
|
||||
key, raw_value = item.split(":", 1)
|
||||
key = key.strip().strip("'\"")
|
||||
result[key] = _parse_literal_value(raw_value.strip(), resolver)
|
||||
return result
|
||||
|
||||
|
||||
def _extract_imports(source: str) -> dict[str, str]:
|
||||
imports: dict[str, str] = {}
|
||||
for names, module_path in IMPORT_RE.findall(source):
|
||||
for part in names.split(","):
|
||||
segment = part.strip()
|
||||
if not segment:
|
||||
continue
|
||||
if " as " in segment:
|
||||
original, alias = [x.strip() for x in segment.split(" as ", 1)]
|
||||
imports[alias] = module_path
|
||||
imports[original] = module_path
|
||||
else:
|
||||
imports[segment] = module_path
|
||||
return imports
|
||||
|
||||
|
||||
def _resolve_object_identifier(repo_root: Path, file_path: Path, source: str, identifier: str) -> dict[str, Any] | None:
|
||||
var_pattern = VAR_OBJ_RE_TEMPLATE.format(name=re.escape(identifier))
|
||||
block = _extract_block(source, var_pattern)
|
||||
imports = _extract_imports(source)
|
||||
|
||||
def resolver(name: str) -> dict[str, Any] | None:
|
||||
local_block = _extract_block(source, VAR_OBJ_RE_TEMPLATE.format(name=re.escape(name)))
|
||||
if local_block:
|
||||
return _parse_object_literal(local_block, resolver)
|
||||
import_path = imports.get(name)
|
||||
if not import_path:
|
||||
return None
|
||||
resolved_file = _resolve_import_path(repo_root, file_path, import_path)
|
||||
if not resolved_file:
|
||||
return None
|
||||
imported_source = resolved_file.read_text(encoding="utf-8")
|
||||
return _resolve_object_identifier(repo_root, resolved_file, imported_source, name)
|
||||
|
||||
if block:
|
||||
return _parse_object_literal(block, resolver)
|
||||
import_path = imports.get(identifier)
|
||||
if not import_path:
|
||||
return None
|
||||
resolved_file = _resolve_import_path(repo_root, file_path, import_path)
|
||||
if not resolved_file:
|
||||
return None
|
||||
imported_source = resolved_file.read_text(encoding="utf-8")
|
||||
return _resolve_object_identifier(repo_root, resolved_file, imported_source, identifier)
|
||||
|
||||
|
||||
def _infer_py_type(value: Any) -> str:
|
||||
if isinstance(value, bool):
|
||||
return "bool"
|
||||
if isinstance(value, int):
|
||||
return "int"
|
||||
if isinstance(value, float):
|
||||
return "float"
|
||||
if isinstance(value, str):
|
||||
return "str"
|
||||
if isinstance(value, list):
|
||||
return "list[Any]"
|
||||
if isinstance(value, dict):
|
||||
return "dict[str, Any]"
|
||||
return "Any"
|
||||
|
||||
|
||||
def _spec_is_none(spec: dict[str, Any]) -> bool:
|
||||
return spec.get("kind") == "null"
|
||||
|
||||
|
||||
def _py_type_from_spec(spec: dict[str, Any]) -> str:
|
||||
kind = spec.get("kind")
|
||||
if kind == "string":
|
||||
return "str"
|
||||
if kind == "number":
|
||||
return "float"
|
||||
if kind == "boolean":
|
||||
return "bool"
|
||||
if kind == "date":
|
||||
return "str"
|
||||
if kind == "enum":
|
||||
values = spec.get("values")
|
||||
if isinstance(values, list) and values:
|
||||
literal_values = ", ".join(_py_repr(v) for v in values)
|
||||
return f"Literal[{literal_values}]"
|
||||
if kind == "ref":
|
||||
ref_name = spec.get("name")
|
||||
if isinstance(ref_name, str) and ref_name.endswith("Parameters"):
|
||||
return f"{ref_name[:-10]}Params"
|
||||
if kind == "array":
|
||||
element = spec.get("element")
|
||||
inner = _py_type_from_spec(element) if isinstance(element, dict) else "Any"
|
||||
return f"list[{inner}]"
|
||||
if kind == "object":
|
||||
dict_value = spec.get("dictValue")
|
||||
if isinstance(dict_value, dict):
|
||||
inner = _py_type_from_spec(dict_value)
|
||||
return f"dict[str, {inner}]"
|
||||
properties = spec.get("properties")
|
||||
if isinstance(properties, dict) and properties:
|
||||
property_types = {_py_type_from_spec(p) for p in properties.values() if isinstance(p, dict)}
|
||||
if len(property_types) == 1:
|
||||
inner = next(iter(property_types))
|
||||
return f"dict[str, {inner}]"
|
||||
return "dict[str, Any]"
|
||||
if kind in {"null"}:
|
||||
return "Any"
|
||||
return "Any"
|
||||
|
||||
|
||||
def _to_class_name(tool_id: str) -> str:
|
||||
cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1 \2", tool_id)
|
||||
cleaned = re.sub(r"[^A-Za-z0-9]+", " ", cleaned)
|
||||
parts = [part.capitalize() for part in cleaned.split() if part]
|
||||
return "".join(parts) + "Params"
|
||||
|
||||
|
||||
def _to_snake_case(name: str) -> str:
|
||||
snake = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name)
|
||||
snake = re.sub(r"[^A-Za-z0-9]+", "_", snake).strip("_").lower()
|
||||
if not snake:
|
||||
snake = "param"
|
||||
if snake[0].isdigit():
|
||||
snake = f"param_{snake}"
|
||||
if keyword.iskeyword(snake):
|
||||
snake = f"{snake}_"
|
||||
return snake
|
||||
|
||||
|
||||
def _build_field_name_map(params: dict[str, Any]) -> dict[str, str]:
|
||||
field_map: dict[str, str] = {}
|
||||
used: set[str] = set()
|
||||
for original_key in sorted(params):
|
||||
base_name = _to_snake_case(original_key)
|
||||
candidate = base_name
|
||||
suffix = 2
|
||||
while candidate in used:
|
||||
candidate = f"{base_name}_{suffix}"
|
||||
suffix += 1
|
||||
used.add(candidate)
|
||||
field_map[original_key] = candidate
|
||||
return field_map
|
||||
|
||||
|
||||
def _to_enum_member_name(tool_id: str) -> str:
|
||||
return _to_snake_case(tool_id).upper()
|
||||
|
||||
|
||||
def _build_enum_member_map(specs: list[ToolModelSpec]) -> dict[str, str]:
|
||||
member_map: dict[str, str] = {}
|
||||
used: set[str] = set()
|
||||
for spec in specs:
|
||||
base_name = _to_enum_member_name(spec.tool_id)
|
||||
candidate = base_name
|
||||
suffix = 2
|
||||
while candidate in used:
|
||||
candidate = f"{base_name}_{suffix}"
|
||||
suffix += 1
|
||||
used.add(candidate)
|
||||
member_map[spec.tool_id] = candidate
|
||||
return member_map
|
||||
|
||||
|
||||
def _py_repr(value: Any) -> str:
|
||||
return (
|
||||
json.dumps(value, ensure_ascii=True).replace("true", "True").replace("false", "False").replace("null", "None")
|
||||
)
|
||||
|
||||
|
||||
def discover_tool_specs(repo_root: Path) -> list[ToolModelSpec]:
|
||||
frontend_dir = repo_root / "frontend"
|
||||
extractor = frontend_dir / "scripts/export-tool-specs.ts"
|
||||
command = ["node", "--import", "tsx", str(extractor)]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(frontend_dir),
|
||||
)
|
||||
raw = json.loads(result.stdout)
|
||||
specs: list[ToolModelSpec] = []
|
||||
for item in raw:
|
||||
tool_id = item.get("tool_id")
|
||||
if not isinstance(tool_id, str) or not tool_id:
|
||||
continue
|
||||
params = item.get("params")
|
||||
param_types = item.get("param_types")
|
||||
specs.append(
|
||||
ToolModelSpec(
|
||||
tool_id=tool_id,
|
||||
params=params if isinstance(params, dict) else {},
|
||||
param_types=param_types if isinstance(param_types, dict) else {},
|
||||
)
|
||||
)
|
||||
return sorted(specs, key=lambda spec: spec.tool_id)
|
||||
|
||||
|
||||
def write_models_module(out_path: Path, specs: list[ToolModelSpec]) -> None:
|
||||
lines: list[str] = [
|
||||
TOOL_MODELS_HEADER,
|
||||
"from __future__ import annotations\n\n",
|
||||
"from enum import StrEnum\n",
|
||||
"from typing import Any, Literal\n\n",
|
||||
"from models.base import ApiModel\n",
|
||||
]
|
||||
|
||||
class_names: dict[str, str] = {spec.tool_id: _to_class_name(spec.tool_id) for spec in specs}
|
||||
class_name_to_tool_id = {name: tool_id for tool_id, name in class_names.items()}
|
||||
|
||||
def extract_class_dependencies(spec: ToolModelSpec) -> set[str]:
|
||||
deps: set[str] = set()
|
||||
if not isinstance(spec.param_types, dict):
|
||||
return deps
|
||||
for entry in spec.param_types.values():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
type_spec = entry
|
||||
if "type" in entry and isinstance(entry.get("type"), dict):
|
||||
type_spec = entry["type"]
|
||||
if not isinstance(type_spec, dict):
|
||||
continue
|
||||
if type_spec.get("kind") != "ref":
|
||||
continue
|
||||
ref_name = type_spec.get("name")
|
||||
if isinstance(ref_name, str) and ref_name.endswith("Parameters"):
|
||||
ref_class = f"{ref_name[:-10]}Params"
|
||||
if ref_class in class_name_to_tool_id:
|
||||
deps.add(ref_class)
|
||||
return deps
|
||||
|
||||
dependencies_by_class: dict[str, set[str]] = {}
|
||||
for spec in specs:
|
||||
class_name = class_names[spec.tool_id]
|
||||
dependencies_by_class[class_name] = extract_class_dependencies(spec)
|
||||
|
||||
remaining = set(class_names.values())
|
||||
ordered_class_names: list[str] = []
|
||||
while remaining:
|
||||
progress = False
|
||||
for class_name in sorted(remaining):
|
||||
deps = dependencies_by_class.get(class_name, set())
|
||||
if deps.issubset(set(ordered_class_names)):
|
||||
ordered_class_names.append(class_name)
|
||||
remaining.remove(class_name)
|
||||
progress = True
|
||||
break
|
||||
if not progress:
|
||||
ordered_class_names.extend(sorted(remaining))
|
||||
break
|
||||
|
||||
ordered_specs = [next(spec for spec in specs if class_names[spec.tool_id] == name) for name in ordered_class_names]
|
||||
|
||||
for spec in ordered_specs:
|
||||
class_name = class_names[spec.tool_id]
|
||||
lines.append(f"class {class_name}(ApiModel):\n")
|
||||
all_param_keys = set(spec.params)
|
||||
if isinstance(spec.param_types, dict):
|
||||
all_param_keys.update(spec.param_types.keys())
|
||||
|
||||
if not all_param_keys:
|
||||
lines.append(" pass\n\n\n")
|
||||
continue
|
||||
|
||||
field_name_map = _build_field_name_map({key: True for key in all_param_keys})
|
||||
for key in sorted(all_param_keys):
|
||||
field_name = field_name_map[key]
|
||||
value = spec.params.get(key)
|
||||
type_spec = spec.param_types.get(key) if isinstance(spec.param_types, dict) else None
|
||||
if isinstance(type_spec, dict):
|
||||
py_type = _py_type_from_spec(type_spec)
|
||||
else:
|
||||
py_type = _infer_py_type(value)
|
||||
|
||||
if value is None and (isinstance(type_spec, dict) and _spec_is_none(type_spec)):
|
||||
if py_type != "Any" and "| None" not in py_type:
|
||||
py_type = f"{py_type} | None"
|
||||
lines.append(f" {field_name}: {py_type} = None\n")
|
||||
elif value is None:
|
||||
lines.append(f" {field_name}: {py_type} | None = None\n")
|
||||
else:
|
||||
if isinstance(type_spec, dict) and type_spec.get("kind") == "ref" and isinstance(value, dict):
|
||||
lines.append(f" {field_name}: {py_type} = {py_type}.model_validate({_py_repr(value)})\n")
|
||||
continue
|
||||
lines.append(f" {field_name}: {py_type} = {_py_repr(value)}\n")
|
||||
lines.append("\n\n")
|
||||
|
||||
if class_names:
|
||||
union_members = " | ".join(class_names[tool_id] for tool_id in sorted(class_names))
|
||||
lines.append(f"type ParamToolModel = {union_members}\n")
|
||||
lines.append("type ParamToolModelType = type[ParamToolModel]\n\n")
|
||||
else:
|
||||
lines.append("type ParamToolModel = ApiModel\n")
|
||||
lines.append("type ParamToolModelType = type[ParamToolModel]\n\n")
|
||||
|
||||
enum_member_map = _build_enum_member_map(specs)
|
||||
|
||||
lines.append("class OperationId(StrEnum):\n")
|
||||
|
||||
for spec in specs:
|
||||
lines.append(f" {enum_member_map[spec.tool_id]} = {spec.tool_id!r}\n")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"\n\n",
|
||||
"OPERATIONS: dict[OperationId, ParamToolModelType] = {\n",
|
||||
]
|
||||
)
|
||||
|
||||
for spec in specs:
|
||||
model_name = _to_class_name(spec.tool_id)
|
||||
lines.append(f" OperationId.{enum_member_map[spec.tool_id]}: {model_name},\n")
|
||||
lines.append("}\n")
|
||||
out_path.write_text("".join(lines), encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Generate tool models from frontend TypeScript tool definitions")
|
||||
parser.add_argument("--spec", help="Deprecated (ignored)", default="")
|
||||
parser.add_argument("--output", default="", help="Path to tool_models.py")
|
||||
parser.add_argument("--ai-output", default="", help="Deprecated (ignored)")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
specs = discover_tool_specs(repo_root)
|
||||
|
||||
output_path = Path(args.output) if args.output else (repo_root / "docgen/backend/models/tool_models.py")
|
||||
|
||||
write_models_module(output_path, specs)
|
||||
print(f"Wrote {len(specs)} tool model specs")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
engine/scripts/setup_env.py
Normal file
48
engine/scripts/setup_env.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Copies .env from .env.example if missing, and errors if any keys from the example
|
||||
are absent from the actual .env file.
|
||||
|
||||
Usage:
|
||||
uv run scripts/setup_env.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import dotenv_values
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
EXAMPLE_FILE = ROOT / "config" / ".env.example"
|
||||
ENV_FILE = ROOT / ".env"
|
||||
|
||||
print("setup-env: see engine/config/.env.example for documentation")
|
||||
|
||||
if not EXAMPLE_FILE.exists():
|
||||
print(f"setup-env: {EXAMPLE_FILE.name} not found, skipping", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if not ENV_FILE.exists():
|
||||
shutil.copy(EXAMPLE_FILE, ENV_FILE)
|
||||
print("setup-env: created .env from .env.example")
|
||||
|
||||
env_keys = set(dotenv_values(ENV_FILE).keys()) | set(os.environ.keys())
|
||||
example_keys = set(dotenv_values(EXAMPLE_FILE).keys())
|
||||
missing = sorted(example_keys - env_keys)
|
||||
|
||||
if missing:
|
||||
sys.exit(
|
||||
"setup-env: .env is missing keys from .env.example:\n"
|
||||
+ "\n".join(f" {k}" for k in missing)
|
||||
+ "\n Add them manually or delete your local .env to re-copy from config/.env.example."
|
||||
)
|
||||
|
||||
extra = sorted(k for k in dotenv_values(ENV_FILE) if k.startswith("STIRLING_") and k not in example_keys)
|
||||
if extra:
|
||||
print(
|
||||
"setup-env: .env contains STIRLING_ keys not in config/.env.example:\n"
|
||||
+ "\n".join(f" {k}" for k in extra)
|
||||
+ "\n Add them to config/.env.example if they are intentional.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
Reference in New Issue
Block a user