4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , sort_or_default
43+ from datafusion .expr import Expr , SortExpr , expr_list_to_raw_expr_list , sort_or_default
4444from datafusion .plan import ExecutionPlan , LogicalPlan
4545from datafusion .record_batch import RecordBatchStream
4646
@@ -394,9 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396 """
397- exprs_internal = [
398- Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399- ]
397+ exprs_internal = expr_list_to_raw_expr_list (exprs )
400398 return DataFrame (self .df .select (* exprs_internal ))
401399
402400 def drop (self , * columns : str ) -> DataFrame :
@@ -426,7 +424,10 @@ def filter(self, *predicates: Expr) -> DataFrame:
426424 """
427425 df = self .df
428426 for p in predicates :
429- df = df .filter (p .expr )
427+ if isinstance (p , str ) or not isinstance (p , Expr ):
428+ error = "Use col() or lit() to construct expressions"
429+ raise TypeError (error )
430+ df = df .filter (expr_list_to_raw_expr_list (p )[0 ])
430431 return DataFrame (df )
431432
432433 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -439,6 +440,9 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439440 Returns:
440441 DataFrame with the new column.
441442 """
443+ if not isinstance (expr , Expr ):
444+ error = "Use col() or lit() to construct expressions"
445+ raise TypeError (error )
442446 return DataFrame (self .df .with_column (name , expr .expr ))
443447
444448 def with_columns (
@@ -472,12 +476,20 @@ def _simplify_expression(
472476 for expr in exprs :
473477 if isinstance (expr , Expr ):
474478 expr_list .append (expr .expr )
475- elif isinstance (expr , Iterable ):
476- expr_list .extend (inner_expr .expr for inner_expr in expr )
479+ elif isinstance (expr , Iterable ) and not isinstance (expr , (str , Expr )):
480+ for inner_expr in expr :
481+ if not isinstance (inner_expr , Expr ):
482+ error = "Use col() or lit() to construct expressions"
483+ raise TypeError (error )
484+ expr_list .append (inner_expr .expr )
477485 else :
478- raise NotImplementedError
486+ error = "Use col() or lit() to construct expressions"
487+ raise TypeError (error )
479488 if named_exprs :
480489 for alias , expr in named_exprs .items ():
490+ if not isinstance (expr , Expr ):
491+ error = "Use col() or lit() to construct expressions"
492+ raise TypeError (error )
481493 expr_list .append (expr .alias (alias ).expr )
482494 return expr_list
483495
@@ -503,37 +515,62 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503515 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
504516
505517 def aggregate (
506- self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
518+ self ,
519+ group_by : list [Expr | str ] | Expr | str ,
520+ aggs : list [Expr ] | Expr ,
507521 ) -> DataFrame :
508522 """Aggregates the rows of the current DataFrame.
509523
510524 Args:
511- group_by: List of expressions to group by.
525+ group_by: List of expressions or column names to group by.
512526 aggs: List of expressions to aggregate.
513527
514528 Returns:
515529 DataFrame after aggregation.
516530 """
517- group_by = group_by if isinstance (group_by , list ) else [group_by ]
518- aggs = aggs if isinstance (aggs , list ) else [aggs ]
519-
520- group_by = [e .expr for e in group_by ]
521- aggs = [e .expr for e in aggs ]
522- return DataFrame (self .df .aggregate (group_by , aggs ))
523-
524- def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525- """Sort the DataFrame by the specified sorting expressions.
531+ group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
532+ aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
533+
534+ group_by_exprs = []
535+ for e in group_by_list :
536+ if isinstance (e , str ):
537+ group_by_exprs .append (Expr .column (e ).expr )
538+ elif isinstance (e , Expr ):
539+ group_by_exprs .append (e .expr )
540+ else :
541+ error = (
542+ "Expected Expr or column name, found:"
543+ f" { type (e ).__name__ } . Use col() or lit() to construct expressions"
544+ )
545+ raise TypeError (error )
546+ aggs_exprs = []
547+ for agg in aggs_list :
548+ if not isinstance (agg , Expr ):
549+ error = "Use col() or lit() to construct expressions"
550+ raise TypeError (error )
551+ aggs_exprs .append (agg .expr )
552+ return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
553+
554+ def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
555+ """Sort the DataFrame by the specified sorting expressions or column names.
526556
527557 Note that any expression can be turned into a sort expression by
528- calling its` ``sort`` method.
558+ calling its ``sort`` method.
529559
530560 Args:
531- exprs: Sort expressions, applied in order.
561+ exprs: Sort expressions or column names , applied in order.
532562
533563 Returns:
534564 DataFrame after sorting.
535565 """
536- exprs_raw = [sort_or_default (expr ) for expr in exprs ]
566+ expr_seq = [e for e in exprs if not isinstance (e , SortExpr )]
567+ raw_exprs_iter = iter (expr_list_to_raw_expr_list (expr_seq ))
568+ exprs_raw = []
569+ for e in exprs :
570+ if isinstance (e , SortExpr ):
571+ exprs_raw .append (sort_or_default (e ))
572+ else :
573+ exprs_raw .append (sort_or_default (Expr (next (raw_exprs_iter ))))
537574 return DataFrame (self .df .sort (* exprs_raw ))
538575
539576 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -757,7 +794,12 @@ def join_on(
757794 Returns:
758795 DataFrame after join.
759796 """
760- exprs = [expr .expr for expr in on_exprs ]
797+ exprs = []
798+ for expr in on_exprs :
799+ if not isinstance (expr , Expr ):
800+ error = "Use col() or lit() to construct expressions"
801+ raise TypeError (error )
802+ exprs .append (expr .expr )
761803 return DataFrame (self .df .join_on (right .df , exprs , how ))
762804
763805 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments