Skip to content

Commit b1bdf74

Browse files
committed
UNPICK
1 parent 0b6303e commit b1bdf74

File tree

4 files changed

+20
-174
lines changed

4 files changed

+20
-174
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -126,49 +126,6 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129-
String Columns and Expressions
130-
------------------------------
131-
132-
Some ``DataFrame`` methods accept plain strings when an argument refers to an
133-
existing column. These include:
134-
135-
* :py:meth:`~datafusion.DataFrame.select`
136-
* :py:meth:`~datafusion.DataFrame.sort`
137-
* :py:meth:`~datafusion.DataFrame.drop`
138-
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139-
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140-
141-
For such methods, you can pass column names directly:
142-
143-
.. code-block:: python
144-
145-
from datafusion import col, functions as f
146-
147-
df.sort('id')
148-
df.aggregate('id', [f.count(col('value'))])
149-
150-
The same operation can also be written with an explicit column expression:
151-
152-
.. code-block:: python
153-
154-
from datafusion import col, functions as f
155-
156-
df.sort(col('id'))
157-
df.aggregate(col('id'), [f.count(col('value'))])
158-
159-
Whenever an argument represents an expression—such as in
160-
:py:meth:`~datafusion.DataFrame.filter` or
161-
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns
162-
and wrap constant values with ``lit()`` (also available as ``literal()``):
163-
164-
.. code-block:: python
165-
166-
from datafusion import col, lit
167-
df.filter(col('age') > lit(21))
168-
169-
Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a
170-
constant value.
171-
172129
Terminal Operations
173130
-------------------
174131

python/datafusion/dataframe.py

Lines changed: 19 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, _to_expr_list, sort_or_default
43+
from datafusion.expr import Expr, SortExpr, sort_or_default
4444
from datafusion.plan import ExecutionPlan, LogicalPlan
4545
from datafusion.record_batch import RecordBatchStream
4646

@@ -394,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396
"""
397-
exprs_internal = _to_expr_list(exprs)
397+
exprs_internal = [
398+
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399+
]
398400
return DataFrame(self.df.select(*exprs_internal))
399401

400402
def drop(self, *columns: str) -> DataFrame:
@@ -424,12 +426,6 @@ def filter(self, *predicates: Expr) -> DataFrame:
424426
"""
425427
df = self.df
426428
for p in predicates:
427-
if not isinstance(p, Expr):
428-
msg = (
429-
f"Expected Expr, got {type(p).__name__}. "
430-
"Use col() or lit() to construct expressions."
431-
)
432-
raise TypeError(msg)
433429
df = df.filter(p.expr)
434430
return DataFrame(df)
435431

@@ -443,12 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
443439
Returns:
444440
DataFrame with the new column.
445441
"""
446-
if not isinstance(expr, Expr):
447-
msg = (
448-
f"Expected Expr, got {type(expr).__name__}. "
449-
"Use col() or lit() to construct expressions."
450-
)
451-
raise TypeError(msg)
452442
return DataFrame(self.df.with_column(name, expr.expr))
453443

454444
def with_columns(
@@ -483,28 +473,11 @@ def _simplify_expression(
483473
if isinstance(expr, Expr):
484474
expr_list.append(expr.expr)
485475
elif isinstance(expr, Iterable):
486-
for inner_expr in expr:
487-
if not isinstance(inner_expr, Expr):
488-
msg = (
489-
f"Expected Expr, got {type(inner_expr).__name__}. "
490-
"Use col() or lit() to construct expressions."
491-
)
492-
raise TypeError(msg)
493-
expr_list.append(inner_expr.expr)
476+
expr_list.extend(inner_expr.expr for inner_expr in expr)
494477
else:
495-
msg = (
496-
f"Expected Expr, got {type(expr).__name__}. "
497-
"Use col() or lit() to construct expressions."
498-
)
499-
raise TypeError(msg)
478+
raise NotImplementedError
500479
if named_exprs:
501480
for alias, expr in named_exprs.items():
502-
if not isinstance(expr, Expr):
503-
msg = (
504-
f"Expected Expr, got {type(expr).__name__}. "
505-
"Use col() or lit() to construct expressions."
506-
)
507-
raise TypeError(msg)
508481
expr_list.append(expr.alias(alias).expr)
509482
return expr_list
510483

@@ -530,56 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
530503
return DataFrame(self.df.with_column_renamed(old_name, new_name))
531504

532505
def aggregate(
533-
self,
534-
group_by: list[Expr | str] | Expr | str,
535-
aggs: list[Expr] | Expr,
506+
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
536507
) -> DataFrame:
537508
"""Aggregates the rows of the current DataFrame.
538509
539510
Args:
540-
group_by: List of expressions or column names to group by.
511+
group_by: List of expressions to group by.
541512
aggs: List of expressions to aggregate.
542513
543514
Returns:
544515
DataFrame after aggregation.
545516
"""
546-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
547-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
517+
group_by = group_by if isinstance(group_by, list) else [group_by]
518+
aggs = aggs if isinstance(aggs, list) else [aggs]
548519

549-
group_by_exprs = [
550-
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
551-
]
552-
aggs_exprs = []
553-
for agg in aggs_list:
554-
if not isinstance(agg, Expr):
555-
msg = (
556-
f"Expected Expr, got {type(agg).__name__}. "
557-
"Use col() or lit() to construct expressions."
558-
)
559-
raise TypeError(msg)
560-
aggs_exprs.append(agg.expr)
561-
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
520+
group_by = [e.expr for e in group_by]
521+
aggs = [e.expr for e in aggs]
522+
return DataFrame(self.df.aggregate(group_by, aggs))
562523

563-
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
564-
"""Sort the DataFrame by the specified sorting expressions or column names.
524+
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
525+
"""Sort the DataFrame by the specified sorting expressions.
565526
566527
Note that any expression can be turned into a sort expression by
567-
calling its ``sort`` method.
528+
calling its` ``sort`` method.
568529
569530
Args:
570-
exprs: Sort expressions or column names, applied in order.
531+
exprs: Sort expressions, applied in order.
571532
572533
Returns:
573534
DataFrame after sorting.
574535
"""
575-
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
576-
raw_exprs_iter = iter(_to_expr_list(expr_seq))
577-
exprs_raw = []
578-
for e in exprs:
579-
if isinstance(e, SortExpr):
580-
exprs_raw.append(sort_or_default(e))
581-
else:
582-
exprs_raw.append(sort_or_default(Expr(next(raw_exprs_iter))))
536+
exprs_raw = [sort_or_default(expr) for expr in exprs]
583537
return DataFrame(self.df.sort(*exprs_raw))
584538

