Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@
ICEBERG_SCHEMA = b"iceberg.schema"
# The PARQUET: in front means that it is Parquet specific, in this case the field_id
PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"
# ORC field ID key for Iceberg field IDs in ORC metadata
ORC_FIELD_ID_KEY = b"iceberg.id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fokko I don't have a ton of contexts on this. Do you think this is required for this PR? could it be a separate PR?

PYARROW_FIELD_DOC_KEY = b"doc"
LIST_ELEMENT_NAME = "element"
MAP_KEY_NAME = "key"
Expand Down Expand Up @@ -690,16 +692,20 @@ def schema_to_pyarrow(
schema: Union[Schema, IcebergType],
metadata: Dict[bytes, bytes] = EMPTY_DICT,
include_field_ids: bool = True,
file_format: FileFormat = FileFormat.PARQUET,
) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids, file_format))


class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
_metadata: Dict[bytes, bytes]

def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True) -> None:
def __init__(
self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True, file_format: Optional[FileFormat] = None
) -> None:
self._metadata = metadata
self._include_field_ids = include_field_ids
self._file_format = file_format

def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
return pa.schema(list(struct_result), metadata=self._metadata)
Expand All @@ -712,7 +718,12 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
# Add field ID based on file format
if self._file_format == FileFormat.ORC:
metadata[ORC_FIELD_ID_KEY] = str(field.field_id)
else:
# Default to Parquet for backward compatibility
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)

return pa.field(
name=field.name,
Expand Down Expand Up @@ -1011,6 +1022,10 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
if file_format == FileFormat.PARQUET:
return ds.ParquetFileFormat(**kwargs)
elif file_format == FileFormat.ORC:
# ORC doesn't support pre_buffer and buffer_size parameters
orc_kwargs = {k: v for k, v in kwargs.items() if k not in ["pre_buffer", "buffer_size"]}
return ds.OrcFileFormat(**orc_kwargs)
else:
raise ValueError(f"Unsupported file format: {file_format}")

Expand All @@ -1027,6 +1042,15 @@ def _read_deletes(io: FileIO, data_file: DataFile) -> Dict[str, pa.ChunkedArray]
file.as_py(): table.filter(pc.field("file_path") == file).column("pos")
for file in table.column("file_path").chunks[0].dictionary
}
elif data_file.file_format == FileFormat.ORC:
with io.new_input(data_file.file_path).open() as fi:
delete_fragment = _get_file_format(data_file.file_format).make_fragment(fi)
table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table()
# For ORC, file_path columns are not dictionary-encoded, so we use unique() directly
return {
path.as_py(): table.filter(pc.field("file_path") == path).column("pos")
for path in table.column("file_path").unique()
}
elif data_file.file_format == FileFormat.PUFFIN:
with io.new_input(data_file.file_path).open() as fi:
payload = fi.read()
Expand Down Expand Up @@ -1228,11 +1252,17 @@ def primitive(self, primitive: pa.DataType) -> T:


def _get_field_id(field: pa.Field) -> Optional[int]:
return (
int(field_id_str.decode())
if (field.metadata and (field_id_str := field.metadata.get(PYARROW_PARQUET_FIELD_ID_KEY)))
else None
)
"""Return the Iceberg field ID from Parquet or ORC metadata if available."""
if field.metadata:
# Try Parquet field ID first
if field_id_bytes := field.metadata.get(PYARROW_PARQUET_FIELD_ID_KEY):
return int(field_id_bytes.decode())

# Fallback: try ORC field ID
if field_id_bytes := field.metadata.get(ORC_FIELD_ID_KEY):
return int(field_id_bytes.decode())

return None


class _HasIds(PyArrowSchemaVisitor[bool]):
Expand Down Expand Up @@ -1495,7 +1525,7 @@ def _task_to_record_batches(
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
downcast_ns_timestamp_to_us: Optional[bool] = None,
) -> Iterator[pa.RecordBatch]:
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with io.new_input(task.file.file_path).open() as fin:
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema
Expand Down Expand Up @@ -1845,6 +1875,8 @@ def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Fi
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
# For projection visitor, we don't know the file format, so default to Parquet
# This is used for schema conversion during reads, not writes
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)

return pa.field(
Expand Down
3 changes: 3 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ class TableProperties:
WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT = True

WRITE_DATA_PATH = "write.data.path"

WRITE_FILE_FORMAT = "write.format.default"
WRITE_FILE_FORMAT_DEFAULT = "parquet"
WRITE_METADATA_PATH = "write.metadata.path"

DELETE_MODE = "write.delete.mode"
Expand Down
44 changes: 44 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2413,6 +2413,32 @@ def example_task(data_file: str) -> FileScanTask:
)


@pytest.fixture
def data_file_orc(table_schema_simple: Schema, tmp_path: str) -> str:
import pyarrow as pa
import pyarrow.orc as orc

