Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions sqlmesh/utils/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
walk(base, base.__qualname__, is_metadata)

for k, v in obj.__dict__.items():
if k.startswith("__"):
# skip dunder methods bar __init__ as it might contain user defined logic with cross class references
if k.startswith("__") and k != "__init__":
continue

# Traverse methods in a class to find global references
Expand All @@ -362,10 +363,14 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
if callable(v):
# Walk the method if it's part of the object, else it's a global function and we just store it
if v.__qualname__.startswith(obj.__qualname__):
for k, v in func_globals(v).items():
walk(v, k, is_metadata)
else:
walk(v, v.__name__, is_metadata)
try:
for k, v in func_globals(v).items():
walk(v, k, is_metadata)
except (OSError, TypeError):
# __init__ may come from built-ins or wrapped callables
pass
else:
walk(v, k, is_metadata)
elif callable(obj):
for k, v in func_globals(obj).items():
walk(v, k, is_metadata)
Expand Down
7 changes: 6 additions & 1 deletion tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,14 +1506,19 @@ def test_requirements(copy_to_temp_path: t.Callable):
"dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True
).environment
requirements = {"ipywidgets", "numpy", "pandas", "test_package"}
if IS_WINDOWS:
requirements.add("pendulum")
assert environment.requirements["pandas"] == "2.2.2"
assert set(environment.requirements) == requirements

context._requirements = {"numpy": "2.1.2", "pandas": "2.2.1"}
context._excluded_requirements = {"ipywidgets", "ruamel.yaml", "ruamel.yaml.clib"}
diff = context.plan_builder("dev", skip_tests=True, skip_backfill=True).build().context_diff
assert set(diff.previous_requirements) == requirements
assert set(diff.requirements) == {"numpy", "pandas"}
reqs = {"numpy", "pandas"}
if IS_WINDOWS:
reqs.add("pendulum")
assert set(diff.requirements) == reqs


def test_deactivate_automatic_requirement_inference(copy_to_temp_path: t.Callable):
Expand Down
73 changes: 66 additions & 7 deletions tests/utils/test_metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,18 @@ class DataClass:
x: int


class ReferencedClass:
def __init__(self, value: int):
self.value = value

def get_value(self) -> int:
return self.value


class MyClass:
def __init__(self, x: int):
self.helper = ReferencedClass(x * 2)

@staticmethod
def foo():
return KLASS_X
Expand All @@ -95,6 +106,13 @@ def bar(cls):
def baz(self):
return KLASS_Z

def use_referenced(self, value: int) -> int:
ref = ReferencedClass(value)
return ref.get_value()

def compute_with_reference(self) -> int:
return self.helper.get_value() + 10


def other_func(a: int) -> int:
import sqlglot
Expand All @@ -103,7 +121,8 @@ def other_func(a: int) -> int:
pd.DataFrame([{"x": 1}])
to_table("y")
my_lambda() # type: ignore
return X + a + W
obj = MyClass(a)
return X + a + W + obj.compute_with_reference()


@contextmanager
Expand Down Expand Up @@ -131,7 +150,7 @@ def function_with_custom_decorator():
def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int:
"""DOC STRING"""
sqlglot.parse_one("1")
MyClass()
MyClass(47)
DataClass(x=y)
normalize_model_name("test" + SQLGLOT_META)
fetch_data()
Expand Down Expand Up @@ -177,6 +196,7 @@ def test_func_globals() -> None:
assert func_globals(other_func) == {
"X": 1,
"W": 0,
"MyClass": MyClass,
"my_lambda": my_lambda,
"pd": pd,
"to_table": to_table,
Expand All @@ -202,7 +222,7 @@ def test_normalize_source() -> None:
== """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
):
sqlglot.parse_one('1')
MyClass()
MyClass(47)
DataClass(x=y)
normalize_model_name('test' + SQLGLOT_META)
fetch_data()
Expand All @@ -223,7 +243,8 @@ def closure(z: int):
pd.DataFrame([{'x': 1}])
to_table('y')
my_lambda()
return X + a + W"""
obj = MyClass(a)
return X + a + W + obj.compute_with_reference()"""
)


Expand Down Expand Up @@ -252,7 +273,7 @@ def test_serialize_env() -> None:
payload="""def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
):
sqlglot.parse_one('1')
MyClass()
MyClass(47)
DataClass(x=y)
normalize_model_name('test' + SQLGLOT_META)
fetch_data()
Expand Down Expand Up @@ -295,6 +316,9 @@ class DataClass:
path="test_metaprogramming.py",
payload="""class MyClass:

def __init__(self, x: int):
self.helper = ReferencedClass(x * 2)

@staticmethod
def foo():
return KLASS_X
Expand All @@ -304,7 +328,26 @@ def bar(cls):
return KLASS_Y

def baz(self):
return KLASS_Z""",
return KLASS_Z

def use_referenced(self, value: int):
ref = ReferencedClass(value)
return ref.get_value()

def compute_with_reference(self):
return self.helper.get_value() + 10""",
),
"ReferencedClass": Executable(
kind=ExecutableKind.DEFINITION,
name="ReferencedClass",
path="test_metaprogramming.py",
payload="""class ReferencedClass:

def __init__(self, value: int):
self.value = value

def get_value(self):
return self.value""",
),
"dataclass": Executable(
payload="from dataclasses import dataclass", kind=ExecutableKind.IMPORT
Expand Down Expand Up @@ -341,7 +384,8 @@ def sample_context_manager():
pd.DataFrame([{'x': 1}])
to_table('y')
my_lambda()
return X + a + W""",
obj = MyClass(a)
return X + a + W + obj.compute_with_reference()""",
),
"sample_context_manager": Executable(
payload="""@contextmanager
Expand Down Expand Up @@ -424,6 +468,21 @@ def function_with_custom_decorator():
assert all(is_metadata for (_, is_metadata) in env.values())
assert serialized_env == expected_env

# Check that class references inside init are captured
init_globals = func_globals(MyClass.__init__)
assert "ReferencedClass" in init_globals

env = {}
build_env(other_func, env=env, name="other_func_test", path=path)
serialized_env = serialize_env(env, path=path)

assert "MyClass" in serialized_env
assert "ReferencedClass" in serialized_env

prepared_env = prepare_env(serialized_env)
result = eval("other_func_test(2)", prepared_env)
assert result == 17


def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
path = Path("tests/utils")
Expand Down