Skip to content

Commit 44b9d1a

Browse files
committed
UNPICK
1 parent 46dbc85 commit 44b9d1a

File tree

6 files changed

+87
-357
lines changed

6 files changed

+87
-357
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -126,51 +126,6 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129-
String Columns and Expressions
130-
------------------------------
131-
132-
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133-
existing column. These include:
134-
135-
* :py:meth:`~datafusion.DataFrame.select`
136-
* :py:meth:`~datafusion.DataFrame.sort`
137-
* :py:meth:`~datafusion.DataFrame.drop`
138-
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139-
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140-
141-
For such methods, you can pass column names directly:
142-
143-
.. code-block:: python
144-
145-
from datafusion import col, functions as f
146-
147-
df.sort('id')
148-
df.aggregate('id', [f.count(col('value'))])
149-
150-
The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``:
151-
152-
.. code-block:: python
153-
154-
from datafusion import col, column, functions as f
155-
156-
df.sort(col('id'))
157-
df.aggregate(column('id'), [f.count(col('value'))])
158-
159-
Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action.
160-
161-
Whenever an argument represents an expression—such as in
162-
:py:meth:`~datafusion.DataFrame.filter` or
163-
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
164-
and wrap constant values with ``lit()`` (also available as ``literal()``):
165-
166-
.. code-block:: python
167-
168-
from datafusion import col, lit
169-
df.filter(col('age') > lit(21))
170-
171-
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
172-
constant value.
173-
174129
Terminal Operations
175130
-------------------
176131

python/datafusion/context.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from datafusion.catalog import Catalog, CatalogProvider, Table
3333
from datafusion.dataframe import DataFrame
34-
from datafusion.expr import SortKey, sort_list_to_raw_sort_list
34+
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
3636
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
3737

