diff --git a/pyiceberg/utils/schema_conversion.py b/pyiceberg/utils/schema_conversion.py index ec2fccd509..c3b3f9912e 100644 --- a/pyiceberg/utils/schema_conversion.py +++ b/pyiceberg/utils/schema_conversion.py @@ -171,11 +171,11 @@ def _resolve_union( # This means that null has to come first: # https://avro.apache.org/docs/current/spec.html # type of the default value must match the first element of the union. - if "null" != avro_types[0]: - raise TypeError("Only null-unions are supported") + if avro_types[0] != "null" and avro_types[0] != {"type": "null"}: + raise TypeError(f"Only null-unions are supported, not: {avro_types[0]}") # Filter the null value and return the type - return list(filter(lambda t: t != "null", avro_types))[0], False + return list(filter(lambda t: t != "null" and t != {"type": "null"}, avro_types))[0], False def _convert_schema(self, avro_type: Union[str, Dict[str, Any]]) -> IcebergType: """ diff --git a/tests/utils/test_schema_conversion.py b/tests/utils/test_schema_conversion.py index e60a89563f..1a1ba67cdb 100644 --- a/tests/utils/test_schema_conversion.py +++ b/tests/utils/test_schema_conversion.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=W0212 -from typing import Any, Dict +from typing import Any, Dict, List, Union import pytest @@ -232,13 +232,72 @@ def test_avro_list_required_record() -> None: assert expected_iceberg_schema == iceberg_schema -def test_resolve_union() -> None: +@pytest.mark.parametrize( + "union_type, expected_type, is_required", + [ + # Primitive type without null (should be required) + ("string", "string", True), + ("int", "int", True), + ({"type": "string"}, {"type": "string"}, True), + ({"type": "int"}, {"type": "int"}, True), + # Null as string, followed by a primitive type + (["null", "string"], "string", False), + (["null", "int"], "int", False), + (["null", "long"], "long", False), + (["null", {"type": "bytes"}], {"type": "bytes"}, False), + # Null as dict, followed by a primitive type + ([{"type": "null"}, "string"], "string", False), + ([{"type": "null"}, "int"], "int", False), + ([{"type": "null"}, "long"], "long", False), + ([{"type": "null"}, {"type": "boolean"}], {"type": "boolean"}, False), + ], +) +def test_resolve_union( + union_type: Union[Dict[str, str], List[Union[str, Dict[str, str]]], str], + expected_type: Union[str, Dict[str, Any]], + is_required: bool, +) -> None: + converter = AvroSchemaConversion() + resolved_type, required_status = converter._resolve_union(union_type) + assert resolved_type == expected_type + assert required_status == is_required + + +@pytest.mark.parametrize( + "union_type", + [ + (["null", "string", "long"]), + (["null", "int", "float"]), + (["string", "int", "long"]), + (["null", {"type": "string"}, {"type": "int"}]), + ], +) +def test_resolve_union_too_many_types_fails(union_type: Union[Dict[str, str], List[Union[str, Dict[str, str]]], str]) -> None: + converter = AvroSchemaConversion() with pytest.raises(TypeError) as exc_info: - AvroSchemaConversion()._resolve_union(["null", "string", "long"]) - + converter._resolve_union(union_type) assert "Non-optional types aren't part of the Iceberg specification" in str(exc_info.value) +@pytest.mark.parametrize( + "union_type", + [ + (["string", "null"]), + ([{"type": "string"}, "null"]), + ([{"type": "string"}, {"type": "string"}]), + (["int", {"type": "null"}]), + (["long", {"type": "null"}]), + ([{"type": "string"}, {"type": "null"}]), + (["float", "double"]), + ], +) +def test_resolve_union_non_null_first_fails(union_type: Union[Dict[str, str], List[Union[str, Dict[str, str]]], str]) -> None: + converter = AvroSchemaConversion() + with pytest.raises(TypeError) as exc_info: + converter._resolve_union(union_type) + assert "Only null-unions are supported" in str(exc_info.value) + + def test_nested_type() -> None: # In the case a primitive field is nested assert AvroSchemaConversion()._convert_schema({"type": {"type": "string"}}) == StringType()