585539
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -803,15 +757,7 @@ def join_on(
803757
Returns:
804758
DataFrame after join.
805759
"""
806-
exprs = []
807-
for expr in on_exprs:
808-
if not isinstance(expr, Expr):
809-
msg = (
810-
f"Expected Expr, got {type(expr).__name__}. "
811-
"Use col() or lit() to construct expressions."
812-
)
813-
raise TypeError(msg)
814-
exprs.append(expr.expr)
760+
exprs = [expr.expr for expr in on_exprs]
815761
return DataFrame(self.df.join_on(right.df, exprs, how))
816762

817763
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
2626

2727
import pyarrow as pa
2828

@@ -215,11 +215,6 @@
215215
]
216216

217217

218-
def _to_expr_list(exprs: Sequence[Expr | str]) -> list[expr_internal.Expr]:
219-
"""Convert a sequence of expressions or column names to raw expressions."""
220-
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in exprs]
221-
222-
223218
def expr_list_to_raw_expr_list(
224219
expr_list: Optional[list[Expr] | Expr],
225220
) -> Optional[list[expr_internal.Expr]]:

python/tests/test_dataframe.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -268,27 +268,6 @@ def test_sort(df):
268268
assert table.to_pydict() == expected
269269

270270

271-
def test_sort_string_and_expression_equivalent(df):
272-
from datafusion import col
273-
274-
result_str = df.sort("a").to_pydict()
275-
result_expr = df.sort(col("a")).to_pydict()
276-
assert result_str == result_expr
277-
278-
279-
def test_aggregate_string_and_expression_equivalent(df):
280-
from datafusion import col
281-
282-
result_str = df.aggregate("a", [f.count()]).to_pydict()
283-
result_expr = df.aggregate(col("a"), [f.count()]).to_pydict()
284-
assert result_str == result_expr
285-
286-
287-
def test_filter_string_unsupported(df):
288-
with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"):
289-
df.filter("a > 1")
290-
291-
292271
def test_drop(df):
293272
df = df.drop("c")
294273

@@ -358,11 +337,6 @@ def test_with_column(df):
358337
assert result.column(2) == pa.array([5, 7, 9])
359338

360339

361-
def test_with_column_invalid_expr(df):
362-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
363-
df.with_column("c", "a")
364-
365-
366340
def test_with_columns(df):
367341
df = df.with_columns(
368342
(column("a") + column("b")).alias("c"),
@@ -394,13 +368,6 @@ def test_with_columns(df):
394368
assert result.column(6) == pa.array([5, 7, 9])
395369

396370

397-
def test_with_columns_invalid_expr(df):
398-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
399-
df.with_columns("a")
400-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
401-
df.with_columns(c="a")
402-
403-
404371
def test_cast(df):
405372
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
406373
expected = pa.schema(
@@ -559,25 +526,6 @@ def test_join_on():
559526
assert table.to_pydict() == expected
560527

561528

562-
def test_join_on_invalid_expr():
563-
ctx = SessionContext()
564-
565-
batch = pa.RecordBatch.from_arrays(
566-
[pa.array([1, 2]), pa.array([4, 5])],
567-
names=["a", "b"],
568-
)
569-
df = ctx.create_dataframe([[batch]], "l")
570-
df1 = ctx.create_dataframe([[batch]], "r")
571-
572-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
573-
df.join_on(df1, "a")
574-
575-
576-
def test_aggregate_invalid_aggs(df):
577-
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
578-
df.aggregate([], "a")
579-
580-
581529
def test_distinct():
582530
ctx = SessionContext()
583531

0 commit comments

Comments
 (0)