diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 5c70636e64..a1fa696f38 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1777,7 +1777,7 @@ def struct( field_arrays.append(array) fields.append(self._construct_field(field, array.type)) elif field.optional: - arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False) + arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids) field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) fields.append(self._construct_field(field, arrow_type)) else: diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 59c795cf75..8575b588b8 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -52,6 +52,7 @@ DateType, DoubleType, IntegerType, + ListType, LongType, NestedField, StringType, @@ -1647,3 +1648,38 @@ def test_abort_table_transaction_on_exception( # Validate the transaction is aborted and no partial update is applied assert len(tbl.scan().to_pandas()) == table_size # type: ignore + + +@pytest.mark.integration +def test_write_optional_list(session_catalog: Catalog) -> None: + identifier = "default.test_write_optional_list" + schema = Schema( + NestedField(field_id=1, name="name", field_type=StringType(), required=False), + NestedField( + field_id=3, + name="my_list", + field_type=ListType(element_id=45, element=StringType(), element_required=False), + required=False, + ), + ) + session_catalog.create_table_if_not_exists(identifier, schema) + + df_1 = pa.Table.from_pylist( + [ + {"name": "one", "my_list": ["test"]}, + {"name": "another", "my_list": ["test"]}, + ] + ) + session_catalog.load_table(identifier).append(df_1) + + assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 2 + + df_2 = pa.Table.from_pylist( + [ + {"name": "one"}, + {"name": "another"}, + ] + ) + session_catalog.load_table(identifier).append(df_2) + + assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4