4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , expr_list_to_raw_expr_list , sort_or_default
43+ from datafusion .expr import Expr , SortExpr , sort_or_default
4444from datafusion .plan import ExecutionPlan , LogicalPlan
4545from datafusion .record_batch import RecordBatchStream
4646
@@ -394,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396 """
397- exprs_internal = expr_list_to_raw_expr_list (exprs )
397+ exprs_internal = [
398+ Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399+ ]
398400 return DataFrame (self .df .select (* exprs_internal ))
399401
400402 def drop (self , * columns : str ) -> DataFrame :
@@ -424,10 +426,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
424426 """
425427 df = self .df
426428 for p in predicates :
427- if isinstance (p , str ) or not isinstance (p , Expr ):
428- error = "Use col() or lit() to construct expressions"
429- raise TypeError (error )
430- df = df .filter (expr_list_to_raw_expr_list (p )[0 ])
429+ df = df .filter (p .expr )
431430 return DataFrame (df )
432431
433432 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -440,9 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
440439 Returns:
441440 DataFrame with the new column.
442441 """
443- if not isinstance (expr , Expr ):
444- error = "Use col() or lit() to construct expressions"
445- raise TypeError (error )
446442 return DataFrame (self .df .with_column (name , expr .expr ))
447443
448444 def with_columns (
@@ -476,20 +472,12 @@ def _simplify_expression(
476472 for expr in exprs :
477473 if isinstance (expr , Expr ):
478474 expr_list .append (expr .expr )
479- elif isinstance (expr , Iterable ) and not isinstance (expr , (str , Expr )):
480- for inner_expr in expr :
481- if not isinstance (inner_expr , Expr ):
482- error = "Use col() or lit() to construct expressions"
483- raise TypeError (error )
484- expr_list .append (inner_expr .expr )
475+ elif isinstance (expr , Iterable ):
476+ expr_list .extend (inner_expr .expr for inner_expr in expr )
485477 else :
486- error = "Use col() or lit() to construct expressions"
487- raise TypeError (error )
478+ raise NotImplementedError
488479 if named_exprs :
489480 for alias , expr in named_exprs .items ():
490- if not isinstance (expr , Expr ):
491- error = "Use col() or lit() to construct expressions"
492- raise TypeError (error )
493481 expr_list .append (expr .alias (alias ).expr )
494482 return expr_list
495483
@@ -515,62 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
515503 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
516504
517505 def aggregate (
518- self ,
519- group_by : list [Expr | str ] | Expr | str ,
520- aggs : list [Expr ] | Expr ,
506+ self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
521507 ) -> DataFrame :
522508 """Aggregates the rows of the current DataFrame.
523509
524510 Args:
525- group_by: List of expressions or column names to group by.
511+ group_by: List of expressions to group by.
526512 aggs: List of expressions to aggregate.
527513
528514 Returns:
529515 DataFrame after aggregation.
530516 """
531- group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
532- aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
533-
534- group_by_exprs = []
535- for e in group_by_list :
536- if isinstance (e , str ):
537- group_by_exprs .append (Expr .column (e ).expr )
538- elif isinstance (e , Expr ):
539- group_by_exprs .append (e .expr )
540- else :
541- error = (
542- "Expected Expr or column name, found:"
543- f" { type (e ).__name__ } . Use col() or lit() to construct expressions"
544- )
545- raise TypeError (error )
546- aggs_exprs = []
547- for agg in aggs_list :
548- if not isinstance (agg , Expr ):
549- error = "Use col() or lit() to construct expressions"
550- raise TypeError (error )
551- aggs_exprs .append (agg .expr )
552- return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
553-
554- def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
555- """Sort the DataFrame by the specified sorting expressions or column names.
517+ group_by = group_by if isinstance (group_by , list ) else [group_by ]
518+ aggs = aggs if isinstance (aggs , list ) else [aggs ]
519+
520+ group_by = [e .expr for e in group_by ]
521+ aggs = [e .expr for e in aggs ]
522+ return DataFrame (self .df .aggregate (group_by , aggs ))
523+
524+ def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525+ """Sort the DataFrame by the specified sorting expressions.
556526
557527 Note that any expression can be turned into a sort expression by
558- calling its ``sort`` method.
528+ calling its` ``sort`` method.
559529
560530 Args:
561- exprs: Sort expressions or column names , applied in order.
531+ exprs: Sort expressions, applied in order.
562532
563533 Returns:
564534 DataFrame after sorting.
565535 """
566- expr_seq = [e for e in exprs if not isinstance (e , SortExpr )]
567- raw_exprs_iter = iter (expr_list_to_raw_expr_list (expr_seq ))
568- exprs_raw = []
569- for e in exprs :
570- if isinstance (e , SortExpr ):
571- exprs_raw .append (sort_or_default (e ))
572- else :
573- exprs_raw .append (sort_or_default (Expr (next (raw_exprs_iter ))))
536+ exprs_raw = [sort_or_default (expr ) for expr in exprs ]
574537 return DataFrame (self .df .sort (* exprs_raw ))
575538
576539 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -794,12 +757,7 @@ def join_on(
794757 Returns:
795758 DataFrame after join.
796759 """
797- exprs = []
798- for expr in on_exprs :
799- if not isinstance (expr , Expr ):
800- error = "Use col() or lit() to construct expressions"
801- raise TypeError (error )
802- exprs .append (expr .expr )
760+ exprs = [expr .expr for expr in on_exprs ]
803761 return DataFrame (self .df .join_on (right .df , exprs , how ))
804762
805763 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments