4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import (
44- EXPR_TYPE_ERROR ,
45- Expr ,
46- SortKey ,
47- expr_list_to_raw_expr_list ,
48- sort_list_to_raw_sort_list ,
49- )
43+ from datafusion .expr import Expr , SortExpr , sort_or_default
5044from datafusion .plan import ExecutionPlan , LogicalPlan
5145from datafusion .record_batch import RecordBatchStream
5246
@@ -292,23 +286,6 @@ def __init__(
292286 self .bloom_filter_ndv = bloom_filter_ndv
293287
294288
295- def _ensure_expr (value : Expr ) -> expr_internal .Expr :
296- """Return the internal expression or raise ``TypeError`` if invalid.
297-
298- Args:
299- value: Candidate expression.
300-
301- Returns:
302- The internal expression representation.
303-
304- Raises:
305- TypeError: If ``value`` is not an instance of :class:`Expr`.
306- """
307- if not isinstance (value , Expr ):
308- raise TypeError (EXPR_TYPE_ERROR )
309- return value .expr
310-
311-
312289class DataFrame :
313290 """Two dimensional table representation of data.
314291
@@ -417,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
417394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
418395
419396 """
420- exprs_internal = expr_list_to_raw_expr_list (exprs )
397+ exprs_internal = [
398+ Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399+ ]
421400 return DataFrame (self .df .select (* exprs_internal ))
422401
423402 def drop (self , * columns : str ) -> DataFrame :
@@ -447,7 +426,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
447426 """
448427 df = self .df
449428 for p in predicates :
450- df = df .filter (_ensure_expr ( p ) )
429+ df = df .filter (p . expr )
451430 return DataFrame (df )
452431
453432 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -460,7 +439,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
460439 Returns:
461440 DataFrame with the new column.
462441 """
463- return DataFrame (self .df .with_column (name , _ensure_expr ( expr ) ))
442+ return DataFrame (self .df .with_column (name , expr . expr ))
464443
465444 def with_columns (
466445 self , * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
@@ -489,24 +468,17 @@ def with_columns(
489468 def _simplify_expression (
490469 * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
491470 ) -> list [expr_internal .Expr ]:
492- expr_list : list [ expr_internal . Expr ] = []
471+ expr_list = []
493472 for expr in exprs :
494- if isinstance (expr , str ):
495- raise TypeError (EXPR_TYPE_ERROR )
496- if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
497- expr_value = list (expr )
498- if any (isinstance (inner , str ) for inner in expr_value ):
499- raise TypeError (EXPR_TYPE_ERROR )
473+ if isinstance (expr , Expr ):
474+ expr_list .append (expr .expr )
475+ elif isinstance (expr , Iterable ):
476+ expr_list .extend (inner_expr .expr for inner_expr in expr )
500477 else :
501- expr_value = expr
502- try :
503- expr_list .extend (expr_list_to_raw_expr_list (expr_value ))
504- except TypeError as err :
505- raise TypeError (EXPR_TYPE_ERROR ) from err
506- for alias , expr in named_exprs .items ():
507- if not isinstance (expr , Expr ):
508- raise TypeError (EXPR_TYPE_ERROR )
509- expr_list .append (expr .alias (alias ).expr )
478+ raise NotImplementedError
479+ if named_exprs :
480+ for alias , expr in named_exprs .items ():
481+ expr_list .append (expr .alias (alias ).expr )
510482 return expr_list
511483
512484 expressions = _simplify_expression (* exprs , ** named_exprs )
@@ -531,43 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
531503 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
532504
533505 def aggregate (
534- self ,
535- group_by : list [Expr | str ] | Expr | str ,
536- aggs : list [Expr ] | Expr ,
506+ self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
537507 ) -> DataFrame :
538508 """Aggregates the rows of the current DataFrame.
539509
540510 Args:
541- group_by: List of expressions or column names to group by.
511+ group_by: List of expressions to group by.
542512 aggs: List of expressions to aggregate.
543513
544514 Returns:
545515 DataFrame after aggregation.
546516 """
547- group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
548- aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
517+ group_by = group_by if isinstance (group_by , list ) else [group_by ]
518+ aggs = aggs if isinstance (aggs , list ) else [aggs ]
549519
550- group_by_exprs = expr_list_to_raw_expr_list (group_by_list )
551- aggs_exprs = []
552- for agg in aggs_list :
553- if not isinstance (agg , Expr ):
554- raise TypeError (EXPR_TYPE_ERROR )
555- aggs_exprs .append (agg .expr )
556- return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
520+ group_by = [e .expr for e in group_by ]
521+ aggs = [e .expr for e in aggs ]
522+ return DataFrame (self .df .aggregate (group_by , aggs ))
557523
558- def sort (self , * exprs : SortKey ) -> DataFrame :
559- """Sort the DataFrame by the specified sorting expressions or column names .
524+ def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525+ """Sort the DataFrame by the specified sorting expressions.
560526
561527 Note that any expression can be turned into a sort expression by
562- calling its ``sort`` method.
528+ calling its` ``sort`` method.
563529
564530 Args:
565- exprs: Sort expressions or column names , applied in order.
531+ exprs: Sort expressions, applied in order.
566532
567533 Returns:
568534 DataFrame after sorting.
569535 """
570- exprs_raw = sort_list_to_raw_sort_list ( list ( exprs ))
536+ exprs_raw = [ sort_or_default ( expr ) for expr in exprs ]
571537 return DataFrame (self .df .sort (* exprs_raw ))
572538
573539 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -791,7 +757,7 @@ def join_on(
791757 Returns:
792758 DataFrame after join.
793759 """
794- exprs = [_ensure_expr ( expr ) for expr in on_exprs ]
760+ exprs = [expr . expr for expr in on_exprs ]
795761 return DataFrame (self .df .join_on (right .df , exprs , how ))
796762
797763 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments