Skip to content

Commit 075e1ec

Browse files
test: imporve tests
1 parent c88674d commit 075e1ec

File tree

1 file changed

+17
-63
lines changed

1 file changed

+17
-63
lines changed

python/tests/test_plans.py

Lines changed: 17 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -48,62 +48,47 @@ def test_logical_plan_to_proto(ctx, df) -> None:
4848
assert str(original_execution_plan) == str(execution_plan)
4949

5050

51-
def test_execution_plan_metrics() -> None:
51+
def test_metrics_tree_walk() -> None:
5252
ctx = SessionContext()
5353
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
5454
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
55-
5655
df.collect()
5756
plan = df.execution_plan()
5857

58+
results = plan.collect_metrics()
59+
assert len(results) >= 1
5960
found_metrics = False
60-
61-
def _check(node):
62-
nonlocal found_metrics
63-
ms = node.metrics()
64-
if ms is not None and ms.output_rows is not None and ms.output_rows > 0:
61+
for name, ms in results:
62+
assert isinstance(name, str)
63+
assert isinstance(ms, MetricsSet)
64+
if ms.output_rows is not None and ms.output_rows > 0:
6565
found_metrics = True
66-
for child in node.children():
67-
_check(child)
68-
69-
_check(plan)
7066
assert found_metrics
7167

7268

7369
def test_metric_properties() -> None:
7470
ctx = SessionContext()
7571
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
7672
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
77-
7873
df.collect()
7974
plan = df.execution_plan()
8075

8176
for _, ms in plan.collect_metrics():
77+
r = repr(ms)
78+
assert isinstance(r, str)
8279
for metric in ms.metrics():
8380
assert isinstance(metric, Metric)
8481
assert isinstance(metric.name, str)
8582
assert len(metric.name) > 0
8683
assert metric.partition is None or isinstance(metric.partition, int)
8784
assert isinstance(metric.labels(), dict)
85+
mr = repr(metric)
86+
assert isinstance(mr, str)
87+
assert len(mr) > 0
8888
return
8989
pytest.skip("No metrics found")
9090

9191

92-
def test_metrics_tree_walk() -> None:
93-
ctx = SessionContext()
94-
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'a'), (4, 'b')")
95-
df = ctx.sql("SELECT column2, COUNT(*) FROM t GROUP BY column2")
96-
97-
df.collect()
98-
plan = df.execution_plan()
99-
100-
results = plan.collect_metrics()
101-
assert len(results) >= 2
102-
for name, ms in results:
103-
assert isinstance(name, str)
104-
assert isinstance(ms, MetricsSet)
105-
106-
10792
def test_no_metrics_before_execution() -> None:
10893
ctx = SessionContext()
10994
ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)")
@@ -113,35 +98,14 @@ def test_no_metrics_before_execution() -> None:
11398
assert ms is None or ms.output_rows is None or ms.output_rows == 0
11499

115100

116-
def test_metrics_repr() -> None:
117-
ctx = SessionContext()
118-
ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)")
119-
df = ctx.sql("SELECT * FROM t")
120-
121-
df.collect()
122-
plan = df.execution_plan()
123-
124-
for _, ms in plan.collect_metrics():
125-
r = repr(ms)
126-
assert isinstance(r, str)
127-
for metric in ms.metrics():
128-
mr = repr(metric)
129-
assert isinstance(mr, str)
130-
assert len(mr) > 0
131-
return
132-
pytest.skip("No metrics found")
133-
134-
135101
def test_collect_partitioned_metrics() -> None:
136102
ctx = SessionContext()
137103
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
138104
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
139105

140-
partitions = df.collect_partitioned()
106+
df.collect_partitioned()
141107
plan = df.execution_plan()
142-
assert len(partitions) == plan.partition_count
143108

144-
# Metrics should be populated after collecting
145109
found_metrics = False
146110
for _, ms in plan.collect_metrics():
147111
if ms.output_rows is not None and ms.output_rows > 0:
@@ -154,18 +118,12 @@ def test_execute_stream_metrics() -> None:
154118
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
155119
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
156120

157-
stream = df.execute_stream()
158-
159-
# Consume the stream (iterates over RecordBatches)
160-
batches = list(stream)
161-
assert len(batches) >= 1
121+
for _ in df.execute_stream():
122+
pass
162123

163-
# Metrics should be populated after consuming the stream
164124
plan = df.execution_plan()
165125
found_metrics = False
166-
for name, ms in plan.collect_metrics():
167-
assert isinstance(name, str)
168-
assert isinstance(ms, MetricsSet)
126+
for _, ms in plan.collect_metrics():
169127
if ms.output_rows is not None and ms.output_rows > 0:
170128
found_metrics = True
171129
assert found_metrics
@@ -176,14 +134,10 @@ def test_execute_stream_partitioned_metrics() -> None:
176134
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
177135
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
178136

179-
streams = df.execute_stream_partitioned()
180-
181-
# Consume all partition streams
182-
for stream in streams:
137+
for stream in df.execute_stream_partitioned():
183138
for _ in stream:
184139
pass
185140

186-
# Metrics should be populated (FilterExec reports output_rows)
187141
plan = df.execution_plan()
188142
found_metrics = False
189143
for _, ms in plan.collect_metrics():

0 commit comments

Comments
 (0)