Skip to content

Commit f9cafb8

Browse files
committed
refactor: improve DataFrame expression handling, type checking, and docs
- Refactor expression handling and `_simplify_expression` for stronger type checking and clearer error handling - Improve type annotations for `file_sort_order` and `order_by` to support string inputs - Refactor DataFrame `filter` method to better validate predicates - Replace internal error message variable with public constant - Clarify usage of `col()` and `column()` in DataFrame examples
1 parent 61f981b commit f9cafb8

File tree

6 files changed

+218
-52
lines changed

6 files changed

+218
-52
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,51 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, column, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with an explicit column expression:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, column, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
158+
159+
Note that ``column()`` is an alias of ``col()``, so you can use either name.
160+
161+
Whenever an argument represents an expression—such as in
162+
:py:meth:`~datafusion.DataFrame.filter` or
163+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
164+
and wrap constant values with ``lit()`` (also available as ``literal()``):
165+
166+
.. code-block:: python
167+
168+
from datafusion import col, lit
169+
df.filter(col('age') > lit(21))
170+
171+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
172+
constant value.
173+
129174
Terminal Operations
130175
-------------------
131176

python/datafusion/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
556+
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -808,7 +808,7 @@ def register_parquet(
808808
file_extension: str = ".parquet",
809809
skip_metadata: bool = True,
810810
schema: pa.Schema | None = None,
811-
file_sort_order: list[list[SortExpr]] | None = None,
811+
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
812812
) -> None:
813813
"""Register a Parquet file as a table.
814814
@@ -1099,7 +1099,7 @@ def read_parquet(
10991099
file_extension: str = ".parquet",
11001100
skip_metadata: bool = True,
11011101
schema: pa.Schema | None = None,
1102-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
1102+
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
11031103
) -> DataFrame:
11041104
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11051105

python/datafusion/dataframe.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import (
44+
_EXPR_TYPE_ERROR,
45+
Expr,
46+
SortExpr,
47+
expr_list_to_raw_expr_list,
48+
sort_list_to_raw_sort_list,
49+
)
4450
from datafusion.plan import ExecutionPlan, LogicalPlan
4551
from datafusion.record_batch import RecordBatchStream
4652

@@ -394,9 +400,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394400
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395401
396402
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
403+
exprs_internal = expr_list_to_raw_expr_list(exprs)
400404
return DataFrame(self.df.select(*exprs_internal))
401405

402406
def drop(self, *columns: str) -> DataFrame:
@@ -426,6 +430,8 @@ def filter(self, *predicates: Expr) -> DataFrame:
426430
"""
427431
df = self.df
428432
for p in predicates:
433+
if not isinstance(p, Expr):
434+
raise TypeError(_EXPR_TYPE_ERROR)
429435
df = df.filter(p.expr)
430436
return DataFrame(df)
431437

@@ -439,6 +445,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439445
Returns:
440446
DataFrame with the new column.
441447
"""
448+
if not isinstance(expr, Expr):
449+
raise TypeError(_EXPR_TYPE_ERROR)
442450
return DataFrame(self.df.with_column(name, expr.expr))
443451

444452
def with_columns(
@@ -468,17 +476,22 @@ def with_columns(
468476
def _simplify_expression(
469477
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
470478
) -> list[expr_internal.Expr]:
471-
expr_list = []
479+
expr_list: list[expr_internal.Expr] = []
472480
for expr in exprs:
473-
if isinstance(expr, Expr):
474-
expr_list.append(expr.expr)
475-
elif isinstance(expr, Iterable):
476-
expr_list.extend(inner_expr.expr for inner_expr in expr)
477-
else:
478-
raise NotImplementedError
479-
if named_exprs:
480-
for alias, expr in named_exprs.items():
481-
expr_list.append(expr.alias(alias).expr)
481+
if isinstance(expr, str) or (
482+
isinstance(expr, Iterable)
483+
and not isinstance(expr, Expr)
484+
and any(isinstance(inner, str) for inner in expr)
485+
):
486+
raise TypeError(_EXPR_TYPE_ERROR)
487+
try:
488+
expr_list.extend(expr_list_to_raw_expr_list(expr))
489+
except TypeError as err:
490+
raise TypeError(_EXPR_TYPE_ERROR) from err
491+
for alias, expr in named_exprs.items():
492+
if not isinstance(expr, Expr):
493+
raise TypeError(_EXPR_TYPE_ERROR)
494+
expr_list.append(expr.alias(alias).expr)
482495
return expr_list
483496

484497
expressions = _simplify_expression(*exprs, **named_exprs)
@@ -503,37 +516,43 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503516
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504517

505518
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
519+
self,
520+
group_by: list[Expr | str] | Expr | str,
521+
aggs: list[Expr] | Expr,
507522
) -> DataFrame:
508523
"""Aggregates the rows of the current DataFrame.
509524
510525
Args:
511-
group_by: List of expressions to group by.
526+
group_by: List of expressions or column names to group by.
512527
aggs: List of expressions to aggregate.
513528
514529
Returns:
515530
DataFrame after aggregation.
516531
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
532+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
533+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519534

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
535+
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
536+
aggs_exprs = []
537+
for agg in aggs_list:
538+
if not isinstance(agg, Expr):
539+
raise TypeError(_EXPR_TYPE_ERROR)
540+
aggs_exprs.append(agg.expr)
541+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523542

