diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 2cc3f4ce2e..02fe69380d 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -168,3 +168,28 @@ jobs: for file in ./bundle/internal/schema/testdata/fail/*.yml; do ajv test -s schema.json -d $file --invalid -c=./keywords.js done + + validate-python-codegen: + needs: cleanups + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Install uv + uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1 + with: + version: "0.6.5" + + - name: Verify that python/codegen is up to date + working-directory: experimental/python + run: | + make codegen + + if ! ( git diff --exit-code ); then + echo "Generated Python code is not up-to-date. Please run 'pushd experimental/python && make codegen' and commit the changes." + + # TODO block PR if this fails once diffs are fixed + # exit 1 + fi diff --git a/experimental/python/Makefile b/experimental/python/Makefile index 1878aab767..5270be3103 100644 --- a/experimental/python/Makefile +++ b/experimental/python/Makefile @@ -16,6 +16,15 @@ lint: uv run pyright uv run ruff format --diff +codegen: + find databricks -name _models | xargs rm -rf + + cd codegen; uv run -m pytest codegen_tests + cd codegen; uv run -m codegen.main --output .. + + uv run ruff check --fix $(sources) || true + uv run ruff format + test: uv run python -m pytest databricks_tests --cov=databricks.bundles --cov-report html -vv @@ -23,4 +32,4 @@ build: rm -rf build dist uv build . -.PHONY: fmt docs lint test build +.PHONY: fmt docs lint codegen test build diff --git a/experimental/python/codegen/codegen/code_builder.py b/experimental/python/codegen/codegen/code_builder.py new file mode 100644 index 0000000000..7e01acc2cc --- /dev/null +++ b/experimental/python/codegen/codegen/code_builder.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + + +class CodeBuilder: + def __init__(self): + self._code = "" + + def append(self, *args: str) -> "Self": + for arg in args: + self._code += arg + + return self + + def indent(self): + return self.append(" ") + + def newline(self) -> "Self": + return self.append("\n") + + def append_list(self, args: list[str], sep: str = ",") -> "Self": + return self.append(sep.join(args)) + + def append_dict(self, args: dict[str, str], sep: str = ",") -> "Self": + return self.append_list([f"{k}={v}" for k, v in args.items()], sep) + + def append_triple_quote(self) -> "Self": + return self.append('"""') + + def append_repr(self, value) -> "Self": + return self.append(repr(value)) + + def build(self): + return self._code diff --git a/experimental/python/codegen/codegen/generated_dataclass.py b/experimental/python/codegen/codegen/generated_dataclass.py new file mode 100644 index 0000000000..37c81a75a2 --- /dev/null +++ b/experimental/python/codegen/codegen/generated_dataclass.py @@ -0,0 +1,479 @@ +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + +import codegen.packages as packages +from codegen.code_builder import CodeBuilder +from codegen.jsonschema import Property, Schema +from codegen.packages import is_resource + + +@dataclass +class GeneratedType: + """ + GeneratedType is a type that can be used in GeneratedField. + + GeneratedType is self-recursive, so it can represent complex types like lists of dataclasses. + """ + + name: str + """ + The name of the type, e.g., "Task" + """ + + package: Optional[str] + """ + The package of the type, e.g., "databricks.bundles.jobs._models.task". + + If type is builtin, package is None. + """ + + parameters: list["Self"] + """ + Parameters of the type, e.g., for list[str]: + + GeneratedType( + name="list", + parameters=[ + GeneratedType(name="str"), + ], + ) + + """ + + +@dataclass +class GeneratedField: + """ + GeneratedField is a field in GeneratedDataclass. + """ + + field_name: str + """ + The name of the field, e.g., "task_key" + """ + + type_name: GeneratedType + """ + The type of the field, e.g., GeneratedType(name="Task", ...) + """ + + param_type_name: GeneratedType + """ + The type of the field in TypedDict, e.g., GeneratedType(name="TaskParam", ...) + """ + + create_func_type_name: GeneratedType + """ + Type type of the field in static "create" function, e.g., GeneratedType(name="TaskParam", ...) + + It can be different from param_type_name because lists are made optional in "create" function + to avoid problems with mutable default arguments. + """ + + description: Optional[str] + """ + The description of the field to be included into a docstring. + """ + + default: Optional[str] + """ + The default value of the field, e.g., "None" + """ + + create_func_default: Optional[str] + """ + The default value of the field in "create" function. + + It can be different from default because lists are made optional in "create" function + to avoid problems with mutable default arguments. + """ + + default_factory: Optional[str] + """ + Factory method for creating a default value, used for lists and dicts. + """ + + def __post_init__(self): + if self.default_factory is not None and self.default is not None: + raise ValueError("Can't have both default and default_factory", self) + + +@dataclass +class GeneratedDataclass: + """ + GeneratedDataclass represents a dataclass to be generated. + """ + + class_name: str + """ + The name of the dataclass, e.g., "Task". + """ + + package: str + """ + Package of the dataclass, e.g., "databricks.bundles.jobs._models.task". + """ + + description: Optional[str] + """ + The description of the dataclass to be included into a docstring. + """ + + fields: list[GeneratedField] + extends: list[GeneratedType] + + +def generate_field( + field_name: str, + prop: Property, + is_required: bool, +) -> GeneratedField: + field_type = generate_type(prop.ref, is_param=False) + param_type = generate_type(prop.ref, is_param=True) + + field_type = variable_or_type(field_type, is_required=is_required) + param_type = variable_or_type(param_type, is_required=is_required) + + if field_type.name == "VariableOrDict": + return GeneratedField( + field_name=field_name, + type_name=field_type, + param_type_name=param_type, + create_func_type_name=optional_type(param_type), + description=prop.description, + default=None, + default_factory="dict", + create_func_default="None", + ) + elif field_type.name == "VariableOrList": + return GeneratedField( + field_name=field_name, + type_name=field_type, + param_type_name=param_type, + create_func_type_name=optional_type(param_type), + description=prop.description, + default=None, + default_factory="list", + create_func_default="None", + ) + elif is_required: + return GeneratedField( + field_name=field_name, + type_name=field_type, + param_type_name=param_type, + create_func_type_name=param_type, + description=prop.description, + default=None, + default_factory=None, + create_func_default=None, + ) + else: + return GeneratedField( + field_name=field_name, + type_name=field_type, + param_type_name=param_type, + create_func_type_name=param_type, + description=prop.description, + default="None", + default_factory=None, + create_func_default="None", + ) + + +def optional_type(generated: GeneratedType) -> GeneratedType: + return GeneratedType( + name="Optional", + package="typing", + parameters=[generated], + ) + + +def str_type() -> GeneratedType: + return GeneratedType( + name="str", + package=None, + parameters=[], + ) + + +def dict_type() -> GeneratedType: + return GeneratedType( + name="dict", + package=None, + parameters=[str_type(), str_type()], + ) + + +def variable_or_type(type: GeneratedType, is_required: bool) -> GeneratedType: + if type.name == "list": + [param] = type.parameters + + return variable_or_list_type(param) + elif type.name == "dict": + [key_param, value_param] = type.parameters + + assert key_param.name == "str" + + return variable_or_dict_type(value_param) + else: + name = "VariableOr" if is_required else "VariableOrOptional" + + return GeneratedType( + name=name, + package="databricks.bundles.core", + parameters=[type], + ) + + +def variable_or_list_type(element_type: GeneratedType) -> GeneratedType: + return GeneratedType( + name="VariableOrList", + package="databricks.bundles.core", + parameters=[element_type], + ) + + +def variable_or_dict_type(element_type: GeneratedType) -> GeneratedType: + return GeneratedType( + name="VariableOrDict", + package="databricks.bundles.core", + parameters=[element_type], + ) + + +def generate_type(ref: str, is_param: bool) -> GeneratedType: + if ref.startswith("#/$defs/slice/"): + element_ref = ref.replace("#/$defs/slice/", "#/$defs/") + element_type = generate_type( + ref=element_ref, + is_param=is_param, + ) + + return GeneratedType( + name="list", + package=None, + parameters=[element_type], + ) + + if ref == "#/$defs/map/string": + return dict_type() + + class_name = packages.get_class_name(ref) + package = packages.get_package(ref) + + if is_param and package: + class_name += "Param" + + return GeneratedType( + name=class_name, + package=package, + parameters=[], + ) + + +def resource_type() -> GeneratedType: + return GeneratedType( + name="Resource", + package="databricks.bundles.core", + parameters=[], + ) + + +def generate_dataclass(schema_name: str, schema: Schema) -> GeneratedDataclass: + print(f"Generating dataclass for {schema_name}") + + fields = list[GeneratedField]() + class_name = packages.get_class_name(schema_name) + + for name, prop in schema.properties.items(): + is_required = name in schema.required + field = generate_field(name, prop, is_required=is_required) + + fields.append(field) + + extends = [] + package = packages.get_package(schema_name) + + assert package + + if is_resource(schema_name): + extends.append(resource_type()) + + return GeneratedDataclass( + class_name=class_name, + package=package, + description=schema.description, + fields=fields, + extends=extends, + ) + + +def _get_type_code(generated: GeneratedType, quote: bool = True) -> str: + if generated.parameters: + parameters = ", ".join( + map(lambda x: _get_type_code(x, quote), generated.parameters) + ) + + return f"{generated.name}[{parameters}]" + else: + if quote: + return '"' + generated.name + '"' + else: + return generated.name + + +def _append_dataclass(b: CodeBuilder, generated: GeneratedDataclass): + # Example: + # + # @dataclass + # class Job(Resource): + # """docstring""" + + b.append("@dataclass(kw_only=True)") + + b.newline() + b.append("class ", generated.class_name) + + if generated.extends: + b.append("(") + b.append_list( + [_get_type_code(extend, quote=False) for extend in generated.extends] + ) + b.append(")") + + b.append(":").newline() + + # FIXME should contain class docstring + if not generated.description: + b.indent().append_triple_quote().append_triple_quote().newline().newline() + else: + _append_description(b, generated.description) + + +def _append_field(b: CodeBuilder, field: GeneratedField): + # Example: + # + # foo: list[str] = field(default_factory=list) + + b.indent().append(field.field_name).append(": ") + + # don't quote types because it breaks reflection + b.append(_get_type_code(field.type_name, quote=False)) + + if field.default_factory: + b.append(" = field(") + b.append_dict({"default_factory": field.default_factory}) + b.append(")") + elif field.default: + b.append(" = ") + b.append(field.default) + + b.newline() + + +def _append_from_dict(b: CodeBuilder, generated: GeneratedDataclass): + # Example: + # + # @classmethod + # def from_dict(cls, value: 'JobDict') -> 'Job': + # return _transform(cls, value) + + b.indent().append("@classmethod").newline() + + ( + b.indent() + .append("def from_dict(cls, value: ") + .append("'") + .append(generated.class_name + "Dict") + .append("'") + .append(") -> 'Self':") + .newline() + ) + + b.indent().indent().append("return _transform(cls, value)").newline() + b.newline() + + +def _append_as_dict(b: CodeBuilder, generated: GeneratedDataclass): + # Example: + # + # def as_dict(self) -> 'JobDict': + # return _transform_to_json_value(self) # type:ignore + # + + b.indent().append("def as_dict(self) -> '").append(generated.class_name).append( + "Dict':" + ).newline() + b.indent().indent().append( + "return _transform_to_json_value(self) # type:ignore", + ).newline() + b.newline() + + +def _append_typed_dict(b: CodeBuilder, generated: GeneratedDataclass): + # Example: + # + # class JobDict(TypedDict, total=False): + # """docstring""" + # + + b.append("class ").append(generated.class_name).append( + "Dict(TypedDict, total=False):" + ).newline() + + # FIXME should contain class description + b.indent().append_triple_quote().append_triple_quote().newline().newline() + + +def _append_description(b: CodeBuilder, description: Optional[str]): + if description: + b.indent().append_triple_quote().newline() + for line in description.split("\n"): + b.indent().append(line).newline() + b.indent().append_triple_quote().newline() + + +def _append_typed_dict_field(b: CodeBuilder, field: GeneratedField): + b.indent().append(field.field_name).append(": ") + b.append(_get_type_code(field.param_type_name, quote=False)) + b.newline() + + +def get_code(generated: GeneratedDataclass) -> str: + b = CodeBuilder() + + _append_dataclass(b, generated) + + for field in generated.fields: + _append_field(b, field) + _append_description(b, field.description) + + b.newline() + + _append_from_dict(b, generated) + _append_as_dict(b, generated) + + b.newline().newline() + + _append_typed_dict(b, generated) + + for field in generated.fields: + _append_typed_dict_field(b, field) + _append_description(b, field.description) + + b.newline() + + # Example: FooParam = FooDict | Foo + + b.newline() + b.append(generated.class_name).append("Param") + b.append(" = ") + b.append(generated.class_name).append("Dict") + b.append(" | ") + b.append(generated.class_name) + b.newline() + + return b.build() diff --git a/experimental/python/codegen/codegen/generated_dataclass_patch.py b/experimental/python/codegen/codegen/generated_dataclass_patch.py new file mode 100644 index 0000000000..50bcd3205a --- /dev/null +++ b/experimental/python/codegen/codegen/generated_dataclass_patch.py @@ -0,0 +1,78 @@ +from dataclasses import replace + +from codegen.generated_dataclass import ( + GeneratedDataclass, + GeneratedField, + GeneratedType, +) + + +def reorder_required_fields(models: dict[str, GeneratedDataclass]): + """ + Reorder fields in dataclasses so that required fields come first. + It's necessary for kwargs in the constructor to work correctly. + """ + for name, model in models.items(): + if not model.fields: + continue + + required_fields = [field for field in model.fields if _is_required(field)] + optional_fields = [field for field in model.fields if not _is_required(field)] + + models[name] = replace(model, fields=required_fields + optional_fields) + + +def quote_recursive_references(models: dict[str, GeneratedDataclass]): + """ + If there is a cycle between two dataclasses, we need to quote one of them. + + Example: + + class Foo: + bar: Optional[Bar] + + class Bar: + foo: "Foo" + """ + + # see also _append_resolve_recursive_imports + + models["jobs.ForEachTask"] = _quote_recursive_references_for_model( + models["jobs.ForEachTask"], + references={"Task", "TaskParam"}, + ) + + +def _quote_recursive_references_for_model( + model: GeneratedDataclass, + references: set[str], +) -> GeneratedDataclass: + def update_type_name(type_name: GeneratedType): + if type_name.name in references: + return replace( + type_name, + name=f'"{type_name.name}"', + ) + elif type_name.parameters: + return replace( + type_name, + parameters=[update_type_name(param) for param in type_name.parameters], + ) + else: + return type_name + + def update_field(field: GeneratedField): + return replace( + field, + type_name=update_type_name(field.type_name), + param_type_name=update_type_name(field.param_type_name), + ) + + return replace( + model, + fields=[update_field(field) for field in model.fields], + ) + + +def _is_required(field: GeneratedField) -> bool: + return field.default is None and field.default_factory is None diff --git a/experimental/python/codegen/codegen/generated_enum.py b/experimental/python/codegen/codegen/generated_enum.py new file mode 100644 index 0000000000..0ea33aa898 --- /dev/null +++ b/experimental/python/codegen/codegen/generated_enum.py @@ -0,0 +1,76 @@ +import re +from dataclasses import dataclass +from typing import Optional + +import codegen.packages as packages +from codegen.code_builder import CodeBuilder +from codegen.generated_dataclass import _append_description +from codegen.jsonschema import Schema + + +@dataclass(kw_only=True) +class GeneratedEnum: + class_name: str + package: str + values: dict[str, str] + description: Optional[str] + + +def generate_enum(schema_name: str, schema: Schema) -> GeneratedEnum: + assert schema.enum + + class_name = packages.get_class_name(schema_name) + package = packages.get_package(schema_name) + values = {} + + assert package + + for value in schema.enum: + values[_camel_to_upper_snake(value)] = value + + return GeneratedEnum( + class_name=class_name, + package=package, + values=values, + description=schema.description, + ) + + +def get_code(generated: GeneratedEnum) -> str: + b = CodeBuilder() + + # Example: + # + # class Color(Enum): + # + b.append(f"class {generated.class_name}(Enum):") + b.newline() + + _append_description(b, generated.description) + + # Example: + # + # RED = "RED" + # + for key, value in generated.values.items(): + b.indent().append(f'{key} = "{value}"') + b.newline() + + b.newline() + + # Example: + # + # ColorParam = Literal["RED", "GREEN", "BLUE"] | Color + + b.append(generated.class_name).append('Param = Literal["') + b.append_list(list(generated.values.values()), sep='", "') + b.append('"] | ', generated.class_name) + b.newline() + + return b.build() + + +def _camel_to_upper_snake(value): + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", value) + + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).upper() diff --git a/experimental/python/codegen/codegen/generated_imports.py b/experimental/python/codegen/codegen/generated_imports.py new file mode 100644 index 0000000000..f9217b644f --- /dev/null +++ b/experimental/python/codegen/codegen/generated_imports.py @@ -0,0 +1,82 @@ +from textwrap import dedent + +import codegen.packages as packages +from codegen.code_builder import CodeBuilder +from codegen.generated_dataclass import GeneratedDataclass +from codegen.generated_enum import GeneratedEnum + + +def append_enum_imports( + b: CodeBuilder, + enums: dict[str, GeneratedEnum], + exclude_packages: list[str], +) -> None: + for schema_name in enums.keys(): + package = packages.get_package(schema_name) + class_name = packages.get_class_name(schema_name) + + if package in exclude_packages: + continue + + b.append(f"from {package} import {class_name}, {class_name}Param\n").newline() + + +def append_dataclass_imports( + b: CodeBuilder, + dataclasses: dict[str, GeneratedDataclass], + exclude_packages: list[str], +) -> None: + for schema_name in dataclasses.keys(): + package = packages.get_package(schema_name) + class_name = packages.get_class_name(schema_name) + + if package in exclude_packages: + continue + + b.append( + f"from {package} import {class_name}, {class_name}Dict, {class_name}Param" + ).newline() + + +def get_code( + dataclasses: dict[str, GeneratedDataclass], + enums: dict[str, GeneratedEnum], + typechecking_imports: dict[str, list[str]], + exclude_packages: list[str], +) -> str: + b = CodeBuilder() + + b.append( + "from typing import Literal, Optional, TypedDict, ClassVar, TYPE_CHECKING\n" + ) + b.append("from enum import Enum\n") + b.append("from dataclasses import dataclass, replace, field\n") + b.append("\n") + b.append("from databricks.bundles.core._resource import Resource\n") + b.append("from databricks.bundles.core._transform import _transform\n") + b.append( + "from databricks.bundles.core._transform_to_json import _transform_to_json_value\n" + ) + b.append( + "from databricks.bundles.core._variable import VariableOr, VariableOrOptional, VariableOrList, VariableOrDict\n" + ) + b.newline() + + runtime_dataclasses = { + k: v + for k, v in dataclasses.items() + if v.class_name not in typechecking_imports.get(v.package, []) + } + + append_dataclass_imports(b, runtime_dataclasses, exclude_packages) + append_enum_imports(b, enums, exclude_packages) + + # typechecking_imports is special case because it's only for TYPE_CHECKING + # and formatter doesn't eliminate unused imports for TYPE_CHECKING + if typechecking_imports: + b.newline() + b.append("if TYPE_CHECKING:").newline() + for package, imports in typechecking_imports.items(): + b.indent().append(f"from {package} import {', '.join(imports)}").newline() + + return b.build() diff --git a/experimental/python/codegen/codegen/jsonschema.py b/experimental/python/codegen/codegen/jsonschema.py new file mode 100644 index 0000000000..e230d6d0b3 --- /dev/null +++ b/experimental/python/codegen/codegen/jsonschema.py @@ -0,0 +1,153 @@ +import json +from pathlib import Path +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +import codegen.packages as packages + + +@dataclass +class Property: + ref: str + description: Optional[str] = None + + +class SchemaType(Enum): + OBJECT = "object" + STRING = "string" + + +@dataclass +class Schema: + type: SchemaType + enum: list[str] = field(default_factory=list) + properties: dict[str, Property] = field(default_factory=dict) + required: list[str] = field(default_factory=list) + description: Optional[str] = None + + def __post_init__(self): + match self.type: + case SchemaType.OBJECT: + assert not self.enum + + case SchemaType.STRING: + assert not self.properties + assert not self.required + assert self.enum + case _: + raise ValueError(f"Unknown type: {self.type}") + + for item in self.enum: + assert isinstance(item, str) + + for item in self.required: + assert isinstance(item, str) + + +@dataclass +class Spec: + schemas: dict[str, Schema] + + +def _unwrap_variable(schema: dict): + # we assume that each field can be a variable + + if anyOf := schema.get("anyOf") or schema.get("oneOf"): + if len(anyOf) != 2: + return None + + [primary, variable] = anyOf + + pattern = variable.get("pattern", "") + type = variable.get("type", "") + + if ( + type == "string" + and pattern.startswith("\\$\\{") + and pattern.endswith("\\}") + ): + return primary + + return None + + +def _parse_schema(schema: dict) -> Schema: + schema = _unwrap_variable(schema) or schema + properties = {} + + for k, v in schema.get("properties", {}).items(): + assert v.get("type") is None + assert v.get("anyOf") is None + assert v.get("properties") is None + assert v.get("items") is None + + assert v.get("$ref") + + prop = Property( + ref=v["$ref"], + description=v.get("description"), + ) + + properties[k] = prop + + assert schema.get("type") in [ + "object", + "string", + ], f"{schema} type not in ['object', 'string']" + + return Schema( + type=SchemaType(schema["type"]), + enum=schema.get("enum", []), + properties=properties, + required=schema.get("required", []), + description=schema.get("description"), + ) + + +def _load_spec() -> dict: + path = ( + Path(__file__).parent # ./experimental/python/codegen/codegen + / ".." # ./experimental/python/codegen + / ".." # ./experimental/python/ + / ".." # ./experimental + / ".." # ./ + / "./bundle/schema/jsonschema.json" + ) + + return json.load(path.open()) + + +def get_schemas(): + output = dict[str, Schema]() + spec = _load_spec() + + sdk_types_spec = _get_spec_path( + spec, + ["$defs", "github.com", "databricks", "databricks-sdk-go", "service"], + ) + resource_types_spec = _get_spec_path( + spec, + ["$defs", "github.com", "databricks", "cli", "bundle", "config"], + ) + + # we don't need all spec, only get supported types + flat_spec = {**sdk_types_spec, **resource_types_spec} + flat_spec = { + key: value for key, value in flat_spec.items() if packages.should_load_ref(key) + } + + for name, schema in flat_spec.items(): + try: + output[name] = _parse_schema(schema) + except Exception as e: + raise ValueError(f"Failed to parse schema for {name}") from e + + return output + + +def _get_spec_path(spec: dict, path: list[str]) -> dict: + for key in path: + spec = spec[key] + + return spec diff --git a/experimental/python/codegen/codegen/jsonschema_patch.py b/experimental/python/codegen/codegen/jsonschema_patch.py new file mode 100644 index 0000000000..1f0dfdf944 --- /dev/null +++ b/experimental/python/codegen/codegen/jsonschema_patch.py @@ -0,0 +1,92 @@ +from dataclasses import replace + +from codegen.jsonschema import Schema + +REMOVED_FIELDS = { + "jobs.RunJobTask": { + # all params except job_parameters should be deprecated and should not be supported + "jar_params", + "notebook_params", + "python_params", + "spark_submit_params", + "python_named_params", + "sql_params", + "dbt_commands", + # except pipeline_params, that is not deprecated + }, + "jobs.TriggerSettings": { + # Old table trigger settings name. Deprecated in favor of `table_update` + "table", + }, + "compute.ClusterSpec": { + # doesn't work, openapi schema needs to be updated to be enum + "kind", + }, + "jobs.TaskEmailNotifications": { + # Deprecated + "no_alert_for_skipped_runs", + }, + "jobs.SparkJarTask": { + # Deprecated. A value of `false` is no longer supported. + "run_as_repl", + # Deprecated + "jar_uri", + }, + "resources.Pipeline": { + # Deprecated + "trigger", + }, + "pipelines.PipelineLibrary": { + # Deprecated + "whl", + }, +} + +EXTRA_REQUIRED_FIELDS: dict[str, list[str]] = { + "jobs.SparkJarTask": ["main_class_name"], +} + + +def add_extra_required_fields(schemas: dict[str, Schema]): + output = {} + + for name, schema in schemas.items(): + if extra_required := EXTRA_REQUIRED_FIELDS.get(name): + new_required = [*schema.required, *extra_required] + new_required = list(set(new_required)) + + if set(new_required) == set(schema.required): + raise ValueError( + f"Extra required fields for {name} are already present in the schema" + ) + + new_schema = replace(schema, required=new_required) + + output[name] = new_schema + else: + output[name] = schema + + return output + + +def remove_unsupported_fields(schemas: dict[str, Schema]): + output = {} + + for name, schema in schemas.items(): + if removed_fields := REMOVED_FIELDS.get(name): + new_properties = { + field: prop + for field, prop in schema.properties.items() + if field not in removed_fields + } + + if new_properties.keys() == schema.properties.keys(): + raise ValueError(f"No fields to remove in schema {name}") + + new_schema = replace(schema, properties=new_properties) + + output[name] = new_schema + else: + output[name] = schema + + return output diff --git a/experimental/python/codegen/codegen/main.py b/experimental/python/codegen/codegen/main.py new file mode 100644 index 0000000000..735309ff4c --- /dev/null +++ b/experimental/python/codegen/codegen/main.py @@ -0,0 +1,255 @@ +import argparse +from pathlib import Path +from textwrap import dedent + +import codegen.generated_dataclass as generated_dataclass +import codegen.generated_dataclass_patch as generated_dataclass_patch +import codegen.generated_enum as generated_enum +import codegen.generated_imports as generated_imports +import codegen.jsonschema as openapi +import codegen.jsonschema_patch as openapi_patch +import codegen.packages as packages + +from codegen.code_builder import CodeBuilder +from codegen.generated_dataclass import GeneratedDataclass, GeneratedType +from codegen.generated_enum import GeneratedEnum + + +def main(output: str): + schemas = openapi.get_schemas() + schemas = openapi_patch.add_extra_required_fields(schemas) + schemas = openapi_patch.remove_unsupported_fields(schemas) + schemas = _remove_unused_schemas(packages.RESOURCE_TYPES, schemas) + + dataclasses, enums = _generate_code(schemas) + + generated_dataclass_patch.reorder_required_fields(dataclasses) + generated_dataclass_patch.quote_recursive_references(dataclasses) + + _write_code(dataclasses, enums, output) + + for resource in packages.RESOURCE_TYPES: + reachable = _collect_reachable_schemas([resource], schemas) + + resource_dataclasses = {k: v for k, v in dataclasses.items() if k in reachable} + resource_enums = {k: v for k, v in enums.items() if k in reachable} + + _write_exports(resource, resource_dataclasses, resource_enums, output) + + +def _generate_code( + schemas: dict[str, openapi.Schema], +) -> tuple[dict[str, GeneratedDataclass], dict[str, GeneratedEnum]]: + dataclasses = {} + enums = {} + + for schema_name, schema in schemas.items(): + if schema.type == openapi.SchemaType.OBJECT: + generated = generated_dataclass.generate_dataclass(schema_name, schema) + + dataclasses[schema_name] = generated + elif schema.type == openapi.SchemaType.STRING: + generated = generated_enum.generate_enum(schema_name, schema) + + enums[schema_name] = generated + else: + raise ValueError(f"Unknown type: {schema.type}") + + return dataclasses, enums + + +def _write_exports( + root: str, + dataclasses: dict[str, GeneratedDataclass], + enums: dict[str, GeneratedEnum], + output: str, +): + exports = [] + + for _, dataclass in dataclasses.items(): + exports += [ + dataclass.class_name, + f"{dataclass.class_name}Dict", + f"{dataclass.class_name}Param", + ] + + for _, enum in enums.items(): + exports += [enum.class_name, f"{enum.class_name}Param"] + + exports.sort() + + b = CodeBuilder() + + b.append("__all__ = [\n") + for export in exports: + b.indent().append_repr(export).append(",").newline() + b.append("]").newline() + b.newline() + b.newline() + + generated_imports.append_dataclass_imports(b, dataclasses, exclude_packages=[]) + generated_imports.append_enum_imports(b, enums, exclude_packages=[]) + + # FIXME should be better generalized + if root == "resources.Job": + _append_resolve_recursive_imports(b) + + root_package = packages.get_package(root) + assert root_package + + # transform databricks.bundles.jobs._models.job -> databricks/bundles/jobs + package_path = Path(root_package.replace(".", "/")).parent.parent + + source_path = Path(output) / package_path / "__init__.py" + source_path.parent.mkdir(exist_ok=True, parents=True) + source_path.write_text(b.build()) + + print(f"Writing exports into {source_path}") + + +def _append_resolve_recursive_imports(b: CodeBuilder): + """ + Resolve forward references for recursive imports so we can assume that there are no forward references + while inspecting type annotations. + """ + + b.append( + dedent(""" + def _resolve_recursive_imports(): + import typing + + from databricks.bundles.core._variable import VariableOr + from databricks.bundles.jobs._models.task import Task + + ForEachTask.__annotations__ = typing.get_type_hints( + ForEachTask, + globalns={"Task": Task, "VariableOr": VariableOr}, + ) + + _resolve_recursive_imports() + """) + ) + + +def _collect_typechecking_imports( + generated: GeneratedDataclass, +) -> dict[str, list[str]]: + out = {} + + def visit_type(type_name: GeneratedType): + if type_name.name.startswith('"'): + out[type_name.package] = out.get(type_name.package, []) + out[type_name.package].append(type_name.name.strip('"')) + + for parameter in type_name.parameters: + visit_type(parameter) + + for field in generated.fields: + visit_type(field.type_name) + visit_type(field.param_type_name) + + return out + + +def _collect_reachable_schemas( + roots: list[str], + schemas: dict[str, openapi.Schema], + include_private: bool = True, + include_deprecated: bool = True, +) -> set[str]: + """ + Remove schemas that are not reachable from the roots, because we + don't want to generate code for them. + """ + + reachable = set(packages.PRIMITIVES) + stack = [] + + for root in roots: + stack.append(root) + + while stack: + current = stack.pop() + if current in reachable: + continue + + reachable.add(current) + + schema = schemas[current] + + if schema.type == openapi.SchemaType.OBJECT: + for field in schema.properties.values(): + if field.ref: + name = field.ref.split("/")[-1] + + if name not in reachable: + stack.append(name) + + return reachable + + +def _remove_unused_schemas( + roots: list[str], + schemas: dict[str, openapi.Schema], +) -> dict[str, openapi.Schema]: + """ + Remove schemas that are not reachable from the roots, because we + don't want to generate code for them. + """ + + reachable = _collect_reachable_schemas(roots, schemas) + + return {k: v for k, v in schemas.items() if k in reachable} + + +def _write_code( + dataclasses: dict[str, GeneratedDataclass], + enums: dict[str, GeneratedEnum], + output: str, +): + package_code = {} + typechecking_imports = {} + + for schema_name, generated in dataclasses.items(): + package = generated.package + code = generated_dataclass.get_code(generated) + + typechecking_imports[package] = _collect_typechecking_imports(generated) + typechecking_imports[package]["typing_extensions"] = ["Self"] + + package_code[package] = package_code.get(package, "") + package_code[package] += "\n" + code + + for schema_name, generated in enums.items(): + package = generated.package + code = generated_enum.get_code(generated) + + package_code[package] = package_code.get(package, "") + package_code[package] += "\n" + code + + package_code = { + package: generated_imports.get_code( + dataclasses, + enums, + # don't import package from itself + exclude_packages=[package], + typechecking_imports=typechecking_imports.get(package, {}), + ) + + code + for package, code in package_code.items() + } + + for package, code in package_code.items(): + package_path = package.replace(".", "/") + source_path = Path(output) / (package_path + ".py") + + source_path.parent.mkdir(exist_ok=True, parents=True) + source_path.write_text(code) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output", type=str) + args = parser.parse_args() + + main(args.output) diff --git a/experimental/python/codegen/codegen/packages.py b/experimental/python/codegen/codegen/packages.py new file mode 100644 index 0000000000..b81e9cad4d --- /dev/null +++ b/experimental/python/codegen/codegen/packages.py @@ -0,0 +1,93 @@ +import re +from typing import Optional + +RESOURCE_NAMESPACE_OVERRIDE = { + "resources.Job": "jobs", + "resources.Pipeline": "pipelines", + "resources.JobPermission": "jobs", + "resources.JobPermissionLevel": "jobs", + "resources.PipelinePermission": "pipelines", + "resources.PipelinePermissionLevel": "pipelines", +} + +# All supported resource types +RESOURCE_TYPES = [ + "resources.Job", + "resources.Pipeline", +] + +# Namespaces to load from OpenAPI spec. +# +# We can't load all types because of errors while loading some of them. +LOADED_NAMESPACES = [ + "compute", + "jobs", + "pipelines", + "resources", +] + +RENAMES = { + "string": "str", + "boolean": "bool", + "integer": "int", + "number": "float", + "int64": "int", + "float64": "float", +} + +PRIMITIVES = [ + "string", + "boolean", + "integer", + "number", + "bool", + "int", + "int64", + "float64", +] + + +def get_class_name(ref: str) -> str: + name = ref.split("/")[-1] + name = name.split(".")[-1] + + return RENAMES.get(name, name) + + +def is_resource(ref: str) -> bool: + return ref in RESOURCE_TYPES + + +def should_load_ref(ref: str) -> bool: + name = ref.split("/")[-1] + + # FIXME doesn't work, looks like enum, but doesn't have any values specified + if name == "compute.Kind": + return False + + for namespace in LOADED_NAMESPACES: + if name.startswith(f"{namespace}."): + return True + + return name in PRIMITIVES + + +def get_package(ref: str) -> Optional[str]: + """ + Returns Python package for a given OpenAPI ref. + Returns None for builtin types. + """ + + full_name = ref.split("/")[-1] + + if full_name in PRIMITIVES: + return None + + [namespace, name] = full_name.split(".") + + if override := RESOURCE_NAMESPACE_OVERRIDE.get(full_name): + namespace = override + + package_name = re.sub(r"(?