Skip to content

Commit 35900a3

Browse files
committed
UNPICK
1 parent 268f2fa commit 35900a3

File tree

4 files changed

+28
-193
lines changed

4 files changed

+28
-193
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -126,49 +126,6 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129-
String Columns and Expressions
130-
------------------------------
131-
132-
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133-
existing column. These include:
134-
135-
* :py:meth:`~datafusion.DataFrame.select`
136-
* :py:meth:`~datafusion.DataFrame.sort`
137-
* :py:meth:`~datafusion.DataFrame.drop`
138-
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139-
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140-
141-
For such methods, you can pass column names directly:
142-
143-
.. code-block:: python
144-
145-
from datafusion import col, functions as f
146-
147-
df.sort('id')
148-
df.aggregate('id', [f.count(col('value'))])
149-
150-
The same operation can also be written with an explicit column expression:
151-
152-
.. code-block:: python
153-
154-
from datafusion import col, functions as f
155-
156-
df.sort(col('id'))
157-
df.aggregate(col('id'), [f.count(col('value'))])
158-
159-
Whenever an argument represents an expression—such as in
160-
:py:meth:`~datafusion.DataFrame.filter` or
161-
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162-
and wrap constant values with ``lit()`` (also available as ``literal()``):
163-
164-
.. code-block:: python
165-
166-
from datafusion import col, lit
167-
df.filter(col('age') > lit(21))
168-
169-
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170-
constant value.
171-
172129
Terminal Operations
173130
-------------------
174131

python/datafusion/dataframe.py

Lines changed: 23 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, expr_list_to_raw_expr_list, sort_or_default
43+
from datafusion.expr import Expr, SortExpr, sort_or_default
4444
from datafusion.plan import ExecutionPlan, LogicalPlan
4545
from datafusion.record_batch import RecordBatchStream
4646

@@ -394,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396
"""
397-
exprs_internal = expr_list_to_raw_expr_list(exprs)
397+
exprs_internal = [
398+
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399+
]
398400
return DataFrame(self.df.select(*exprs_internal))
399401

400402
def drop(self, *columns: str) -> DataFrame:
@@ -424,10 +426,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
424426
"""
425427
df = self.df
426428
for p in predicates:
427-
if isinstance(p, str) or not isinstance(p, Expr):
428-
error = "Use col() or lit() to construct expressions"
429-
raise TypeError(error)
430-
df = df.filter(expr_list_to_raw_expr_list(p)[0])
429+
df = df.filter(p.expr)
431430
return DataFrame(df)
432431

433432
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -440,9 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
440439
Returns:
441440
DataFrame with the new column.
442441
"""
443-
if not isinstance(expr, Expr):
444-
error = "Use col() or lit() to construct expressions"
445-
raise TypeError(error)
446442
return DataFrame(self.df.with_column(name, expr.expr))
447443

448444
def with_columns(
@@ -476,20 +472,12 @@ def _simplify_expression(
476472
for expr in exprs:
477473
if isinstance(expr, Expr):
478474
expr_list.append(expr.expr)
479-
elif isinstance(expr, Iterable) and not isinstance(expr, (str, Expr)):
480-
for inner_expr in expr:
481-
if not isinstance(inner_expr, Expr):
482-
error = "Use col() or lit() to construct expressions"
483-
raise TypeError(error)
484-
expr_list.append(inner_expr.expr)
475+
elif isinstance(expr, Iterable):
476+
expr_list.extend(inner_expr.expr for inner_expr in expr)
485477
else:
486-
error = "Use col() or lit() to construct expressions"
487-
raise TypeError(error)
478+
raise NotImplementedError
488479
if named_exprs:
489480
for alias, expr in named_exprs.items():
490-
if not isinstance(expr, Expr):
491-
error = "Use col() or lit() to construct expressions"
492-
raise TypeError(error)
493481
expr_list.append(expr.alias(alias).expr)
494482
return expr_list
495483

@@ -515,62 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
515503
return DataFrame(self.df.with_column_renamed(old_name, new_name))
516504

517505
def aggregate(
518-
self,
519-
group_by: list[Expr | str] | Expr | str,
520-
aggs: list[Expr] | Expr,
506+
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
521507
) -> DataFrame:
522508
"""Aggregates the rows of the current DataFrame.
523509
524510
Args:
525-
group_by: List of expressions or column names to group by.
511+
group_by: List of expressions to group by.
526512
aggs: List of expressions to aggregate.
527513
528514
Returns:
529515
DataFrame after aggregation.
530516
"""
531-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
532-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
533-
534-
group_by_exprs = []
535-
for e in group_by_list:
536-
if isinstance(e, str):
537-
group_by_exprs.append(Expr.column(e).expr)
538-
elif isinstance(e, Expr):
539-
group_by_exprs.append(e.expr)
540-
else:
541-
error = (
542-
"Expected Expr or column name, found:"
543-
f" {type(e).__name__}. Use col() or lit() to construct expressions"
544-
)
545-
raise TypeError(error)
546-
aggs_exprs = []
547-
for agg in aggs_list:
548-
if not isinstance(agg, Expr):
549-
error = "Use col() or lit() to construct expressions"
550-
raise TypeError(error)
551-
aggs_exprs.append(agg.expr)
552-
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
553-
554-
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
555-
"""Sort the DataFrame by the specified sorting expressions or column names.
517+
group_by = group_by if isinstance(group_by, list) else [group_by]
518+
aggs = aggs if isinstance(aggs, list) else [aggs]
519+
520+
group_by = [e.expr for e in group_by]
521+
aggs = [e.expr for e in aggs]
522+
return DataFrame(self.df.aggregate(group_by, aggs))
523+
524+
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525+
"""Sort the DataFrame by the specified sorting expressions.
556526
557527
Note that any expression can be turned into a sort expression by
558-
calling its ``sort`` method.
528+
calling its` ``sort`` method.
559529
560530
Args:
561-
exprs: Sort expressions or column names, applied in order.
531+
exprs: Sort expressions, applied in order.
562532
563533
Returns:
564534
DataFrame after sorting.
565535
"""
566-
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
567-
raw_exprs_iter = iter(expr_list_to_raw_expr_list(expr_seq))
568-
exprs_raw = []
569-
for e in exprs:
570-
if isinstance(e, SortExpr):
571-
exprs_raw.append(sort_or_default(e))
572-
else:
573-
exprs_raw.append(sort_or_default(Expr(next(raw_exprs_iter))))
536+
exprs_raw = [sort_or_default(expr) for expr in exprs]
574537
return DataFrame(self.df.sort(*exprs_raw))
575538

