Skip to content

Commit f4cd590

Browse files
committed
Revert "UNPICK"
This reverts commit 35900a3.
1 parent 35900a3 commit f4cd590

File tree

4 files changed

+193
-28
lines changed

4 files changed

+193
-28
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,49 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
String Columns and Expressions
130+
------------------------------
131+
132+
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
For such methods, you can pass column names directly:
142+
143+
.. code-block:: python
144+
145+
from datafusion import col, functions as f
146+
147+
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
149+
150+
The same operation can also be written with an explicit column expression:
151+
152+
.. code-block:: python
153+
154+
from datafusion import col, functions as f
155+
156+
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
158+
159+
Whenever an argument represents an expression—such as in
160+
:py:meth:`~datafusion.DataFrame.filter` or
161+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162+
and wrap constant values with ``lit()`` (also available as ``literal()``):
163+
164+
.. code-block:: python
165+
166+
from datafusion import col, lit
167+
df.filter(col('age') > lit(21))
168+
169+
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170+
constant value.
171+
129172
Terminal Operations
130173
-------------------
131174

python/datafusion/dataframe.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import Expr, SortExpr, expr_list_to_raw_expr_list, sort_or_default
4444
from datafusion.plan import ExecutionPlan, LogicalPlan
4545
from datafusion.record_batch import RecordBatchStream
4646

@@ -394,9 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
397+
exprs_internal = expr_list_to_raw_expr_list(exprs)
400398
return DataFrame(self.df.select(*exprs_internal))
401399

402400
def drop(self, *columns: str) -> DataFrame:
@@ -426,7 +424,10 @@ def filter(self, *predicates: Expr) -> DataFrame:
426424
"""
427425
df = self.df
428426
for p in predicates:
429-
df = df.filter(p.expr)
427+
if isinstance(p, str) or not isinstance(p, Expr):
428+
error = "Use col() or lit() to construct expressions"
429+
raise TypeError(error)
430+
df = df.filter(expr_list_to_raw_expr_list(p)[0])
430431
return DataFrame(df)
431432

432433
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -439,6 +440,9 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
439440
Returns:
440441
DataFrame with the new column.
441442
"""
443+
if not isinstance(expr, Expr):
444+
error = "Use col() or lit() to construct expressions"
445+
raise TypeError(error)
442446
return DataFrame(self.df.with_column(name, expr.expr))
443447