@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[SortKey]] | None = None,
556+
file_sort_order: list[list[Expr | SortExpr]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -567,20 +567,23 @@ def register_listing_table(
567567
table_partition_cols: Partition columns.
568568
file_extension: File extension of the provided table.
569569
schema: The data source schema.
570-
file_sort_order: Sort order for the file. Each sort key can be
571-
specified as a column name (``str``), an expression
572-
(``Expr``), or a ``SortExpr``.
570+
file_sort_order: Sort order for the file.
573571
"""
574572
if table_partition_cols is None:
575573
table_partition_cols = []
576574
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
575+
file_sort_order_raw = (
576+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
577+
if file_sort_order is not None
578+
else None
579+
)
577580
self.ctx.register_listing_table(
578581
name,
579582
str(path),
580583
table_partition_cols,
581584
file_extension,
582585
schema,
583-
self._convert_file_sort_order(file_sort_order),
586+
file_sort_order_raw,
584587
)
585588

586589
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
@@ -805,7 +808,7 @@ def register_parquet(
805808
file_extension: str = ".parquet",
806809
skip_metadata: bool = True,
807810
schema: pa.Schema | None = None,
808-
file_sort_order: list[list[SortKey]] | None = None,
811+
file_sort_order: list[list[SortExpr]] | None = None,
809812
) -> None:
810813
"""Register a Parquet file as a table.
811814
@@ -824,9 +827,7 @@ def register_parquet(
824827
that may be in the file schema. This can help avoid schema
825828
conflicts due to metadata.
826829
schema: The data source schema.
827-
file_sort_order: Sort order for the file. Each sort key can be
828-
specified as a column name (``str``), an expression
829-
(``Expr``), or a ``SortExpr``.
830+
file_sort_order: Sort order for the file.
830831
"""
831832
if table_partition_cols is None:
832833
table_partition_cols = []
@@ -839,7 +840,9 @@ def register_parquet(
839840
file_extension,
840841
skip_metadata,
841842
schema,
842-
self._convert_file_sort_order(file_sort_order),
843+
[sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
844+
if file_sort_order is not None
845+
else None,
843846
)
844847

845848
def register_csv(
@@ -1096,7 +1099,7 @@ def read_parquet(
10961099
file_extension: str = ".parquet",
10971100
skip_metadata: bool = True,
10981101
schema: pa.Schema | None = None,
1099-
file_sort_order: list[list[SortKey]] | None = None,
1102+
file_sort_order: list[list[Expr | SortExpr]] | None = None,
11001103
) -> DataFrame:
11011104
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11021105
@@ -1113,17 +1116,19 @@ def read_parquet(
11131116
schema: An optional schema representing the parquet files. If None,
11141117
the parquet reader will try to infer it based on data in the
11151118
file.
1116-
file_sort_order: Sort order for the file. Each sort key can be
1117-
specified as a column name (``str``), an expression
1118-
(``Expr``), or a ``SortExpr``.
1119+
file_sort_order: Sort order for the file.
11191120
11201121
Returns:
11211122
DataFrame representation of the read Parquet files
11221123
"""
11231124
if table_partition_cols is None:
11241125
table_partition_cols = []
11251126
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
1126-
file_sort_order = self._convert_file_sort_order(file_sort_order)
1127+
file_sort_order = (
1128+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1129+
if file_sort_order is not None
1130+
else None
1131+
)
11271132
return DataFrame(
11281133
self.ctx.read_parquet(
11291134
str(path),
@@ -1174,24 +1179,6 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11741179
"""Execute the ``plan`` and return the results."""
11751180
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
11761181

1177-
@staticmethod
1178-
def _convert_file_sort_order(
1179-
file_sort_order: list[list[SortKey]] | None,
1180-
) -> list[list[Any]] | None:
1181-
"""Convert nested ``SortKey`` lists into raw sort representations.
1182-
1183-
Each ``SortKey`` can be a column name string, an ``Expr``, or a
1184-
``SortExpr`` and will be converted using
1185-
:func:`datafusion.expr.sort_list_to_raw_sort_list`.
1186-
"""
1187-
# Convert each ``SortKey`` in the provided sort order to the low-level
1188-
# representation expected by the Rust bindings.
1189-
return (
1190-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1191-
if file_sort_order is not None
1192-
else None
1193-
)
1194-
11951182
@staticmethod
11961183
def _convert_table_partition_cols(
11971184
table_partition_cols: list[tuple[str, str | pa.DataType]],

python/datafusion/dataframe.py

Lines changed: 28 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import (
44-
EXPR_TYPE_ERROR,
45-
Expr,
46-
SortKey,
47-
expr_list_to_raw_expr_list,
48-
sort_list_to_raw_sort_list,
49-
)
43+
from datafusion.expr import Expr, SortExpr, sort_or_default
5044
from datafusion.plan import ExecutionPlan, LogicalPlan
5145
from datafusion.record_batch import RecordBatchStream
5246

@@ -292,23 +286,6 @@ def __init__(
292286
self.bloom_filter_ndv = bloom_filter_ndv
293287

294288

295-
def _ensure_expr(value: Expr) -> expr_internal.Expr:
296-
"""Return the internal expression or raise ``TypeError`` if invalid.
297-
298-
Args:
299-
value: Candidate expression.
300-
301-
Returns:
302-
The internal expression representation.
303-
304-
Raises:
305-
TypeError: If ``value`` is not an instance of :class:`Expr`.
306-
"""
307-
if not isinstance(value, Expr):
308-
raise TypeError(EXPR_TYPE_ERROR)
309-
return value.expr
310-
311-
312289
class DataFrame:
313290
"""Two dimensional table representation of data.
314291
@@ -417,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
417394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
418395
419396
"""
420-
exprs_internal = expr_list_to_raw_expr_list(exprs)
397+
exprs_internal = [
398+
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399+
]
421400
return DataFrame(self.df.select(*exprs_internal))
422401

423402
def drop(self, *columns: str) -> DataFrame:
@@ -447,7 +426,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
447426
"""
448427
df = self.df
449428
for p in predicates:
450-
df = df.filter(_ensure_expr(p))
429+
df = df.filter(p.expr)
451430
return DataFrame(df)
452431

453432
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -460,7 +439,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
460439
Returns:
461440
DataFrame with the new column.
462441
"""
463-
return DataFrame(self.df.with_column(name, _ensure_expr(expr)))
442+
return DataFrame(self.df.with_column(name, expr.expr))
464443

465444
def with_columns(
466445
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
@@ -489,24 +468,17 @@ def with_columns(
489468
def _simplify_expression(
490469
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
491470
) -> list[expr_internal.Expr]:
492-
expr_list: list[expr_internal.Expr] = []
471+
expr_list = []
493472
for expr in exprs:
494-
if isinstance(expr, str):
495-
raise TypeError(EXPR_TYPE_ERROR)
496-
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
497-
expr_value = list(expr)
498-
if any(isinstance(inner, str) for inner in expr_value):
499-
raise TypeError(EXPR_TYPE_ERROR)
473+
if isinstance(expr, Expr):
474+
expr_list.append(expr.expr)
475+
elif isinstance(expr, Iterable):
476+
expr_list.extend(inner_expr.expr for inner_expr in expr)
500477
else:
501-
expr_value = expr
502-
try:
503-
expr_list.extend(expr_list_to_raw_expr_list(expr_value))
504-
except TypeError as err:
505-
raise TypeError(EXPR_TYPE_ERROR) from err
506-
for alias, expr in named_exprs.items():
507-
if not isinstance(expr, Expr):
508-
raise TypeError(EXPR_TYPE_ERROR)
509-
expr_list.append(expr.alias(alias).expr)
478+
raise NotImplementedError
479+
if named_exprs:
480+
for alias, expr in named_exprs.items():
481+
expr_list.append(expr.alias(alias).expr)
510482
return expr_list
511483

512484
expressions = _simplify_expression(*exprs, **named_exprs)
@@ -531,43 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
531503
return DataFrame(self.df.with_column_renamed(old_name, new_name))
532504

533505
def aggregate(
534-
self,
535-
group_by: list[Expr | str] | Expr | str,
536-
aggs: list[Expr] | Expr,
506+
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
537507
) -> DataFrame:
538508
"""Aggregates the rows of the current DataFrame.
539509
540510
Args:
541-
group_by: List of expressions or column names to group by.
511+
group_by: List of expressions to group by.
542512
aggs: List of expressions to aggregate.
543513
544514
Returns:
545515
DataFrame after aggregation.
546516
"""
547-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
548-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
517+
group_by = group_by if isinstance(group_by, list) else [group_by]
518+
aggs = aggs if isinstance(aggs, list) else [aggs]
549519

550-
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
551-
aggs_exprs = []
552-
for agg in aggs_list:
553-
if not isinstance(agg, Expr):
554-
raise TypeError(EXPR_TYPE_ERROR)
555-
aggs_exprs.append(agg.expr)
556-
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
520+
group_by = [e.expr for e in group_by]
521+
aggs = [e.expr for e in aggs]
522+
return DataFrame(self.df.aggregate(group_by, aggs))
557523

558-
def sort(self, *exprs: SortKey) -> DataFrame:
559-
"""Sort the DataFrame by the specified sorting expressions or column names.
524+
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525+
"""Sort the DataFrame by the specified sorting expressions.
560526
561527
Note that any expression can be turned into a sort expression by
562-
calling its ``sort`` method.
528+
calling its` ``sort`` method.
563529
564530
Args:
565-
exprs: Sort expressions or column names, applied in order.
531+
exprs: Sort expressions, applied in order.
566532
567533
Returns:
568534
DataFrame after sorting.
569535
"""
570-
exprs_raw = sort_list_to_raw_sort_list(list(exprs))
536+
exprs_raw = [sort_or_default(expr) for expr in exprs]
571537
return DataFrame(self.df.sort(*exprs_raw))
572538

573539
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -791,7 +757,7 @@ def join_on(
791757
Returns:
792758
DataFrame after join.
793759
"""
794-
exprs = [_ensure_expr(expr) for expr in on_exprs]
760+
exprs = [expr.expr for expr in on_exprs]
795761
return DataFrame(self.df.join_on(right.df, exprs, how))
796762

797763
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

0 commit comments

Comments
 (0)