diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 84478f24cf..07c836fb32 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -3,6 +3,7 @@ import builtins import ipaddress import uuid +import warnings import weakref from collections.abc import Mapping, Sequence, Set from datetime import date, datetime, time, timedelta @@ -214,6 +215,7 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, + coerce_numbers_to_str: Optional[bool] = None, gt: Optional[float] = None, ge: Optional[float] = None, lt: Optional[float] = None, @@ -226,9 +228,12 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, + fail_fast: Optional[bool] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, + validate_default: Optional[bool] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: Any = Undefined, @@ -257,6 +262,7 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, + coerce_numbers_to_str: Optional[bool] = None, gt: Optional[float] = None, ge: Optional[float] = None, lt: Optional[float] = None, @@ -269,9 +275,12 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, + fail_fast: Optional[bool] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, + validate_default: Optional[bool] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: str, @@ -309,6 +318,7 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, + coerce_numbers_to_str: Optional[bool] = None, gt: Optional[float] = None, ge: Optional[float] = None, lt: Optional[float] = None, @@ -321,9 +331,12 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, + fail_fast: Optional[bool] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, + validate_default: Optional[bool] = None, repr: bool = True, sa_column: Union[Column[Any], UndefinedType] = Undefined, schema_extra: Optional[dict[str, Any]] = None, @@ -342,6 +355,7 @@ def Field( exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None, const: Optional[bool] = None, + coerce_numbers_to_str: Optional[bool] = None, gt: Optional[float] = None, ge: Optional[float] = None, lt: Optional[float] = None, @@ -354,9 +368,12 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, + fail_fast: Optional[bool] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, + validate_default: Optional[bool] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: Any = Undefined, @@ -371,9 +388,27 @@ def Field( schema_extra: Optional[dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + + for param_name in ( + "coerce_numbers_to_str", + "validate_default", + "union_mode", + "fail_fast", + ): + if param_name in current_schema_extra: + msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`" + warnings.warn(msg, UserWarning, stacklevel=2) + # Extract possible alias settings from schema_extra so we can control precedence schema_validation_alias = current_schema_extra.pop("validation_alias", None) schema_serialization_alias = current_schema_extra.pop("serialization_alias", None) + current_coerce_numbers_to_str = coerce_numbers_to_str or current_schema_extra.pop( + "coerce_numbers_to_str", None + ) + current_validate_default = validate_default or current_schema_extra.pop( + "validate_default", None + ) + current_fail_fast = fail_fast or current_schema_extra.pop("fail_fast", None) field_info_kwargs = { "alias": alias, "title": title, @@ -381,6 +416,8 @@ def Field( "exclude": exclude, "include": include, "const": const, + "coerce_numbers_to_str": current_coerce_numbers_to_str, + "validate_default": current_validate_default, "gt": gt, "ge": ge, "lt": lt, @@ -393,6 +430,7 @@ def Field( "unique_items": unique_items, "min_length": min_length, "max_length": max_length, + "fail_fast": current_fail_fast, "allow_mutation": allow_mutation, "regex": regex, "discriminator": discriminator, @@ -418,6 +456,10 @@ def Field( serialization_alias or schema_serialization_alias or alias ) + current_union_mode = union_mode or current_schema_extra.pop("union_mode", None) + if current_union_mode is not None: + field_info_kwargs["union_mode"] = current_union_mode + field_info = FieldInfo( default, default_factory=default_factory, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 140b02fd9b..284cb81eb4 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -3,7 +3,7 @@ import pytest from pydantic import ValidationError -from sqlmodel import Field, SQLModel +from sqlmodel import Field, Session, SQLModel, create_engine def test_decimal(): @@ -54,3 +54,192 @@ class Model(SQLModel): instance = Model(id=123, foo="bar") assert "foo=" not in repr(instance) + + +def test_coerce_numbers_to_str_true(): + class Model(SQLModel): + val: str = Field(coerce_numbers_to_str=True) + + assert Model.model_validate({"val": 123}).val == "123" + assert Model.model_validate({"val": 45.67}).val == "45.67" + + +@pytest.mark.parametrize("coerce_numbers_to_str", [None, False]) +def test_coerce_numbers_to_str_false(coerce_numbers_to_str: Optional[bool]): + class Model2(SQLModel): + val: str = Field(coerce_numbers_to_str=coerce_numbers_to_str) + + with pytest.raises(ValidationError): + Model2.model_validate({"val": 123}) + + +def test_coerce_numbers_to_str_via_schema_extra(): # Current workaround. Remove after some time + with pytest.warns( + UserWarning, + match=( + "Pass `coerce_numbers_to_str` parameter directly to Field instead of passing " + "it via `schema_extra`" + ), + ): + + class Model(SQLModel): + val: str = Field(schema_extra={"coerce_numbers_to_str": True}) + + assert Model.model_validate({"val": 123}).val == "123" + assert Model.model_validate({"val": 45.67}).val == "45.67" + + +def test_validate_default_true(): + class Model(SQLModel): + val: int = Field(default="123", validate_default=True) + + assert Model.model_validate({}).val == 123 + + class Model2(SQLModel): + val: int = Field(default=None, validate_default=True) + + with pytest.raises(ValidationError): + Model2.model_validate({}) + + +def test_validate_default_table_model(): + class Model(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + val: int = Field(default="123", validate_default=True) + + class ModelDB(Model, table=True): + pass + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + model = ModelDB() + with Session(engine) as session: + session.add(model) + session.commit() + session.refresh(model) + + assert model.val == 123 + + +@pytest.mark.parametrize("validate_default", [None, False]) +def test_validate_default_false(validate_default: Optional[bool]): + class Model3(SQLModel): + val: int = Field(default="123", validate_default=validate_default) + + assert Model3().val == "123" + + +def test_validate_default_via_schema_extra(): # Current workaround. Remove after some time + with pytest.warns( + UserWarning, + match=( + "Pass `validate_default` parameter directly to Field instead of passing " + "it via `schema_extra`" + ), + ): + + class Model(SQLModel): + val: int = Field(default="123", schema_extra={"validate_default": True}) + + assert Model.model_validate({}).val == 123 + + +@pytest.mark.parametrize("union_mode", [None, "smart"]) +def test_union_mode_smart(union_mode: Optional[Literal["smart"]]): + class Model(SQLModel): + val: Union[float, int] = Field(union_mode=union_mode) + + a = Model.model_validate({"val": 123}) + assert isinstance(a.val, int) # float is first, but int is more precise + + b = Model.model_validate({"val": 123.0}) + assert isinstance(b.val, float) + + c = Model.model_validate({"val": 123.1}) + assert isinstance(c.val, float) + + +def test_union_mode_left_to_right(): + class Model(SQLModel): + val: Union[float, int] = Field(union_mode="left_to_right") + + a = Model.model_validate({"val": 123}) + assert isinstance(a.val, float) + + b = Model.model_validate({"val": 123.0}) + assert isinstance(b.val, float) + + c = Model.model_validate({"val": 123.1}) + assert isinstance(c.val, float) + + +def test_union_mode_via_schema_extra(): # Current workaround. Remove after some time + with pytest.warns( + UserWarning, + match=( + "Pass `union_mode` parameter directly to Field instead of passing " + "it via `schema_extra`" + ), + ): + + class Model(SQLModel): + val: Union[float, int] = Field(schema_extra={"union_mode": "smart"}) + + a = Model.model_validate({"val": 123}) + assert isinstance(a.val, int) # float is first, but int is more precise + + b = Model.model_validate({"val": 123.0}) + assert isinstance(b.val, float) + + c = Model.model_validate({"val": 123.1}) + assert isinstance(c.val, float) + + +def test_fail_fast_true(): + class Model(SQLModel): + val: list[int] = Field(fail_fast=True) + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({"val": [1.1, "not an int"]}) + + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "int_from_float" + + +@pytest.mark.parametrize("fail_fast", [None, False]) +def test_fail_fast_false(fail_fast: Optional[bool]): + class Model(SQLModel): + val: list[int] = Field(fail_fast=fail_fast) + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({"val": [1.1, "not an int"]}) + + errors = exc_info.value.errors() + assert len(errors) == 2 + error_types = {error["type"] for error in errors} + + assert "int_from_float" in error_types + assert "int_parsing" in error_types + + +def test_fail_fast_via_schema_extra(): # Current workaround. Remove after some time + with pytest.warns( + UserWarning, + match=( + "Pass `fail_fast` parameter directly to Field instead of passing " + "it via `schema_extra`" + ), + ): + + class Model(SQLModel): + val: list[int] = Field(schema_extra={"fail_fast": True}) + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({"val": [1.1, "not an int"]}) + + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "int_from_float"