Skip to content

Commit 268f2fa

Browse files
committed
Add type checks for expressions in DataFrame methods and expr_list_to_raw_expr_list for improved error handling
1 parent 54e7f6d commit 268f2fa

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

python/datafusion/dataframe.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ def filter(self, *predicates: Expr) -> DataFrame:
424424
"""
425425
df = self.df
426426
for p in predicates:
427+
if isinstance(p, str) or not isinstance(p, Expr):
428+
error = "Use col() or lit() to construct expressions"
429+
raise TypeError(error)
427430
df = df.filter(expr_list_to_raw_expr_list(p)[0])
428431
return DataFrame(df)
429432

@@ -437,7 +440,10 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
437440
Returns:
438441
DataFrame with the new column.
439442
"""
440-
return DataFrame(self.df.with_column(name, expr_list_to_raw_expr_list(expr)[0]))
443+
if not isinstance(expr, Expr):
444+
error = "Use col() or lit() to construct expressions"
445+
raise TypeError(error)
446+
return DataFrame(self.df.with_column(name, expr.expr))
441447

442448
def with_columns(
443449
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
@@ -468,13 +474,22 @@ def _simplify_expression(
468474
) -> list[expr_internal.Expr]:
469475
expr_list = []
470476
for expr in exprs:
471-
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
472-
expr_list.extend(expr_list_to_raw_expr_list(inner_expr)[0] for inner_expr in expr)
477+
if isinstance(expr, Expr):
478+
expr_list.append(expr.expr)
479+
elif isinstance(expr, Iterable) and not isinstance(expr, (str, Expr)):
480+
for inner_expr in expr:
481+
if not isinstance(inner_expr, Expr):
482+
error = "Use col() or lit() to construct expressions"
483+
raise TypeError(error)
484+
expr_list.append(inner_expr.expr)
473485
else:
474-
expr_list.append(expr_list_to_raw_expr_list(expr)[0])
486+
error = "Use col() or lit() to construct expressions"
487+
raise TypeError(error)
475488
if named_exprs:
476489
for alias, expr in named_exprs.items():
477-
expr_list_to_raw_expr_list(expr)[0]
490+
if not isinstance(expr, Expr):
491+
error = "Use col() or lit() to construct expressions"
492+
raise TypeError(error)
478493
expr_list.append(expr.alias(alias).expr)
479494
return expr_list
480495

@@ -516,10 +531,24 @@ def aggregate(
516531
group_by_list = group_by if isinstance(group_by, list) else [group_by]
517532
aggs_list = aggs if isinstance(aggs, list) else [aggs]
518533

519-
group_by_exprs = [
520-
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
521-
]
522-
aggs_exprs = [expr_list_to_raw_expr_list(agg)[0] for agg in aggs_list]
534+
group_by_exprs = []
535+
for e in group_by_list:
536+
if isinstance(e, str):
537+
group_by_exprs.append(Expr.column(e).expr)
538+
elif isinstance(e, Expr):
539+
group_by_exprs.append(e.expr)
540+
else:
541+
error = (
542+
"Expected Expr or column name, found:"
543+
f" {type(e).__name__}. Use col() or lit() to construct expressions"
544+
)
545+
raise TypeError(error)
546+
aggs_exprs = []
547+
for agg in aggs_list:
548+
if not isinstance(agg, Expr):
549+
error = "Use col() or lit() to construct expressions"
550+
raise TypeError(error)
551+
aggs_exprs.append(agg.expr)
523552
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
524553

525554
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
@@ -765,7 +794,12 @@ def join_on(
765794
Returns:
766795
DataFrame after join.
767796
"""
768-
exprs = [expr_list_to_raw_expr_list(expr)[0] for expr in on_exprs]
797+
exprs = []
798+
for expr in on_exprs:
799+
if not isinstance(expr, Expr):
800+
error = "Use col() or lit() to construct expressions"
801+
raise TypeError(error)
802+
exprs.append(expr.expr)
769803
return DataFrame(self.df.join_on(right.df, exprs, how))
770804

771805
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,19 @@ def expr_list_to_raw_expr_list(
223223
expr_list = [expr_list]
224224
if expr_list is None:
225225
return None
226-
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in expr_list]
226+
raw_exprs: list[expr_internal.Expr] = []
227+
for e in expr_list:
228+
if isinstance(e, str):
229+
raw_exprs.append(Expr.column(e).expr)
230+
elif isinstance(e, Expr):
231+
raw_exprs.append(e.expr)
232+
else:
233+
error = (
234+
"Expected Expr or column name, found:"
235+
f" {type(e).__name__}. Use col() or lit() to construct expressions."
236+
)
237+
raise TypeError(error)
238+
return raw_exprs
227239

228240

229241
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

0 commit comments

Comments
 (0)