444448
def with_columns(
@@ -472,12 +476,20 @@ def _simplify_expression(
472476
for expr in exprs:
473477
if isinstance(expr, Expr):
474478
expr_list.append(expr.expr)
475-
elif isinstance(expr, Iterable):
476-
expr_list.extend(inner_expr.expr for inner_expr in expr)
479+
elif isinstance(expr, Iterable) and not isinstance(expr, (str, Expr)):
480+
for inner_expr in expr:
481+
if not isinstance(inner_expr, Expr):
482+
error = "Use col() or lit() to construct expressions"
483+
raise TypeError(error)
484+
expr_list.append(inner_expr.expr)
477485
else:
478-
raise NotImplementedError
486+
error = "Use col() or lit() to construct expressions"
487+
raise TypeError(error)
479488
if named_exprs:
480489
for alias, expr in named_exprs.items():
490+
if not isinstance(expr, Expr):
491+
error = "Use col() or lit() to construct expressions"
492+
raise TypeError(error)
481493
expr_list.append(expr.alias(alias).expr)
482494
return expr_list
483495

@@ -503,37 +515,62 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503515
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504516

505517
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
518+
self,
519+
group_by: list[Expr | str] | Expr | str,
520+
aggs: list[Expr] | Expr,
507521
) -> DataFrame:
508522
"""Aggregates the rows of the current DataFrame.
509523
510524
Args:
511-
group_by: List of expressions to group by.
525+
group_by: List of expressions or column names to group by.
512526
aggs: List of expressions to aggregate.
513527
514528
Returns:
515529
DataFrame after aggregation.
516530
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
519-
520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
523-
524-
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525-
"""Sort the DataFrame by the specified sorting expressions.
531+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
532+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
533+
534+
group_by_exprs = []
535+
for e in group_by_list:
536+
if isinstance(e, str):
537+
group_by_exprs.append(Expr.column(e).expr)
538+
elif isinstance(e, Expr):
539+
group_by_exprs.append(e.expr)
540+
else:
541+
error = (
542+
"Expected Expr or column name, found:"
543+
f" {type(e).__name__}. Use col() or lit() to construct expressions"
544+
)
545+
raise TypeError(error)
546+
aggs_exprs = []
547+
for agg in aggs_list:
548+
if not isinstance(agg, Expr):
549+
error = "Use col() or lit() to construct expressions"
550+
raise TypeError(error)
551+
aggs_exprs.append(agg.expr)
552+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
553+
554+
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
555+
"""Sort the DataFrame by the specified sorting expressions or column names.
526556
527557
Note that any expression can be turned into a sort expression by
528-
calling its` ``sort`` method.
558+
calling its ``sort`` method.
529559
530560
Args:
531-
exprs: Sort expressions, applied in order.
561+
exprs: Sort expressions or column names, applied in order.
532562
533563
Returns:
534564
DataFrame after sorting.
535565
"""
536-
exprs_raw = [sort_or_default(expr) for expr in exprs]
566+
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
567+
raw_exprs_iter = iter(expr_list_to_raw_expr_list(expr_seq))
568+
exprs_raw = []
569+
for e in exprs:
570+
if isinstance(e, SortExpr):
571+
exprs_raw.append(sort_or_default(e))
572+
else:
573+
exprs_raw.append(sort_or_default(Expr(next(raw_exprs_iter))))
537574
return DataFrame(self.df.sort(*exprs_raw))
538575

539576
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -757,7 +794,12 @@ def join_on(
757794
Returns:
758795
DataFrame after join.
759796
"""
760-
exprs = [expr.expr for expr in on_exprs]
797+
exprs = []
798+
for expr in on_exprs:
799+
if not isinstance(expr, Expr):
800+
error = "Use col() or lit() to construct expressions"
801+
raise TypeError(error)
802+
exprs.append(expr.expr)
761803
return DataFrame(self.df.join_on(right.df, exprs, how))
762804

763805
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -216,12 +216,26 @@
216216

217217

218218
def expr_list_to_raw_expr_list(
219-
expr_list: Optional[list[Expr] | Expr],
219+
expr_list: Optional[Sequence[Expr | str] | Expr | str],
220220
) -> Optional[list[expr_internal.Expr]]:
221-
"""Helper function to convert an optional list to raw expressions."""
222-
if isinstance(expr_list, Expr):
221+
"""Convert a sequence of expressions or column names to raw expressions."""
222+
if isinstance(expr_list, (Expr, str)):
223223
expr_list = [expr_list]
224-
return [e.expr for e in expr_list] if expr_list is not None else None
224+
if expr_list is None:
225+
return None
226+
raw_exprs: list[expr_internal.Expr] = []
227+
for e in expr_list:
228+
if isinstance(e, str):
229+
raw_exprs.append(Expr.column(e).expr)
230+
elif isinstance(e, Expr):
231+
raw_exprs.append(e.expr)
232+
else:
233+
error = (
234+
"Expected Expr or column name, found:"
235+
f" {type(e).__name__}. Use col() or lit() to construct expressions."
236+
)
237+
raise TypeError(error)
238+
return raw_exprs
225239

226240

227241
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

python/tests/test_dataframe.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ def test_select_mixed_expr_string(df):
227227
assert result.column(1) == pa.array([1, 2, 3])
228228

229229

230+
def test_select_unsupported(df):
231+
with pytest.raises(
232+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
233+
):
234+
df.select(1)
235+
236+
230237
def test_filter(df):
231238
df1 = df.filter(column("a") > literal(2)).select(
232239
column("a") + column("b"),
@@ -268,6 +275,34 @@ def test_sort(df):
268275
assert table.to_pydict() == expected
269276

270277

278+
def test_sort_string_and_expression_equivalent(df):
279+
from datafusion import col
280+
281+
result_str = df.sort("a").to_pydict()
282+
result_expr = df.sort(col("a")).to_pydict()
283+
assert result_str == result_expr
284+
285+
286+
def test_sort_unsupported(df):
287+
with pytest.raises(
288+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
289+
):
290+
df.sort(1)
291+
292+
293+
def test_aggregate_string_and_expression_equivalent(df):
294+
from datafusion import col
295+
296+
result_str = df.aggregate("a", [f.count()]).to_pydict()
297+
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
298+
assert result_str == result_expr
299+
300+
301+
def test_filter_string_unsupported(df):
302+
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
303+
df.filter("a > 1")
304+
305+
271306
def test_drop(df):
272307
df = df.drop("c")
273308

@@ -337,6 +372,11 @@ def test_with_column(df):
337372
assert result.column(2) == pa.array([5, 7, 9])
338373

339374

375+
def test_with_column_invalid_expr(df):
376+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
377+
df.with_column("c", "a")
378+
379+
340380
def test_with_columns(df):
341381
df = df.with_columns(
342382
(column("a") + column("b")).alias("c"),
@@ -368,6 +408,13 @@ def test_with_columns(df):
368408
assert result.column(6) == pa.array([5, 7, 9])
369409

370410

411+
def test_with_columns_invalid_expr(df):
412+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
413+
df.with_columns("a")
414+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
415+
df.with_columns(c="a")
416+
417+
371418
def test_cast(df):
372419
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
373420
expected = pa.schema(
@@ -526,6 +573,25 @@ def test_join_on():
526573
assert table.to_pydict() == expected
527574

528575

576+
def test_join_on_invalid_expr():
577+
ctx = SessionContext()
578+
579+
batch = pa.RecordBatch.from_arrays(
580+
[pa.array([1, 2]), pa.array([4, 5])],
581+
names=["a", "b"],
582+
)
583+
df = ctx.create_dataframe([[batch]], "l")
584+
df1 = ctx.create_dataframe([[batch]], "r")
585+
586+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
587+
df.join_on(df1, "a")
588+
589+
590+
def test_aggregate_invalid_aggs(df):
591+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
592+
df.aggregate([], "a")
593+
594+
529595
def test_distinct():
530596
ctx = SessionContext()
531597

0 commit comments

Comments
 (0)