4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , _to_expr_list , sort_or_default
43+ from datafusion .expr import Expr , SortExpr , sort_or_default
4444from datafusion .plan import ExecutionPlan , LogicalPlan
4545from datafusion .record_batch import RecordBatchStream
4646
@@ -394,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396 """
397- exprs_internal = _to_expr_list (exprs )
397+ exprs_internal = [
398+ Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399+ ]
398400 return DataFrame (self .df .select (* exprs_internal ))
399401
400402 def drop (self , * columns : str ) -> DataFrame :
@@ -424,12 +426,6 @@ def filter(self, *predicates: Expr) -> DataFrame:
424426 """
425427 df = self .df
426428 for p in predicates :
427- if not isinstance (p , Expr ):
428- msg = (
429- f"Expected Expr, got { type (p ).__name__ } . "
430- "Use col() or lit() to construct expressions."
431- )
432- raise TypeError (msg )
433429 df = df .filter (p .expr )
434430 return DataFrame (df )
435431
@@ -443,12 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
443439 Returns:
444440 DataFrame with the new column.
445441 """
446- if not isinstance (expr , Expr ):
447- msg = (
448- f"Expected Expr, got { type (expr ).__name__ } . "
449- "Use col() or lit() to construct expressions."
450- )
451- raise TypeError (msg )
452442 return DataFrame (self .df .with_column (name , expr .expr ))
453443
454444 def with_columns (
@@ -483,28 +473,11 @@ def _simplify_expression(
483473 if isinstance (expr , Expr ):
484474 expr_list .append (expr .expr )
485475 elif isinstance (expr , Iterable ):
486- for inner_expr in expr :
487- if not isinstance (inner_expr , Expr ):
488- msg = (
489- f"Expected Expr, got { type (inner_expr ).__name__ } . "
490- "Use col() or lit() to construct expressions."
491- )
492- raise TypeError (msg )
493- expr_list .append (inner_expr .expr )
476+ expr_list .extend (inner_expr .expr for inner_expr in expr )
494477 else :
495- msg = (
496- f"Expected Expr, got { type (expr ).__name__ } . "
497- "Use col() or lit() to construct expressions."
498- )
499- raise TypeError (msg )
478+ raise NotImplementedError
500479 if named_exprs :
501480 for alias , expr in named_exprs .items ():
502- if not isinstance (expr , Expr ):
503- msg = (
504- f"Expected Expr, got { type (expr ).__name__ } . "
505- "Use col() or lit() to construct expressions."
506- )
507- raise TypeError (msg )
508481 expr_list .append (expr .alias (alias ).expr )
509482 return expr_list
510483
@@ -530,56 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
530503 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
531504
532505 def aggregate (
533- self ,
534- group_by : list [Expr | str ] | Expr | str ,
535- aggs : list [Expr ] | Expr ,
506+ self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
536507 ) -> DataFrame :
537508 """Aggregates the rows of the current DataFrame.
538509
539510 Args:
540- group_by: List of expressions or column names to group by.
511+ group_by: List of expressions to group by.
541512 aggs: List of expressions to aggregate.
542513
543514 Returns:
544515 DataFrame after aggregation.
545516 """
546- group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
547- aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
517+ group_by = group_by if isinstance (group_by , list ) else [group_by ]
518+ aggs = aggs if isinstance (aggs , list ) else [aggs ]
548519
549- group_by_exprs = [
550- Expr .column (e ).expr if isinstance (e , str ) else e .expr for e in group_by_list
551- ]
552- aggs_exprs = []
553- for agg in aggs_list :
554- if not isinstance (agg , Expr ):
555- msg = (
556- f"Expected Expr, got { type (agg ).__name__ } . "
557- "Use col() or lit() to construct expressions."
558- )
559- raise TypeError (msg )
560- aggs_exprs .append (agg .expr )
561- return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
520+ group_by = [e .expr for e in group_by ]
521+ aggs = [e .expr for e in aggs ]
522+ return DataFrame (self .df .aggregate (group_by , aggs ))
562523
563- def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
564- """Sort the DataFrame by the specified sorting expressions or column names .
524+ def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525+ """Sort the DataFrame by the specified sorting expressions.
565526
566527 Note that any expression can be turned into a sort expression by
567- calling its ``sort`` method.
528+ calling its` ``sort`` method.
568529
569530 Args:
570- exprs: Sort expressions or column names , applied in order.
531+ exprs: Sort expressions, applied in order.
571532
572533 Returns:
573534 DataFrame after sorting.
574535 """
575- expr_seq = [e for e in exprs if not isinstance (e , SortExpr )]
576- raw_exprs_iter = iter (_to_expr_list (expr_seq ))
577- exprs_raw = []
578- for e in exprs :
579- if isinstance (e , SortExpr ):
580- exprs_raw .append (sort_or_default (e ))
581- else :
582- exprs_raw .append (sort_or_default (Expr (next (raw_exprs_iter ))))
536+ exprs_raw = [sort_or_default (expr ) for expr in exprs ]
583537 return DataFrame (self .df .sort (* exprs_raw ))
584538
585539 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -803,15 +757,7 @@ def join_on(
803757 Returns:
804758 DataFrame after join.
805759 """
806- exprs = []
807- for expr in on_exprs :
808- if not isinstance (expr , Expr ):
809- msg = (
810- f"Expected Expr, got { type (expr ).__name__ } . "
811- "Use col() or lit() to construct expressions."
812- )
813- raise TypeError (msg )
814- exprs .append (expr .expr )
760+ exprs = [expr .expr for expr in on_exprs ]
815761 return DataFrame (self .df .join_on (right .df , exprs , how ))
816762
817763 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments