4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343from datafusion .expr import (
44- _EXPR_TYPE_ERROR ,
44+ EXPR_TYPE_ERROR ,
4545 Expr ,
4646 SortExpr ,
4747 expr_list_to_raw_expr_list ,
@@ -430,9 +430,9 @@ def filter(self, *predicates: Expr) -> DataFrame:
430430 """
431431 df = self .df
432432 for p in predicates :
433- if not isinstance (p , Expr ):
434- raise TypeError (_EXPR_TYPE_ERROR )
435- df = df .filter (p . expr )
433+ if isinstance ( p , str ) or not isinstance (p , Expr ):
434+ raise TypeError (EXPR_TYPE_ERROR )
435+ df = df .filter (expr_list_to_raw_expr_list ( p )[ 0 ] )
436436 return DataFrame (df )
437437
438438 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -446,7 +446,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
446446 DataFrame with the new column.
447447 """
448448 if not isinstance (expr , Expr ):
449- raise TypeError (_EXPR_TYPE_ERROR )
449+ raise TypeError (EXPR_TYPE_ERROR )
450450 return DataFrame (self .df .with_column (name , expr .expr ))
451451
452452 def with_columns (
@@ -478,20 +478,19 @@ def _simplify_expression(
478478 ) -> list [expr_internal .Expr ]:
479479 expr_list : list [expr_internal .Expr ] = []
480480 for expr in exprs :
481- if isinstance (expr , str ) or (
482- isinstance (expr , Iterable )
483- and not isinstance (expr , Expr )
484- and any (isinstance (inner , str ) for inner in expr )
485- ):
486- raise TypeError (_EXPR_TYPE_ERROR )
487- try :
488- expr_list .extend (expr_list_to_raw_expr_list (expr ))
489- except TypeError as err :
490- raise TypeError (_EXPR_TYPE_ERROR ) from err
491- for alias , expr in named_exprs .items ():
492- if not isinstance (expr , Expr ):
493- raise TypeError (_EXPR_TYPE_ERROR )
494- expr_list .append (expr .alias (alias ).expr )
481+ if isinstance (expr , str ):
482+ raise TypeError (EXPR_TYPE_ERROR )
483+ if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
484+ if any (not isinstance (inner_expr , Expr ) for inner_expr in expr ):
485+ raise TypeError (EXPR_TYPE_ERROR )
486+ elif not isinstance (expr , Expr ):
487+ raise TypeError (EXPR_TYPE_ERROR )
488+ expr_list .extend (expr_list_to_raw_expr_list (expr ))
489+ if named_exprs :
490+ for alias , expr in named_exprs .items ():
491+ if not isinstance (expr , Expr ):
492+ raise TypeError (EXPR_TYPE_ERROR )
493+ expr_list .append (expr .alias (alias ).expr )
495494 return expr_list
496495
497496 expressions = _simplify_expression (* exprs , ** named_exprs )
@@ -536,7 +535,7 @@ def aggregate(
536535 aggs_exprs = []
537536 for agg in aggs_list :
538537 if not isinstance (agg , Expr ):
539- raise TypeError (_EXPR_TYPE_ERROR )
538+ raise TypeError (EXPR_TYPE_ERROR )
540539 aggs_exprs .append (agg .expr )
541540 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
542541
@@ -552,7 +551,20 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
552551 Returns:
553552 DataFrame after sorting.
554553 """
555- exprs_raw = sort_list_to_raw_sort_list (list (exprs ))
554+ exprs_raw = []
555+ for e in exprs :
556+ if isinstance (e , SortExpr ):
557+ exprs_raw .append (sort_or_default (e ))
558+ elif isinstance (e , str ):
559+ exprs_raw .append (sort_or_default (Expr .column (e )))
560+ elif isinstance (e , Expr ):
561+ exprs_raw .append (sort_or_default (e ))
562+ else :
563+ error = (
564+ "Expected Expr or column name, found:"
565+ f" { type (e ).__name__ } . { EXPR_TYPE_ERROR } ."
566+ )
567+ raise TypeError (error )
556568 return DataFrame (self .df .sort (* exprs_raw ))
557569
558570 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -779,7 +791,7 @@ def join_on(
779791 exprs = []
780792 for expr in on_exprs :
781793 if not isinstance (expr , Expr ):
782- raise TypeError (_EXPR_TYPE_ERROR )
794+ raise TypeError (EXPR_TYPE_ERROR )
783795 exprs .append (expr .expr )
784796 return DataFrame (self .df .join_on (right .df , exprs , how ))
785797
0 commit comments