524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
543+
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
544+
"""Sort the DataFrame by the specified sorting expressions or column names.
526545
527546
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
547+
calling its ``sort`` method.
529548
530549
Args:
531-
exprs: Sort expressions, applied in order.
550+
exprs: Sort expressions or column names, applied in order.
532551
533552
Returns:
534553
DataFrame after sorting.
535554
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
555+
exprs_raw = sort_list_to_raw_sort_list(list(exprs))
537556
return DataFrame(self.df.sort(*exprs_raw))
538557

539558
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -757,7 +776,11 @@ def join_on(
757776
Returns:
758777
DataFrame after join.
759778
"""
760-
exprs = [expr.expr for expr in on_exprs]
779+
exprs = []
780+
for expr in on_exprs:
781+
if not isinstance(expr, Expr):
782+
raise TypeError(_EXPR_TYPE_ERROR)
783+
exprs.append(expr.expr)
761784
return DataFrame(self.df.join_on(right.df, exprs, how))
762785

763786
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -39,6 +39,10 @@
3939
if TYPE_CHECKING:
4040
from datafusion.plan import LogicalPlan
4141

42+
43+
# Standard error message for invalid expression types
44+
_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45+
4246
# The following are imported from the internal representation. We may choose to
4347
# give these all proper wrappers, or to simply leave as is. These were added
4448
# in order to support passing the `test_imports` unit test.
@@ -216,12 +220,26 @@
216220

217221

218222
def expr_list_to_raw_expr_list(
219-
expr_list: Optional[list[Expr] | Expr],
223+
expr_list: Optional[Sequence[Expr | str] | Expr | str],
220224
) -> Optional[list[expr_internal.Expr]]:
221-
"""Helper function to convert an optional list to raw expressions."""
222-
if isinstance(expr_list, Expr):
225+
"""Convert a sequence of expressions or column names to raw expressions."""
226+
if isinstance(expr_list, (Expr, str)):
223227
expr_list = [expr_list]
224-
return [e.expr for e in expr_list] if expr_list is not None else None
228+
if expr_list is None:
229+
return None
230+
raw_exprs: list[expr_internal.Expr] = []
231+
for e in expr_list:
232+
if isinstance(e, str):
233+
raw_exprs.append(Expr.column(e).expr)
234+
elif isinstance(e, Expr):
235+
raw_exprs.append(e.expr)
236+
else:
237+
error = (
238+
"Expected Expr or column name, found:"
239+
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
240+
)
241+
raise TypeError(error)
242+
return raw_exprs
225243

226244

227245
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
@@ -232,12 +250,27 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
232250

233251

234252
def sort_list_to_raw_sort_list(
235-
sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
253+
sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str],
236254
) -> Optional[list[expr_internal.SortExpr]]:
237255
"""Helper function to return an optional sort list to raw variant."""
238-
if isinstance(sort_list, (Expr, SortExpr)):
256+
if isinstance(sort_list, (Expr, SortExpr, str)):
239257
sort_list = [sort_list]
240-
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None
258+
if sort_list is None:
259+
return None
260+
raw_sort_list = []
261+
for item in sort_list:
262+
if isinstance(item, str):
263+
expr_obj = Expr.column(item)
264+
elif isinstance(item, (Expr, SortExpr)):
265+
expr_obj = item
266+
else:
267+
error = (
268+
"Expected Expr or column name, found:"
269+
f" {type(item).__name__}. {_EXPR_TYPE_ERROR}."
270+
)
271+
raise TypeError(error)
272+
raw_sort_list.append(sort_or_default(expr_obj))
273+
return raw_sort_list
241274

242275

243276
class Expr:

0 commit comments

Comments
 (0)