Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.265.0

### Notable Changes
* Separate generated classes between jobs and pipelines in Python support ([#3428](https://github.com/databricks/cli/pull/3428))

### Dependency updates
* Upgrade TF provider to 1.87.0 ([#3430](https://github.com/databricks/cli/pull/3430))
Expand Down
20 changes: 13 additions & 7 deletions experimental/python/codegen/codegen/generated_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ class GeneratedDataclass:


def generate_field(
namespace: str,
field_name: str,
prop: Property,
is_required: bool,
) -> GeneratedField:
field_type = generate_type(prop.ref, is_param=False)
param_type = generate_type(prop.ref, is_param=True)
field_type = generate_type(namespace, prop.ref, is_param=False)
param_type = generate_type(namespace, prop.ref, is_param=True)

field_type = variable_or_type(field_type, is_required=is_required)
param_type = variable_or_type(param_type, is_required=is_required)
Expand Down Expand Up @@ -255,10 +256,11 @@ def variable_or_dict_type(element_type: GeneratedType) -> GeneratedType:
)


def generate_type(ref: str, is_param: bool) -> GeneratedType:
def generate_type(namespace: str, ref: str, is_param: bool) -> GeneratedType:
if ref.startswith("#/$defs/slice/"):
element_ref = ref.replace("#/$defs/slice/", "#/$defs/")
element_type = generate_type(
namespace=namespace,
ref=element_ref,
is_param=is_param,
)
Expand All @@ -273,7 +275,7 @@ def generate_type(ref: str, is_param: bool) -> GeneratedType:
return dict_type()

class_name = packages.get_class_name(ref)
package = packages.get_package(ref)
package = packages.get_package(namespace, ref)

if is_param and package:
class_name += "Param"
Expand All @@ -293,20 +295,24 @@ def resource_type() -> GeneratedType:
)


def generate_dataclass(schema_name: str, schema: Schema) -> GeneratedDataclass:
def generate_dataclass(
namespace: str,
schema_name: str,
schema: Schema,
) -> GeneratedDataclass:
print(f"Generating dataclass for {schema_name}")

fields = list[GeneratedField]()
class_name = packages.get_class_name(schema_name)

for name, prop in schema.properties.items():
is_required = name in schema.required
field = generate_field(name, prop, is_required=is_required)
field = generate_field(namespace, name, prop, is_required=is_required)

fields.append(field)

extends = []
package = packages.get_package(schema_name)
package = packages.get_package(namespace, schema_name)

assert package

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ class Bar:

# see also _append_resolve_recursive_imports

models["jobs.ForEachTask"] = _quote_recursive_references_for_model(
models["jobs.ForEachTask"],
references={"Task", "TaskParam"},
)
if "jobs.ForEachTask" in models:
models["jobs.ForEachTask"] = _quote_recursive_references_for_model(
models["jobs.ForEachTask"],
references={"Task", "TaskParam"},
)


def _quote_recursive_references_for_model(
Expand Down
4 changes: 2 additions & 2 deletions experimental/python/codegen/codegen/generated_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class GeneratedEnum:
experimental: bool


def generate_enum(schema_name: str, schema: Schema) -> GeneratedEnum:
def generate_enum(namespace: str, schema_name: str, schema: Schema) -> GeneratedEnum:
assert schema.enum

class_name = packages.get_class_name(schema_name)
package = packages.get_package(schema_name)
package = packages.get_package(namespace, schema_name)
values = {}

assert package
Expand Down
12 changes: 6 additions & 6 deletions experimental/python/codegen/codegen/generated_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def append_enum_imports(
enums: dict[str, GeneratedEnum],
exclude_packages: list[str],
) -> None:
for schema_name in enums.keys():
package = packages.get_package(schema_name)
class_name = packages.get_class_name(schema_name)
for generated in enums.values():
package = generated.package
class_name = generated.class_name

if package in exclude_packages:
continue
Expand All @@ -26,9 +26,9 @@ def append_dataclass_imports(
dataclasses: dict[str, GeneratedDataclass],
exclude_packages: list[str],
) -> None:
for schema_name in dataclasses.keys():
package = packages.get_package(schema_name)
class_name = packages.get_class_name(schema_name)
for generated in dataclasses.values():
package = generated.package
class_name = generated.class_name

if package in exclude_packages:
continue
Expand Down
39 changes: 20 additions & 19 deletions experimental/python/codegen/codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ def main(output: str):
schemas = _remove_deprecated_fields(schemas)
schemas = _remove_unused_schemas(packages.RESOURCE_TYPES, schemas)

dataclasses, enums = _generate_code(schemas)

generated_dataclass_patch.reorder_required_fields(dataclasses)
generated_dataclass_patch.quote_recursive_references(dataclasses)
# each resource has own namespace and is generated separately so
# that there are no dependencies between namespaces as in Databricks SDK v1
for resource, namespace in packages.RESOURCE_NAMESPACE.items():
# only generate code for schemas used directly or transitively by resource
reachable = _collect_reachable_schemas([resource], schemas)
reachable_schemas = {k: v for k, v in schemas.items() if k in reachable}

_write_code(dataclasses, enums, output)
dataclasses, enums = _generate_code(namespace, reachable_schemas)

for resource in packages.RESOURCE_TYPES:
reachable = _collect_reachable_schemas([resource], schemas)
generated_dataclass_patch.reorder_required_fields(dataclasses)
generated_dataclass_patch.quote_recursive_references(dataclasses)

resource_dataclasses = {k: v for k, v in dataclasses.items() if k in reachable}
resource_enums = {k: v for k, v in enums.items() if k in reachable}
_write_code(dataclasses, enums, output)

_write_exports(resource, resource_dataclasses, resource_enums, output)
_write_exports(namespace, dataclasses, enums, output)


def _transitively_mark_deprecated_and_private(
Expand Down Expand Up @@ -95,18 +96,21 @@ def _remove_deprecated_fields(


def _generate_code(
namespace: str,
schemas: dict[str, openapi.Schema],
) -> tuple[dict[str, GeneratedDataclass], dict[str, GeneratedEnum]]:
dataclasses = {}
enums = {}

for schema_name, schema in schemas.items():
if schema.type == openapi.SchemaType.OBJECT:
generated = generated_dataclass.generate_dataclass(schema_name, schema)
generated = generated_dataclass.generate_dataclass(
namespace, schema_name, schema
)

dataclasses[schema_name] = generated
elif schema.type == openapi.SchemaType.STRING:
generated = generated_enum.generate_enum(schema_name, schema)
generated = generated_enum.generate_enum(namespace, schema_name, schema)

enums[schema_name] = generated
else:
Expand All @@ -116,7 +120,7 @@ def _generate_code(


def _write_exports(
root: str,
namespace: str,
dataclasses: dict[str, GeneratedDataclass],
enums: dict[str, GeneratedEnum],
output: str,
Expand Down Expand Up @@ -148,14 +152,11 @@ def _write_exports(
generated_imports.append_enum_imports(b, enums, exclude_packages=[])

# FIXME should be better generalized
if root == "resources.Job":
if namespace == "jobs":
_append_resolve_recursive_imports(b)

root_package = packages.get_package(root)
assert root_package

# transform databricks.bundles.jobs._models.job -> databricks/bundles/jobs
package_path = Path(root_package.replace(".", "/")).parent.parent
root_package = packages.get_root_package(namespace)
package_path = Path(root_package.replace(".", "/"))

source_path = Path(output) / package_path / "__init__.py"
source_path.parent.mkdir(exist_ok=True, parents=True)
Expand Down
27 changes: 10 additions & 17 deletions experimental/python/codegen/codegen/packages.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import re
from typing import Optional

RESOURCE_NAMESPACE_OVERRIDE = {
# All supported resource types and their namespace
RESOURCE_NAMESPACE = {
"resources.Job": "jobs",
"resources.Pipeline": "pipelines",
"resources.JobPermission": "jobs",
"resources.JobPermissionLevel": "jobs",
"resources.PipelinePermission": "pipelines",
"resources.PipelinePermissionLevel": "pipelines",
}

# All supported resource types
RESOURCE_TYPES = [
"resources.Job",
"resources.Pipeline",
]
RESOURCE_TYPES = list(RESOURCE_NAMESPACE.keys())

# Namespaces to load from OpenAPI spec.
#
Expand Down Expand Up @@ -72,7 +65,11 @@ def should_load_ref(ref: str) -> bool:
return name in PRIMITIVES


def get_package(ref: str) -> Optional[str]:
def get_root_package(namespace: str) -> str:
return f"databricks.bundles.{namespace}"


def get_package(namespace: str, ref: str) -> Optional[str]:
"""
Returns Python package for a given OpenAPI ref.
Returns None for builtin types.
Expand All @@ -83,11 +80,7 @@ def get_package(ref: str) -> Optional[str]:
if full_name in PRIMITIVES:
return None

[namespace, name] = full_name.split(".")

if override := RESOURCE_NAMESPACE_OVERRIDE.get(full_name):
namespace = override

[_, name] = full_name.split(".")
package_name = re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()

return f"databricks.bundles.{namespace}._models.{package_name}"
return f"{get_root_package(namespace)}._models.{package_name}"
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

def test_generate_type_string():
generated_type = generate_type(
"#/$defs/string",
namespace="jobs",
ref="#/$defs/string",
is_param=False,
)

Expand All @@ -27,7 +28,8 @@ def test_generate_type_string():

def test_generate_type_dict():
generated_type = generate_type(
"#/$defs/map/string",
namespace="jobs",
ref="#/$defs/map/string",
is_param=False,
)

Expand All @@ -36,6 +38,7 @@ def test_generate_type_dict():

def test_generate_dataclass():
generated = generate_dataclass(
namespace="bananas",
schema_name="jobs.Task",
schema=Schema(
type=SchemaType.OBJECT,
Expand All @@ -52,7 +55,7 @@ def test_generate_dataclass():

assert generated == GeneratedDataclass(
class_name="Task",
package="databricks.bundles.jobs._models.task",
package="databricks.bundles.bananas._models.task",
description="task description",
extends=[],
fields=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

def test_generate_enum():
generated = generate_enum(
namespace="bananas",
schema_name="jobs.MyEnum",
schema=Schema(
enum=["myEnumValue"],
Expand All @@ -14,7 +15,7 @@ def test_generate_enum():

assert generated == GeneratedEnum(
class_name="MyEnum",
package="databricks.bundles.jobs._models.my_enum",
package="databricks.bundles.bananas._models.my_enum",
values={"MY_ENUM_VALUE": "myEnumValue"},
description="enum description",
experimental=False,
Expand Down
Loading
Loading