#!/usr/bin/env python3 from __future__ import annotations import argparse import json import keyword import re import subprocess from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from typing import Any TOOL_MODELS_HEADER = """# AUTO-GENERATED FILE. DO NOT EDIT. # Generated by scripts/generate_tool_models.py from frontend TypeScript sources. # ruff: noqa: N815 """ OPERATION_TYPE_RE = re.compile(r"operationType\s*:\s*['\"]([A-Za-z0-9_]+)['\"]") DEFAULT_REF_RE = re.compile(r"defaultParameters\s*:\s*([A-Za-z0-9_]+)") DEFAULT_SHORTHAND_RE = re.compile(r"\bdefaultParameters\b") IMPORT_RE = re.compile(r"import\s*\{([^}]+)\}\s*from\s*['\"]([^'\"]+)['\"]") VAR_OBJ_RE_TEMPLATE = r"(?:export\s+)?const\s+{name}\b[^=]*=\s*\{{" @dataclass class ToolModelSpec: tool_id: str params: dict[str, Any] param_types: dict[str, Any] class ParseError(Exception): pass def _find_matching(text: str, start: int, open_char: str, close_char: str) -> int: depth = 0 i = start in_str: str | None = None while i < len(text): ch = text[i] if in_str: if ch == "\\": i += 2 continue if ch == in_str: in_str = None i += 1 continue if ch in {"'", '"'}: in_str = ch elif ch == open_char: depth += 1 elif ch == close_char: depth -= 1 if depth == 0: return i i += 1 raise ParseError(f"Unmatched {open_char}{close_char} block") def _extract_block(text: str, pattern: str) -> str | None: match = re.search(pattern, text) if not match: return None brace_start = text.find("{", match.end() - 1) if brace_start == -1: return None brace_end = _find_matching(text, brace_start, "{", "}") return text[brace_start : brace_end + 1] def _split_top_level_items(obj_body: str) -> list[str]: items: list[str] = [] depth_obj = depth_arr = 0 in_str: str | None = None token_start = 0 i = 0 while i < len(obj_body): ch = obj_body[i] if in_str: if ch == "\\": i += 2 continue if ch == in_str: in_str = None i += 1 continue if ch in {"'", '"'}: in_str = ch elif ch == "{": depth_obj += 1 elif ch == "}": depth_obj -= 1 elif ch == "[": depth_arr += 1 elif ch == "]": depth_arr -= 1 elif ch == "," and depth_obj == 0 and depth_arr == 0: piece = obj_body[token_start:i].strip() if piece: items.append(piece) token_start = i + 1 i += 1 tail = obj_body[token_start:].strip() if tail: items.append(tail) return items def _resolve_import_path(repo_root: Path, current_file: Path, module_path: str) -> Path | None: candidates: list[Path] = [] if module_path.startswith("@app/"): rel = module_path[len("@app/") :] candidates.extend( [ repo_root / "frontend/src/core" / f"{rel}.ts", repo_root / "frontend/src/core" / f"{rel}.tsx", repo_root / "frontend/src/saas" / f"{rel}.ts", repo_root / "frontend/src/saas" / f"{rel}.tsx", repo_root / "frontend/src" / f"{rel}.ts", repo_root / "frontend/src" / f"{rel}.tsx", ] ) elif module_path.startswith("."): base = (current_file.parent / module_path).resolve() candidates.extend([Path(f"{base}.ts"), Path(f"{base}.tsx")]) for candidate in candidates: if candidate.exists(): return candidate return None def _parse_literal_value(value: str, resolver: Callable[[str], dict[str, Any] | None]) -> Any: value = value.strip() if not value: return None if value.startswith("{") and value.endswith("}"): return _parse_object_literal(value, resolver) if value.startswith("[") and value.endswith("]"): inner = value[1:-1].strip() if not inner: return [] return [_parse_literal_value(item, resolver) for item in _split_top_level_items(inner)] if value.startswith(("'", '"')) and value.endswith(("'", '"')): return value[1:-1] if value in {"true", "false"}: return value == "true" if value == "null": return None if re.fullmatch(r"-?\d+", value): return int(value) if re.fullmatch(r"-?\d+\.\d+", value): return float(value) resolved = resolver(value) if resolved is not None: return resolved return None def _parse_object_literal(obj_text: str, resolver: Callable[[str], dict[str, Any] | None]) -> dict[str, Any]: body = obj_text.strip()[1:-1] result: dict[str, Any] = {} for item in _split_top_level_items(body): if item.startswith("..."): spread_name = item[3:].strip() spread = resolver(spread_name) if isinstance(spread, dict): result.update(spread) continue if ":" not in item: continue key, raw_value = item.split(":", 1) key = key.strip().strip("'\"") result[key] = _parse_literal_value(raw_value.strip(), resolver) return result def _extract_imports(source: str) -> dict[str, str]: imports: dict[str, str] = {} for names, module_path in IMPORT_RE.findall(source): for part in names.split(","): segment = part.strip() if not segment: continue if " as " in segment: original, alias = [x.strip() for x in segment.split(" as ", 1)] imports[alias] = module_path imports[original] = module_path else: imports[segment] = module_path return imports def _resolve_object_identifier(repo_root: Path, file_path: Path, source: str, identifier: str) -> dict[str, Any] | None: var_pattern = VAR_OBJ_RE_TEMPLATE.format(name=re.escape(identifier)) block = _extract_block(source, var_pattern) imports = _extract_imports(source) def resolver(name: str) -> dict[str, Any] | None: local_block = _extract_block(source, VAR_OBJ_RE_TEMPLATE.format(name=re.escape(name))) if local_block: return _parse_object_literal(local_block, resolver) import_path = imports.get(name) if not import_path: return None resolved_file = _resolve_import_path(repo_root, file_path, import_path) if not resolved_file: return None imported_source = resolved_file.read_text(encoding="utf-8") return _resolve_object_identifier(repo_root, resolved_file, imported_source, name) if block: return _parse_object_literal(block, resolver) import_path = imports.get(identifier) if not import_path: return None resolved_file = _resolve_import_path(repo_root, file_path, import_path) if not resolved_file: return None imported_source = resolved_file.read_text(encoding="utf-8") return _resolve_object_identifier(repo_root, resolved_file, imported_source, identifier) def _infer_py_type(value: Any) -> str: if isinstance(value, bool): return "bool" if isinstance(value, int): return "int" if isinstance(value, float): return "float" if isinstance(value, str): return "str" if isinstance(value, list): return "list[Any]" if isinstance(value, dict): return "dict[str, Any]" return "Any" def _spec_is_none(spec: dict[str, Any]) -> bool: return spec.get("kind") == "null" def _py_type_from_spec(spec: dict[str, Any]) -> str: kind = spec.get("kind") if kind == "string": return "str" if kind == "number": return "float" if kind == "boolean": return "bool" if kind == "date": return "str" if kind == "enum": values = spec.get("values") if isinstance(values, list) and values: literal_values = ", ".join(_py_repr(v) for v in values) return f"Literal[{literal_values}]" if kind == "ref": ref_name = spec.get("name") if isinstance(ref_name, str) and ref_name.endswith("Parameters"): return f"{ref_name[:-10]}Params" if kind == "array": element = spec.get("element") inner = _py_type_from_spec(element) if isinstance(element, dict) else "Any" return f"list[{inner}]" if kind == "object": dict_value = spec.get("dictValue") if isinstance(dict_value, dict): inner = _py_type_from_spec(dict_value) return f"dict[str, {inner}]" properties = spec.get("properties") if isinstance(properties, dict) and properties: property_types = {_py_type_from_spec(p) for p in properties.values() if isinstance(p, dict)} if len(property_types) == 1: inner = next(iter(property_types)) return f"dict[str, {inner}]" return "dict[str, Any]" if kind in {"null"}: return "Any" return "Any" def _to_class_name(tool_id: str) -> str: cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1 \2", tool_id) cleaned = re.sub(r"[^A-Za-z0-9]+", " ", cleaned) parts = [part.capitalize() for part in cleaned.split() if part] return "".join(parts) + "Params" def _to_snake_case(name: str) -> str: snake = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name) snake = re.sub(r"[^A-Za-z0-9]+", "_", snake).strip("_").lower() if not snake: snake = "param" if snake[0].isdigit(): snake = f"param_{snake}" if keyword.iskeyword(snake): snake = f"{snake}_" return snake def _build_field_name_map(params: dict[str, Any]) -> dict[str, str]: field_map: dict[str, str] = {} used: set[str] = set() for original_key in sorted(params): base_name = _to_snake_case(original_key) candidate = base_name suffix = 2 while candidate in used: candidate = f"{base_name}_{suffix}" suffix += 1 used.add(candidate) field_map[original_key] = candidate return field_map def _to_enum_member_name(tool_id: str) -> str: return _to_snake_case(tool_id).upper() def _build_enum_member_map(specs: list[ToolModelSpec]) -> dict[str, str]: member_map: dict[str, str] = {} used: set[str] = set() for spec in specs: base_name = _to_enum_member_name(spec.tool_id) candidate = base_name suffix = 2 while candidate in used: candidate = f"{base_name}_{suffix}" suffix += 1 used.add(candidate) member_map[spec.tool_id] = candidate return member_map def _py_repr(value: Any) -> str: return ( json.dumps(value, ensure_ascii=True).replace("true", "True").replace("false", "False").replace("null", "None") ) def discover_tool_specs(repo_root: Path) -> list[ToolModelSpec]: frontend_dir = repo_root / "frontend" extractor = frontend_dir / "scripts/export-tool-specs.ts" command = ["node", "--import", "tsx", str(extractor)] result = subprocess.run( command, check=True, capture_output=True, text=True, cwd=str(frontend_dir), ) raw = json.loads(result.stdout) specs: list[ToolModelSpec] = [] for item in raw: tool_id = item.get("tool_id") if not isinstance(tool_id, str) or not tool_id: continue params = item.get("params") param_types = item.get("param_types") specs.append( ToolModelSpec( tool_id=tool_id, params=params if isinstance(params, dict) else {}, param_types=param_types if isinstance(param_types, dict) else {}, ) ) return sorted(specs, key=lambda spec: spec.tool_id) def write_models_module(out_path: Path, specs: list[ToolModelSpec]) -> None: lines: list[str] = [ TOOL_MODELS_HEADER, "from __future__ import annotations\n\n", "from enum import StrEnum\n", "from typing import Any, Literal\n\n", "from models.base import ApiModel\n", ] class_names: dict[str, str] = {spec.tool_id: _to_class_name(spec.tool_id) for spec in specs} class_name_to_tool_id = {name: tool_id for tool_id, name in class_names.items()} def extract_class_dependencies(spec: ToolModelSpec) -> set[str]: deps: set[str] = set() if not isinstance(spec.param_types, dict): return deps for entry in spec.param_types.values(): if not isinstance(entry, dict): continue type_spec = entry if "type" in entry and isinstance(entry.get("type"), dict): type_spec = entry["type"] if not isinstance(type_spec, dict): continue if type_spec.get("kind") != "ref": continue ref_name = type_spec.get("name") if isinstance(ref_name, str) and ref_name.endswith("Parameters"): ref_class = f"{ref_name[:-10]}Params" if ref_class in class_name_to_tool_id: deps.add(ref_class) return deps dependencies_by_class: dict[str, set[str]] = {} for spec in specs: class_name = class_names[spec.tool_id] dependencies_by_class[class_name] = extract_class_dependencies(spec) remaining = set(class_names.values()) ordered_class_names: list[str] = [] while remaining: progress = False for class_name in sorted(remaining): deps = dependencies_by_class.get(class_name, set()) if deps.issubset(set(ordered_class_names)): ordered_class_names.append(class_name) remaining.remove(class_name) progress = True break if not progress: ordered_class_names.extend(sorted(remaining)) break ordered_specs = [next(spec for spec in specs if class_names[spec.tool_id] == name) for name in ordered_class_names] for spec in ordered_specs: class_name = class_names[spec.tool_id] lines.append(f"class {class_name}(ApiModel):\n") all_param_keys = set(spec.params) if isinstance(spec.param_types, dict): all_param_keys.update(spec.param_types.keys()) if not all_param_keys: lines.append(" pass\n\n\n") continue field_name_map = _build_field_name_map({key: True for key in all_param_keys}) for key in sorted(all_param_keys): field_name = field_name_map[key] value = spec.params.get(key) type_spec = spec.param_types.get(key) if isinstance(spec.param_types, dict) else None if isinstance(type_spec, dict): py_type = _py_type_from_spec(type_spec) else: py_type = _infer_py_type(value) if value is None and (isinstance(type_spec, dict) and _spec_is_none(type_spec)): if py_type != "Any" and "| None" not in py_type: py_type = f"{py_type} | None" lines.append(f" {field_name}: {py_type} = None\n") elif value is None: lines.append(f" {field_name}: {py_type} | None = None\n") else: if isinstance(type_spec, dict) and type_spec.get("kind") == "ref" and isinstance(value, dict): lines.append(f" {field_name}: {py_type} = {py_type}.model_validate({_py_repr(value)})\n") continue lines.append(f" {field_name}: {py_type} = {_py_repr(value)}\n") lines.append("\n\n") if class_names: union_members = " | ".join(class_names[tool_id] for tool_id in sorted(class_names)) lines.append(f"type ParamToolModel = {union_members}\n") lines.append("type ParamToolModelType = type[ParamToolModel]\n\n") else: lines.append("type ParamToolModel = ApiModel\n") lines.append("type ParamToolModelType = type[ParamToolModel]\n\n") enum_member_map = _build_enum_member_map(specs) lines.append("class OperationId(StrEnum):\n") for spec in specs: lines.append(f" {enum_member_map[spec.tool_id]} = {spec.tool_id!r}\n") lines.extend( [ "\n\n", "OPERATIONS: dict[OperationId, ParamToolModelType] = {\n", ] ) for spec in specs: model_name = _to_class_name(spec.tool_id) lines.append(f" OperationId.{enum_member_map[spec.tool_id]}: {model_name},\n") lines.append("}\n") out_path.write_text("".join(lines), encoding="utf-8") def main() -> None: parser = argparse.ArgumentParser(description="Generate tool models from frontend TypeScript tool definitions") parser.add_argument("--spec", help="Deprecated (ignored)", default="") parser.add_argument("--output", default="", help="Path to tool_models.py") parser.add_argument("--ai-output", default="", help="Deprecated (ignored)") args = parser.parse_args() repo_root = Path(__file__).resolve().parents[3] specs = discover_tool_specs(repo_root) output_path = Path(args.output) if args.output else (repo_root / "docgen/backend/models/tool_models.py") write_models_module(output_path, specs) print(f"Wrote {len(specs)} tool model specs") if __name__ == "__main__": main()