diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 8e83b011bf..a33e56581a 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -47,6 +47,7 @@ Field, PrivateAttr, SerializeAsAny, + field_validator, model_serializer, model_validator, ) @@ -310,6 +311,14 @@ class NestedField(IcebergType): ... doc="Just a long" ... )) '2: bar: required long (Just a long)' + >>> str(NestedField( + ... field_id=3, + ... name='baz', + ... field_type="string", + ... required=True, + ... doc="A string field" + ... )) + '3: baz: required string (A string field)' """ field_id: int = Field(alias="id") @@ -320,11 +329,21 @@ class NestedField(IcebergType): initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore + @field_validator("field_type", mode="before") + def convert_field_type(cls, v: Any) -> IcebergType: + """Convert string values into IcebergType instances.""" + if isinstance(v, str): + try: + return IcebergType.handle_primitive_type(v, None) + except ValueError as e: + raise ValueError(f"Unsupported field type: '{v}'") from e + return v + def __init__( self, field_id: Optional[int] = None, name: Optional[str] = None, - field_type: Optional[IcebergType] = None, + field_type: Optional[IcebergType | str] = None, required: bool = False, doc: Optional[str] = None, initial_default: Optional[Any] = None, diff --git a/tests/test_types.py b/tests/test_types.py index b19df17e08..e14ec9dd6c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -62,6 +62,21 @@ (12, BinaryType), ] +primitive_types = { + "boolean": BooleanType, + "int": IntegerType, + "long": LongType, + "float": FloatType, + "double": DoubleType, + "date": DateType, + "time": TimeType, + "timestamp": TimestampType, + "timestamptz": TimestamptzType, + "string": StringType, + "uuid": UUIDType, + "binary": BinaryType, +} + @pytest.mark.parametrize("input_index, input_type", non_parameterized_types) def test_repr_primitive_types(input_index: int, input_type: Type[PrimitiveType]) -> None: @@ -231,6 +246,32 @@ def test_nested_field() -> None: assert "validation errors for NestedField" in str(exc_info.value) +def test_nested_field_complex_type_as_str_unsupported() -> None: + unsupported_types = ["list", "map", "struct"] + for type_str in unsupported_types: + with pytest.raises(ValueError) as exc_info: + _ = NestedField(1, "field", type_str, required=True) + assert f"Unsupported field type: '{type_str}'" in str(exc_info.value) + + +def test_nested_field_primitive_type_as_str() -> None: + for type_str, type_class in primitive_types.items(): + field_var = NestedField( + 1, + "field", + type_str, + required=True, + ) + assert isinstance( + field_var.field_type, type_class + ), f"Expected {type_class.__name__}, got {field_var.field_type.__class__.__name__}" + + # Test that passing 'bool' raises a ValueError, as it should be 'boolean' + with pytest.raises(ValueError) as exc_info: + _ = NestedField(1, "field", "bool", required=True) + assert "Unsupported field type: 'bool'" in str(exc_info.value) + + @pytest.mark.parametrize("input_index,input_type", non_parameterized_types) @pytest.mark.parametrize("check_index,check_type", non_parameterized_types) def test_non_parameterized_type_equality(