diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index b29d0d9e48..e52ed5e0ab 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -603,6 +603,26 @@ def _(self, type_var: BooleanType) -> Literal[bool]: else: raise ValueError(f"Could not convert {self.value} into a {type_var}") + @to.register(FloatType) + def _(self, type_var: FloatType) -> Literal[float]: + try: + number = float(self.value) + if FloatType.max < number: + return FloatAboveMax() + elif FloatType.min > number: + return FloatBelowMin() + return FloatLiteral(number) + except ValueError as e: + raise ValueError(f"Could not convert {self.value} into a {type_var}") from e + + @to.register(DoubleType) + def _(self, type_var: DoubleType) -> Literal[float]: + try: + number = float(self.value) + return DoubleLiteral(number) + except ValueError as e: + raise ValueError(f"Could not convert {self.value} into a {type_var}") from e + def __repr__(self) -> str: """Return the string representation of the StringLiteral class.""" return f"literal({repr(self.value)})" diff --git a/tests/expressions/test_literals.py b/tests/expressions/test_literals.py index 6144e32776..4d8f5557f6 100644 --- a/tests/expressions/test_literals.py +++ b/tests/expressions/test_literals.py @@ -393,6 +393,22 @@ def test_string_to_boolean_literal() -> None: assert literal("FALSE").to(BooleanType()) == literal(False) +def test_string_to_float_literal() -> None: + assert literal("3.141").to(FloatType()) == literal(3.141).to(FloatType()) + + +def test_string_to_float_outside_bound() -> None: + big_lit_str = literal(str(FloatType.max + 1.0e37)) + assert big_lit_str.to(FloatType()) == FloatAboveMax() + + small_lit_str = literal(str(FloatType.min - 1.0e37)) + assert small_lit_str.to(FloatType()) == FloatBelowMin() + + +def test_string_to_double_literal() -> None: + assert literal("3.141").to(DoubleType()) == literal(3.141) + + @pytest.mark.parametrize( "val", ["unknown", "off", "on", "0", "1", "y", "yes", "n", "no", "t", "f"], @@ -744,7 +760,7 @@ def test_invalid_decimal_conversions() -> None: def test_invalid_string_conversions() -> None: assert_invalid_conversions( literal("abc"), - [FloatType(), DoubleType(), FixedType(1), BinaryType()], + [FixedType(1), BinaryType()], )