Skip to content

Commit 77da2d2

Browse files
committed
changes to review
1 parent 0d1e41e commit 77da2d2

File tree

5 files changed

+324
-35
lines changed

5 files changed

+324
-35
lines changed

python/datafusion/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
SessionContext,
4545
SQLOptions,
4646
)
47-
from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions
47+
from .dataframe import (
48+
DataFrame,
49+
DataFrameWriteOptions,
50+
InsertOp,
51+
ParquetColumnOptions,
52+
ParquetWriterOptions,
53+
)
4854
from .dataframe_formatter import configure_formatter
4955
from .expr import Expr, WindowFrame
5056
from .io import read_avro, read_csv, read_json, read_parquet
@@ -71,9 +77,11 @@
7177
"Config",
7278
"DFSchema",
7379
"DataFrame",
80+
"DataFrameWriteOptions",
7481
"Database",
7582
"ExecutionPlan",
7683
"Expr",
84+
"InsertOp",
7785
"LogicalPlan",
7886
"ParquetColumnOptions",
7987
"ParquetWriterOptions",

python/datafusion/dataframe.py

Lines changed: 119 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939
from typing_extensions import deprecated # Python 3.12
4040

4141
from datafusion._internal import DataFrame as DataFrameInternal
42+
from datafusion._internal import DataFrameWriteOptions as DataFrameWriteOptionsInternal
43+
from datafusion._internal import InsertOp as InsertOpInternal
4244
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4345
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4446
from datafusion.expr import (
4547
Expr,
48+
SortExpr,
4649
SortKey,
4750
ensure_expr,
4851
ensure_expr_list,
@@ -939,21 +942,31 @@ def except_all(self, other: DataFrame) -> DataFrame:
939942
"""
940943
return DataFrame(self.df.except_all(other.df))
941944

942-
def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None:
945+
def write_csv(
946+
self,
947+
path: str | pathlib.Path,
948+
with_header: bool = False,
949+
write_options: DataFrameWriteOptions | None = None,
950+
) -> None:
943951
"""Execute the :py:class:`DataFrame` and write the results to a CSV file.
944952
945953
Args:
946954
path: Path of the CSV file to write.
947955
with_header: If true, output the CSV header row.
956+
write_options: Options that impact how the DataFrame is written.
948957
"""
949-
self.df.write_csv(str(path), with_header)
958+
raw_write_options = (
959+
write_options._raw_write_options if write_options is not None else None
960+
)
961+
self.df.write_csv(str(path), with_header, raw_write_options)
950962

951963
@overload
952964
def write_parquet(
953965
self,
954966
path: str | pathlib.Path,
955967
compression: str,
956968
compression_level: int | None = None,
969+
write_options: DataFrameWriteOptions | None = None,
957970
) -> None: ...
958971

959972
@overload
@@ -962,6 +975,7 @@ def write_parquet(
962975
path: str | pathlib.Path,
963976
compression: Compression = Compression.ZSTD,
964977
compression_level: int | None = None,
978+
write_options: DataFrameWriteOptions | None = None,
965979
) -> None: ...
966980

967981
@overload
@@ -970,31 +984,38 @@ def write_parquet(
970984
path: str | pathlib.Path,
971985
compression: ParquetWriterOptions,
972986
compression_level: None = None,
987+
write_options: DataFrameWriteOptions | None = None,
973988
) -> None: ...
974989

975990
def write_parquet(
976991
self,
977992
path: str | pathlib.Path,
978993
compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD,
979994
compression_level: int | None = None,
995+
write_options: DataFrameWriteOptions | None = None,
980996
) -> None:
981997
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
982998
999+
Available compression types are:
1000+
1001+
- "uncompressed": No compression.
1002+
- "snappy": Snappy compression.
1003+
- "gzip": Gzip compression.
1004+
- "brotli": Brotli compression.
1005+
- "lz4": LZ4 compression.
1006+
- "lz4_raw": LZ4_RAW compression.
1007+
- "zstd": Zstandard compression.
1008+
1009+
LZO compression is not yet implemented in arrow-rs and is therefore
1010+
excluded.
1011+
9831012
Args:
9841013
path: Path of the Parquet file to write.
9851014
compression: Compression type to use. Default is "ZSTD".
986-
Available compression types are:
987-
- "uncompressed": No compression.
988-
- "snappy": Snappy compression.
989-
- "gzip": Gzip compression.
990-
- "brotli": Brotli compression.
991-
- "lz4": LZ4 compression.
992-
- "lz4_raw": LZ4_RAW compression.
993-
- "zstd": Zstandard compression.
994-
Note: LZO is not yet implemented in arrow-rs and is therefore excluded.
9951015
compression_level: Compression level to use. For ZSTD, the
9961016
recommended range is 1 to 22, with the default being 4. Higher levels
9971017
provide better compression but slower speed.
1018+
write_options: Options that impact how the DataFrame is written.
9981019
"""
9991020
if isinstance(compression, ParquetWriterOptions):
10001021
if compression_level is not None:
@@ -1012,10 +1033,21 @@ def write_parquet(
10121033
):
10131034
compression_level = compression.get_default_level()
10141035

1015-
self.df.write_parquet(str(path), compression.value, compression_level)
1036+
raw_write_options = (
1037+
write_options._raw_write_options if write_options is not None else None
1038+
)
1039+
self.df.write_parquet(
1040+
str(path),
1041+
compression.value,
1042+
compression_level,
1043+
raw_write_options,
1044+
)
10161045

10171046
def write_parquet_with_options(
1018-
self, path: str | pathlib.Path, options: ParquetWriterOptions
1047+
self,
1048+
path: str | pathlib.Path,
1049+
options: ParquetWriterOptions,
1050+
write_options: DataFrameWriteOptions | None = None,
10191051
) -> None:
10201052
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
10211053
@@ -1024,6 +1056,7 @@ def write_parquet_with_options(
10241056
Args:
10251057
path: Path of the Parquet file to write.
10261058
options: Sets the writer parquet options (see `ParquetWriterOptions`).
1059+
write_options: Options that impact how the DataFrame is written.
10271060
"""
10281061
options_internal = ParquetWriterOptionsInternal(
10291062
options.data_pagesize_limit,
@@ -1060,19 +1093,45 @@ def write_parquet_with_options(
10601093
bloom_filter_ndv=opts.bloom_filter_ndv,
10611094
)
10621095

1096+
raw_write_options = (
1097+
write_options._raw_write_options if write_options is not None else None
1098+
)
10631099
self.df.write_parquet_with_options(
10641100
str(path),
10651101
options_internal,
10661102
column_specific_options_internal,
1103+
raw_write_options,
10671104
)
10681105

1069-
def write_json(self, path: str | pathlib.Path) -> None:
1106+
def write_json(
1107+
self,
1108+
path: str | pathlib.Path,
1109+
write_options: DataFrameWriteOptions | None = None,
1110+
) -> None:
10701111
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.
10711112
10721113
Args:
10731114
path: Path of the JSON file to write.
1115+
write_options: Options that impact how the DataFrame is written.
1116+
"""
1117+
raw_write_options = (
1118+
write_options._raw_write_options if write_options is not None else None
1119+
)
1120+
self.df.write_json(str(path), write_options=raw_write_options)
1121+
1122+
def write_table(
1123+
self, table_name: str, write_options: DataFrameWriteOptions | None = None
1124+
) -> None:
1125+
"""Execute the :py:class:`DataFrame` and write the results to a table.
1126+
1127+
The table must be registered with the session to perform this operation.
1128+
Not all table providers support writing operations. See the individual
1129+
implementations for details.
10741130
"""
1075-
self.df.write_json(str(path))
1131+
raw_write_options = (
1132+
write_options._raw_write_options if write_options is not None else None
1133+
)
1134+
self.df.write_table(table_name, raw_write_options)
10761135

10771136
def to_arrow_table(self) -> pa.Table:
10781137
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Table.
@@ -1220,3 +1279,48 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame:
12201279
- For columns not in subset, the original column is kept unchanged
12211280
"""
12221281
return DataFrame(self.df.fill_null(value, subset))
1282+
1283+
1284+
class InsertOp(Enum):
1285+
"""Insert operation mode.
1286+
1287+
These modes are used by the table writing feature to define how record
1288+
batches should be written to a table.
1289+
"""
1290+
1291+
APPEND = InsertOpInternal.APPEND
1292+
"""Appends new rows to the existing table without modifying any existing rows."""
1293+
1294+
REPLACE = InsertOpInternal.REPLACE
1295+
"""Replace existing rows that collide with the inserted rows.
1296+
1297+
Replacement is typically based on a unique key or primary key.
1298+
"""
1299+
1300+
OVERWRITE = InsertOpInternal.OVERWRITE
1301+
"""Overwrites all existing rows in the table with the new rows."""
1302+
1303+
1304+
class DataFrameWriteOptions:
1305+
"""Writer options for DataFrame.
1306+
1307+
There is no guarantee the table provider supports all writer options.
1308+
See the individual implementation and documentation for details.
1309+
"""
1310+
1311+
def __init__(
1312+
self,
1313+
insert_operation: InsertOp | None = None,
1314+
single_file_output: bool = False,
1315+
partition_by: str | Sequence[str] | None = None,
1316+
sort_by: Expr | SortExpr | Sequence[Expr] | Sequence[SortExpr] | None = None,
1317+
) -> None:
1318+
"""Instantiate writer options for DataFrame."""
1319+
if isinstance(partition_by, str):
1320+
partition_by = [partition_by]
1321+
1322+
sort_by_raw = sort_list_to_raw_sort_list(sort_by)
1323+
1324+
self._raw_write_options = DataFrameWriteOptionsInternal(
1325+
insert_operation, single_file_output, partition_by, sort_by_raw
1326+
)

python/tests/test_dataframe.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import ctypes
1818
import datetime
19+
import itertools
1920
import os
2021
import re
2122
import threading
@@ -40,6 +41,7 @@
4041
from datafusion import (
4142
functions as f,
4243
)
44+
from datafusion.dataframe import DataFrameWriteOptions
4345
from datafusion.dataframe_formatter import (
4446
DataFrameHtmlFormatter,
4547
configure_formatter,
@@ -58,9 +60,7 @@ def ctx():
5860

5961

6062
@pytest.fixture
61-
def df():
62-
ctx = SessionContext()
63-
63+
def df(ctx):
6464
# create a RecordBatch and a new DataFrame from it
6565
batch = pa.RecordBatch.from_arrays(
6666
[pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
@@ -1830,6 +1830,56 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
18301830
assert result == expected
18311831

18321832

1833+
sort_by_cases = [
1834+
(None, [1, 2, 3], "unsorted"),
1835+
(column("c"), [2, 1, 3], "single_column_expr"),
1836+
(column("a").sort(ascending=False), [3, 2, 1], "single_sort_expr"),
1837+
([column("c"), column("b")], [2, 1, 3], "list_col_expr"),
1838+
(
1839+
[column("c").sort(ascending=False), column("b").sort(ascending=False)],
1840+
[3, 1, 2],
1841+
"list_sort_expr",
1842+
),
1843+
]
1844+
1845+
formats = ["csv", "json", "parquet", "table"]
1846+
1847+
1848+
@pytest.mark.parametrize(
1849+
("format", "sort_by", "expected_a"),
1850+
[
1851+
pytest.param(format, sort_by, expected_a, id=f"{format}_{test_id}")
1852+
for format, (sort_by, expected_a, test_id) in itertools.product(
1853+
formats, sort_by_cases
1854+
)
1855+
],
1856+
)
1857+
def test_write_files_with_options(
1858+
ctx, df, tmp_path, format, sort_by, expected_a
1859+
) -> None:
1860+
write_options = DataFrameWriteOptions(sort_by=sort_by)
1861+
1862+
if format == "csv":
1863+
df.write_csv(tmp_path, with_header=True, write_options=write_options)
1864+
ctx.register_csv("test_table", tmp_path)
1865+
elif format == "json":
1866+
df.write_json(tmp_path, write_options=write_options)
1867+
ctx.register_json("test_table", tmp_path)
1868+
elif format == "parquet":
1869+
df.write_parquet(tmp_path, write_options=write_options)
1870+
ctx.register_parquet("test_table", tmp_path)
1871+
elif format == "table":
1872+
batch = pa.RecordBatch.from_arrays([[], [], []], schema=df.schema())
1873+
ctx.register_record_batches("test_table", [[batch]])
1874+
ctx.table("test_table").show()
1875+
df.write_table("test_table", write_options=write_options)
1876+
1877+
result = ctx.table("test_table").to_pydict()["a"]
1878+
ctx.table("test_table").show()
1879+
1880+
assert result == expected_a
1881+
1882+
18331883
@pytest.mark.parametrize("path_to_str", [True, False])
18341884
def test_write_json(ctx, df, tmp_path, path_to_str):
18351885
path = str(tmp_path) if path_to_str else tmp_path
@@ -2322,6 +2372,25 @@ def test_write_parquet_options_error(df, tmp_path):
23222372
df.write_parquet(str(tmp_path), options, compression_level=1)
23232373

23242374

2375+
def test_write_table(ctx, df):
2376+
batch = pa.RecordBatch.from_arrays(
2377+
[pa.array([1, 2, 3])],
2378+
names=["a"],
2379+
)
2380+
2381+
ctx.register_record_batches("t", [[batch]])
2382+
2383+
df = ctx.table("t").with_column("a", column("a") * literal(-1))
2384+
2385+
ctx.table("t").show()
2386+
2387+
df.write_table("t")
2388+
result = ctx.table("t").sort(column("a")).collect()[0][0].to_pylist()
2389+
expected = [-3, -2, -1, 1, 2, 3]
2390+
2391+
assert result == expected
2392+
2393+
23252394
def test_dataframe_export(df) -> None:
23262395
# Guarantees that we have the canonical implementation
23272396
# reading our dataframe export

0 commit comments

Comments
 (0)