|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from __future__ import annotations |
| 15 | +from typing import Sequence |
| 16 | +import ast |
| 17 | +""" |
| 18 | +Entrypoint for initiating an async -> sync conversion using CrossSync |
| 19 | +
|
| 20 | +Finds all python files rooted in a given directory, and uses |
| 21 | +transformers.CrossSyncFileProcessor to handle any files marked with |
| 22 | +__CROSS_SYNC_OUTPUT__ |
| 23 | +""" |
| 24 | + |
| 25 | + |
| 26 | +def extract_header_comments(file_path) -> str: |
| 27 | + """ |
| 28 | + Extract the file header. Header is defined as the top-level |
| 29 | + comments before any code or imports |
| 30 | + """ |
| 31 | + header = [] |
| 32 | + with open(file_path, "r") as f: |
| 33 | + for line in f: |
| 34 | + if line.startswith("#") or line.strip() == "": |
| 35 | + header.append(line) |
| 36 | + else: |
| 37 | + break |
| 38 | + header.append("\n# This file is automatically generated by CrossSync. Do not edit manually.\n\n") |
| 39 | + return "".join(header) |
| 40 | + |
| 41 | + |
| 42 | +class CrossSyncOutputFile: |
| 43 | + |
| 44 | + def __init__(self, output_path: str, ast_tree, header: str | None = None): |
| 45 | + self.output_path = output_path |
| 46 | + self.tree = ast_tree |
| 47 | + self.header = header or "" |
| 48 | + |
| 49 | + def render(self, with_formatter=True, save_to_disk: bool = True) -> str: |
| 50 | + """ |
| 51 | + Render the file to a string, and optionally save to disk |
| 52 | +
|
| 53 | + Args: |
| 54 | + with_formatter: whether to run the output through black before returning |
| 55 | + save_to_disk: whether to write the output to the file path |
| 56 | + """ |
| 57 | + full_str = self.header + ast.unparse(self.tree) |
| 58 | + if with_formatter: |
| 59 | + import black # type: ignore |
| 60 | + import autoflake # type: ignore |
| 61 | + |
| 62 | + full_str = black.format_str( |
| 63 | + autoflake.fix_code(full_str, remove_all_unused_imports=True), |
| 64 | + mode=black.FileMode(), |
| 65 | + ) |
| 66 | + if save_to_disk: |
| 67 | + import os |
| 68 | + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) |
| 69 | + with open(self.output_path, "w") as f: |
| 70 | + f.write(full_str) |
| 71 | + return full_str |
| 72 | + |
| 73 | + |
| 74 | +def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: |
| 75 | + import glob |
| 76 | + from transformers import CrossSyncFileProcessor |
| 77 | + |
| 78 | + # find all python files in the directory |
| 79 | + files = glob.glob(directory + "/**/*.py", recursive=True) |
| 80 | + # keep track of the output files pointed to by the annotated classes |
| 81 | + artifacts: set[CrossSyncOutputFile] = set() |
| 82 | + file_transformer = CrossSyncFileProcessor() |
| 83 | + # run each file through ast transformation to find all annotated classes |
| 84 | + for file_path in files: |
| 85 | + ast_tree = ast.parse(open(file_path).read()) |
| 86 | + output_path = file_transformer.get_output_path(ast_tree) |
| 87 | + if output_path is not None: |
| 88 | + # contains __CROSS_SYNC_OUTPUT__ annotation |
| 89 | + converted_tree = file_transformer.visit(ast_tree) |
| 90 | + header = extract_header_comments(file_path) |
| 91 | + artifacts.add(CrossSyncOutputFile(output_path, converted_tree, header)) |
| 92 | + # return set of output artifacts |
| 93 | + return artifacts |
| 94 | + |
| 95 | + |
| 96 | +def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): |
| 97 | + for a in artifacts: |
| 98 | + a.render(save_to_disk=True) |
| 99 | + |
| 100 | + |
| 101 | +if __name__ == "__main__": |
| 102 | + import sys |
| 103 | + |
| 104 | + search_root = sys.argv[1] |
| 105 | + outputs = convert_files_in_dir(search_root) |
| 106 | + print(f"Generated {len(outputs)} artifacts: {[a.output_path for a in outputs]}") |
| 107 | + save_artifacts(outputs) |
0 commit comments