diff --git a/src/codegen/gscli/backend/typestub_utils.py b/src/codegen/gscli/backend/typestub_utils.py index 4bab38674..d8bd51482 100644 --- a/src/codegen/gscli/backend/typestub_utils.py +++ b/src/codegen/gscli/backend/typestub_utils.py @@ -3,6 +3,7 @@ import re from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor +from typing import TypeVar, Union import astor @@ -10,6 +11,9 @@ logger = get_logger(__name__) +# Define a type variable for AST nodes +ASTNode = TypeVar("ASTNode", ast.FunctionDef, ast.AnnAssign, ast.Assign) + class MethodRemover(ast.NodeTransformer): def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]): @@ -44,7 +48,7 @@ def should_remove(self, node: ast.FunctionDef | ast.AnnAssign) -> bool: class FieldRemover(ast.NodeTransformer): - def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]): + def __init__(self, conditions: list[Callable[[Union[ast.AnnAssign, ast.Assign]], bool]]): self.conditions = conditions def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: @@ -79,20 +83,22 @@ def _remove_methods(source: str, conditions: list[Callable[[ast.FunctionDef], bo return astor.to_source(modified_tree) -def _remove_fields(source: str, conditions: list[Callable[[ast.FunctionDef], bool]]) -> str: +def _remove_fields(source: str, conditions: list[Callable[[Union[ast.AnnAssign, ast.Assign]], bool]]) -> str: tree = ast.parse(source) transformer = FieldRemover(conditions) modified_tree = transformer.visit(tree) return astor.to_source(modified_tree) -def _starts_with_underscore(node: ast.FunctionDef | ast.AnnAssign | ast.Assign) -> bool: +def _starts_with_underscore(node: Union[ast.FunctionDef, ast.AnnAssign, ast.Assign]) -> bool: if isinstance(node, ast.FunctionDef): return node.name.startswith("_") and (not node.name.startswith("__") and not node.name.endswith("__")) elif isinstance(node, ast.Assign): - return node.targets[0].id.startswith("_") + if isinstance(node.targets[0], ast.Name): + return node.targets[0].id.startswith("_") elif isinstance(node, ast.AnnAssign): - return node.target.id.startswith("_") + if isinstance(node.target, ast.Name): + return node.target.id.startswith("_") return False @@ -121,7 +127,9 @@ def _strip_internal_symbols(file: str, root: str) -> None: _has_decorator("noapidoc"), ] - modified_content = _remove_fields(original_content, [_starts_with_underscore]) + # Type cast _starts_with_underscore to the correct type for _remove_fields + field_condition = _starts_with_underscore + modified_content = _remove_fields(original_content, [field_condition]) modified_content = _remove_methods(modified_content, conditions) if modified_content.strip().endswith(":"): diff --git a/src/codegen/gscli/generate/commands.py b/src/codegen/gscli/generate/commands.py index 3f96419ff..8a90d0a40 100644 --- a/src/codegen/gscli/generate/commands.py +++ b/src/codegen/gscli/generate/commands.py @@ -2,6 +2,8 @@ import os import re import shutil +import sys +from typing import Any import click from termcolor import colored @@ -66,7 +68,7 @@ def _generate_codebase_typestubs() -> None: # right now this command expects you to run it from here if not initial_dir.endswith("codegen/codegen-backend"): print(colored("Error: Must be in a directory ending with 'codegen/codegen-backend'", "red")) - exit(1) + sys.exit(1) out_dir = os.path.abspath(os.path.join(initial_dir, "typings")) frontend_typestubs_dir = os.path.abspath(os.path.join(initial_dir, os.pardir, "codegen-frontend/assets/typestubs/graphsitter")) @@ -113,6 +115,7 @@ def generate_docs(docs_dir: str) -> None: @generate.command() @click.argument("filepath", default=sdk.__path__[0] + "/system-prompt.txt", required=False) def system_prompt(filepath: str) -> None: + """Generate the system prompt and write it to the specified file""" print(f"Generating system prompt and writing to {filepath}...") new_system_prompt = get_system_prompt() with open(filepath, "w") as f: @@ -121,6 +124,7 @@ def system_prompt(filepath: str) -> None: def get_snippet_pattern(target_name: str) -> str: + """Generate a regex pattern to match code snippets with the given target name""" pattern = rf"\[//\]: # \(--{re.escape(target_name)}--\)\s*(?:\[//\]: # \(--{re.escape(AUTO_GENERATED_COMMENT)}--\)\s*)?" pattern += CODE_SNIPPETS_REGEX return pattern @@ -153,9 +157,9 @@ def generate_codegen_sdk_docs(docs_dir: str) -> None: # Write the generated docs to the file system, splitting between core, python, and typescript # keep track of where we put each one so we can update the mint.json - python_set = set() - typescript_set = set() - core_set = set() + python_set: set[str] = set() + typescript_set: set[str] = set() + core_set: set[str] = set() # TODO replace this with new `get_mdx_for_class` function for class_doc in gs_docs.classes: class_name = class_doc.title @@ -178,7 +182,7 @@ def generate_codegen_sdk_docs(docs_dir: str) -> None: # Update the core, python, and typescript page sets in mint.json mint_file_path = os.path.join(docs_dir, "mint.json") with open(mint_file_path) as mint_file: - mint_data = json.load(mint_file) + mint_data: dict[str, Any] = json.load(mint_file) # Find the "Codebase SDK" group where we want to add the pages codebase_sdk_group = next(group for group in mint_data["navigation"] if group["group"] == "API Reference") diff --git a/src/codegen/gscli/generate/runner_imports.py b/src/codegen/gscli/generate/runner_imports.py index d07b86062..e05fde711 100644 --- a/src/codegen/gscli/generate/runner_imports.py +++ b/src/codegen/gscli/generate/runner_imports.py @@ -51,9 +51,9 @@ def get_generated_imports(): ) -def fix_ruff_imports(objects: list[DocumentedObject]): +def fix_ruff_imports(objects: list[DocumentedObject]) -> None: root, _ = split_git_path(str(Path(__file__))) - to_add = [] + to_add: list[str] = [] for obj in objects: to_add.append(f"{obj.module}.{obj.name}") generics = tomlkit.array() @@ -66,12 +66,12 @@ def fix_ruff_imports(objects: list[DocumentedObject]): config.write_text(tomlkit.dumps(toml_config)) -def get_runner_imports(include_codegen=True, include_private_imports: bool = True) -> str: +def get_runner_imports(include_codegen: bool = True, include_private_imports: bool = True) -> str: # get the imports from the apidoc, py_apidoc, and ts_apidoc gs_objects = get_documented_objects() gs_public_objects = list(chain(gs_objects["apidoc"], gs_objects["py_apidoc"], gs_objects["ts_apidoc"])) fix_ruff_imports(gs_public_objects) - gs_public_imports = {f"from {obj.module} import {obj.name}" for obj in gs_public_objects} + gs_public_imports: set[str] = {f"from {obj.module} import {obj.name}" for obj in gs_public_objects} # construct import string with all imports ret = IMPORT_STRING_TEMPLATE.format( diff --git a/src/codegen/gscli/generate/system_prompt.py b/src/codegen/gscli/generate/system_prompt.py index 33b4a18a5..988a3afa4 100644 --- a/src/codegen/gscli/generate/system_prompt.py +++ b/src/codegen/gscli/generate/system_prompt.py @@ -1,27 +1,35 @@ import json from pathlib import Path +from typing import Any, Optional docs = Path("./docs") -mint = json.load(open(docs / "mint.json")) +mint: dict[str, Any] = json.load(open(docs / "mint.json")) -def render_page(page_str: str): +def render_page(page_str: str) -> str: + """Render a single page from the docs""" return open(docs / (page_str + ".mdx")).read() -def render_group(page_strs: list[str]): +def render_group(page_strs: list[str]) -> str: + """Render a group of pages from the docs""" return "\n\n".join([render_page(x) for x in page_strs]) -def get_group(name) -> list[str]: +def get_group(name: str) -> Optional[list[str]]: + """Get a group of pages by name from the mint.json file""" group = next((x for x in mint["navigation"] if x.get("group") == name), None) if group: return group["pages"] + return None def render_groups(group_names: list[str]) -> str: + """Render multiple groups of pages from the docs""" groups = [get_group(x) for x in group_names] - return "\n\n".join([render_group(g) for g in groups]) + # Filter out None values + filtered_groups = [g for g in groups if g is not None] + return "\n\n".join([render_group(g) for g in filtered_groups]) def get_system_prompt() -> str: diff --git a/src/codegen/gscli/generate/utils.py b/src/codegen/gscli/generate/utils.py index d579f9288..af053a8e2 100644 --- a/src/codegen/gscli/generate/utils.py +++ b/src/codegen/gscli/generate/utils.py @@ -12,11 +12,11 @@ class LanguageType(StrEnum): BOTH = "BOTH" -def generate_builtins_file(path_to_builtins: str, language_type: LanguageType): +def generate_builtins_file(path_to_builtins: str, language_type: LanguageType) -> None: """Generates and writes the builtins file""" documented_imports = get_documented_objects() all_objects = chain(documented_imports["apidoc"], documented_imports["py_apidoc"], documented_imports["ts_apidoc"]) - unique_imports = {f"from {obj.module} import {obj.name} as {obj.name}" for obj in all_objects} + unique_imports: set[str] = {f"from {obj.module} import {obj.name} as {obj.name}" for obj in all_objects} all_imports = "\n".join(sorted(unique_imports)) # TODO: re-use code with runner_imports list # TODO: also auto generate import string for CodemodContext + MessageType