576539
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -794,12 +757,7 @@ def join_on(
794757
Returns:
795758
DataFrame after join.
796759
"""
797-
exprs = []
798-
for expr in on_exprs:
799-
if not isinstance(expr, Expr):
800-
error = "Use col() or lit() to construct expressions"
801-
raise TypeError(error)
802-
exprs.append(expr.expr)
760+
exprs = [expr.expr for expr in on_exprs]
803761
return DataFrame(self.df.join_on(right.df, exprs, how))
804762

805763
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
2626

2727
import pyarrow as pa
2828

@@ -216,26 +216,12 @@
216216

217217

218218
def expr_list_to_raw_expr_list(
219-
expr_list: Optional[Sequence[Expr | str] | Expr | str],
219+
expr_list: Optional[list[Expr] | Expr],
220220
) -> Optional[list[expr_internal.Expr]]:
221-
"""Convert a sequence of expressions or column names to raw expressions."""
222-
if isinstance(expr_list, (Expr, str)):
221+
"""Helper function to convert an optional list to raw expressions."""
222+
if isinstance(expr_list, Expr):
223223
expr_list = [expr_list]
224-
if expr_list is None:
225-
return None
226-
raw_exprs: list[expr_internal.Expr] = []
227-
for e in expr_list:
228-
if isinstance(e, str):
229-
raw_exprs.append(Expr.column(e).expr)
230-
elif isinstance(e, Expr):
231-
raw_exprs.append(e.expr)
232-
else:
233-
error = (
234-
"Expected Expr or column name, found:"
235-
f" {type(e).__name__}. Use col() or lit() to construct expressions."
236-
)
237-
raise TypeError(error)
238-
return raw_exprs
224+
return [e.expr for e in expr_list] if expr_list is not None else None
239225

240226

241227
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:

python/tests/test_dataframe.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,6 @@ def test_select_mixed_expr_string(df):
227227
assert result.column(1) == pa.array([1, 2, 3])
228228

229229

230-
def test_select_unsupported(df):
231-
with pytest.raises(
232-
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
233-
):
234-
df.select(1)
235-
236-
237230
def test_filter(df):
238231
df1 = df.filter(column("a") > literal(2)).select(
239232
column("a") + column("b"),
@@ -275,34 +268,6 @@ def test_sort(df):
275268
assert table.to_pydict() == expected
276269

277270

278-
def test_sort_string_and_expression_equivalent(df):
279-
from datafusion import col
280-
281-
result_str = df.sort("a").to_pydict()
282-
result_expr = df.sort(col("a")).to_pydict()
283-
assert result_str == result_expr
284-
285-
286-
def test_sort_unsupported(df):
287-
with pytest.raises(
288-
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
289-
):
290-
df.sort(1)
291-
292-
293-
def test_aggregate_string_and_expression_equivalent(df):
294-
from datafusion import col
295-
296-
result_str = df.aggregate("a", [f.count()]).to_pydict()
297-
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
298-
assert result_str == result_expr
299-
300-
301-
def test_filter_string_unsupported(df):
302-
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
303-
df.filter("a > 1")
304-
305-
306271
def test_drop(df):
307272
df = df.drop("c")
308273

@@ -372,11 +337,6 @@ def test_with_column(df):
372337
assert result.column(2) == pa.array([5, 7, 9])
373338

374339

375-
def test_with_column_invalid_expr(df):
376-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
377-
df.with_column("c", "a")
378-
379-
380340
def test_with_columns(df):
381341
df = df.with_columns(
382342
(column("a") + column("b")).alias("c"),
@@ -408,13 +368,6 @@ def test_with_columns(df):
408368
assert result.column(6) == pa.array([5, 7, 9])
409369

410370

411-
def test_with_columns_invalid_expr(df):
412-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
413-
df.with_columns("a")
414-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
415-
df.with_columns(c="a")
416-
417-
418371
def test_cast(df):
419372
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
420373
expected = pa.schema(
@@ -573,25 +526,6 @@ def test_join_on():
573526
assert table.to_pydict() == expected
574527

575528

576-
def test_join_on_invalid_expr():
577-
ctx = SessionContext()
578-
579-
batch = pa.RecordBatch.from_arrays(
580-
[pa.array([1, 2]), pa.array([4, 5])],
581-
names=["a", "b"],
582-
)
583-
df = ctx.create_dataframe([[batch]], "l")
584-
df1 = ctx.create_dataframe([[batch]], "r")
585-
586-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
587-
df.join_on(df1, "a")
588-
589-
590-
def test_aggregate_invalid_aggs(df):
591-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
592-
df.aggregate([], "a")
593-
594-
595529
def test_distinct():
596530
ctx = SessionContext()
597531

0 commit comments

Comments
 (0)