from pyiceberg.io.pyarrow import schema_to_pyarrow

table = pa.table(
{"foo": ["a", "b", "c"], "bar": [1, 2, 3], "baz": [True, False, None]},
schema=schema_to_pyarrow(table_schema_simple),
)

file_path = f"{tmp_path}/0000-data.orc"
orc.write_table(table=table, where=file_path)
return file_path


@pytest.fixture
def example_task_orc(data_file_orc: str) -> FileScanTask:
datafile = DataFile.from_args(file_path=data_file_orc, file_format=FileFormat.ORC, file_size_in_bytes=1925)
datafile.spec_id = 0
return FileScanTask(
data_file=datafile,
)


@pytest.fixture(scope="session")
def warehouse(tmp_path_factory: pytest.TempPathFactory) -> Path:
return tmp_path_factory.mktemp("test_sql")
Expand Down Expand Up @@ -2442,6 +2468,24 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
)


@pytest.fixture
def table_v2_orc(example_table_metadata_v2: Dict[str, Any]) -> Table:
import copy

metadata_dict = copy.deepcopy(example_table_metadata_v2)
if not metadata_dict["properties"]:
metadata_dict["properties"] = {}
metadata_dict["properties"]["write.format.default"] = "ORC"
table_metadata = TableMetadataV2(**metadata_dict)
return Table(
identifier=("database", "table_orc"),
metadata=table_metadata,
metadata_location=f"{table_metadata.location}/uuid.metadata.json",
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def table_v2_with_fixed_and_decimal_types(
table_metadata_v2_with_fixed_and_decimal_types: Dict[str, Any],
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError
from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not
from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, _dataframe_to_data_files
from pyiceberg.manifest import FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import TableProperties
Expand Down Expand Up @@ -709,6 +710,78 @@ def test_write_parquet_unsupported_properties(
tbl.append(arrow_table_with_null)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_spark_writes_orc_pyiceberg_reads(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
"""Test that ORC files written by Spark can be read by PyIceberg."""
identifier = f"default.spark_writes_orc_pyiceberg_reads_v{format_version}"

# Create test data
test_data = [
(1, "Alice", 25, True),
(2, "Bob", 30, False),
(3, "Charlie", 35, True),
(4, "David", 28, True),
(5, "Eve", 32, False),
]

# Create Spark DataFrame
spark_df = spark.createDataFrame(test_data, ["id", "name", "age", "is_active"])

# Ensure a clean slate to avoid replacing a v2 table with v1
spark.sql(f"DROP TABLE IF EXISTS {identifier}")

# Create table with Spark using ORC format and desired format-version
spark_df.writeTo(identifier).using("iceberg").tableProperty("write.format.default", "orc").tableProperty(
"format-version", str(format_version)
).createOrReplace()

# Write data with ORC format using Spark
spark_df.writeTo(identifier).using("iceberg").append()

# Read with PyIceberg - this is the main focus of our validation
tbl = session_catalog.load_table(identifier)
pyiceberg_df = tbl.scan().to_pandas()

# Verify PyIceberg results have the expected number of rows
assert len(pyiceberg_df) == 10 # 5 rows from create + 5 rows from append

# Verify PyIceberg column names
assert list(pyiceberg_df.columns) == ["id", "name", "age", "is_active"]

# Verify PyIceberg data integrity - check the actual data values
expected_data = [
(1, "Alice", 25, True),
(2, "Bob", 30, False),
(3, "Charlie", 35, True),
(4, "David", 28, True),
(5, "Eve", 32, False),
]

# Verify PyIceberg results contain the expected data (appears twice due to create + append)
pyiceberg_data = list(zip(pyiceberg_df["id"], pyiceberg_df["name"], pyiceberg_df["age"], pyiceberg_df["is_active"]))
assert pyiceberg_data == expected_data + expected_data # Data should appear twice

# Verify PyIceberg data types are correct
assert pyiceberg_df["id"].dtype == "int64"
assert pyiceberg_df["name"].dtype == "object" # string
assert pyiceberg_df["age"].dtype == "int64"
assert pyiceberg_df["is_active"].dtype == "bool"

# Cross-validate with Spark to ensure consistency (ensure deterministic ordering)
spark_result = spark.sql(f"SELECT * FROM {identifier}").toPandas()
sort_cols = ["id", "name", "age", "is_active"]
spark_result = spark_result.sort_values(by=sort_cols).reset_index(drop=True)
pyiceberg_df = pyiceberg_df.sort_values(by=sort_cols).reset_index(drop=True)
pandas.testing.assert_frame_equal(spark_result, pyiceberg_df, check_dtype=False)

# Verify the files are actually ORC format
files = list(tbl.scan().plan_files())
assert len(files) > 0
for file_task in files:
assert file_task.file.file_format == FileFormat.ORC


@pytest.mark.integration
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_data_files"
Expand Down
Loading