Skip to content

Commit c446690

Browse files
committed
Reapply "feat: add ensure_expr_list function to flatten and validate nested expressions"
This reverts commit 5964e47.
1 parent 5964e47 commit c446690

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

python/datafusion/dataframe.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Expr,
4646
SortKey,
4747
ensure_expr,
48+
ensure_expr_list,
4849
expr_list_to_raw_expr_list,
4950
sort_list_to_raw_sort_list,
5051
)
@@ -488,17 +489,7 @@ def with_columns(
488489
Returns:
489490
DataFrame with the new columns added.
490491
"""
491-
492-
def _iter_exprs(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[Expr | str]:
493-
for expr in items:
494-
if isinstance(expr, str):
495-
yield expr
496-
elif isinstance(expr, Iterable) and not isinstance(expr, Expr):
497-
yield from _iter_exprs(expr)
498-
else:
499-
yield expr
500-
501-
expressions = [ensure_expr(e) for e in _iter_exprs(exprs)]
492+
expressions = ensure_expr_list(exprs)
502493
for alias, expr in named_exprs.items():
503494
ensure_expr(expr)
504495
expressions.append(expr.alias(alias).expr)

python/datafusion/expr.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -219,6 +219,7 @@
219219
"WindowFrame",
220220
"WindowFrameBound",
221221
"ensure_expr",
222+
"ensure_expr_list",
222223
]
223224

224225

@@ -243,6 +244,31 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr:
243244
return value.expr
244245

245246

247+
def ensure_expr_list(
248+
exprs: Iterable[Expr | Iterable[Expr]],
249+
) -> list[expr_internal.Expr]:
250+
"""Flatten an iterable of expressions, validating each via ``ensure_expr``.
251+
252+
Args:
253+
exprs: Possibly nested iterable containing expressions.
254+
255+
Returns:
256+
A flat list of raw expressions.
257+
258+
Raises:
259+
TypeError: If any item is not an instance of :class:`Expr`.
260+
"""
261+
262+
def _iter(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[expr_internal.Expr]:
263+
for expr in items:
264+
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
265+
yield from _iter(expr)
266+
else:
267+
yield ensure_expr(expr)
268+
269+
return list(_iter(exprs))
270+
271+
246272
def _to_raw_expr(value: Expr | str) -> expr_internal.Expr:
247273
"""Convert a Python expression or column name to its raw variant.
248274

python/tests/test_dataframe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,14 @@ def test_with_columns_invalid_expr(df):
436436
TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
437437
):
438438
df.with_columns(c="a")
439+
with pytest.raises(
440+
TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
441+
):
442+
df.with_columns(["a"])
443+
with pytest.raises(
444+
TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
445+
):
446+
df.with_columns(c=["a"])
439447

440448

441449
def test_cast(df):

0 commit comments

Comments
 (0)