Skip to content

Commit dfe0347

Browse files
norabelroselebrice
andauthored
Literal types no longer cause a warning on decoding (#221)
* Literal types no longer cause a warning on decoding * Python 3.7 compatibility Co-authored-by: Fabrice Normandin <fabrice.normandin@gmail.com> * Add test_literal_decoding * Apply suggestions from code review --------- Co-authored-by: Fabrice Normandin <fabrice.normandin@gmail.com>
1 parent e2c0819 commit dfe0347

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

simple_parsing/helpers/serialization/decoding.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from logging import getLogger
1313
from pathlib import Path
1414
from typing import Any, Callable, TypeVar
15+
from typing_extensions import Literal
1516

1617
from simple_parsing.annotation_utils.get_field_annotations import (
1718
evaluate_string_annotation,
@@ -25,6 +26,7 @@
2526
is_enum,
2627
is_forward_ref,
2728
is_list,
29+
is_literal,
2830
is_set,
2931
is_tuple,
3032
is_typevar,
@@ -222,6 +224,11 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
222224
logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.")
223225
if bound is not None:
224226
return get_decoding_fn(bound)
227+
228+
if is_literal(t):
229+
logger.debug(f"Decoding a Literal field: {t}")
230+
possible_vals = get_type_arguments(t)
231+
return decode_literal(*possible_vals)
225232

226233
# Unknown type.
227234
warnings.warn(
@@ -396,6 +403,26 @@ def _decode_enum(val: str) -> Enum:
396403
return _decode_enum
397404

398405

406+
def decode_literal(*possible_vals: Any) -> Callable[[Any], Any]:
407+
"""Creates a decoding function for a Literal type.
408+
409+
Args:
410+
*possible_vals (Any): The permissible values for the Literal type.
411+
412+
Returns:
413+
Callable[[Any], Any]: A function that checks if a given value is one of the
414+
permissible values for the Literal. If not, raises a TypeError.
415+
"""
416+
417+
def _decode_literal(val: Any) -> Any:
418+
if val not in possible_vals:
419+
raise TypeError(f"Expected one of {possible_vals} for Literal, got {val}")
420+
421+
return val
422+
423+
return _decode_literal
424+
425+
399426
def no_op(v: T) -> T:
400427
"""Decoding function that gives back the value as-is.
401428

test/test_decoding.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from pathlib import Path
55
from test.testutils import Generic, TypeVar
66
from typing import Any, Dict, List, Optional, Tuple, Type
7+
from typing_extensions import Literal
78

89
import pytest
10+
import warnings
911

1012
from simple_parsing.helpers import Serializable, dict_field, list_field
1113
from simple_parsing.helpers.serialization.decoding import (
@@ -38,6 +40,22 @@ class SomeClass(Serializable):
3840
assert SomeClass.loads(b.dumps()) == b
3941

4042

43+
def test_literal_decoding():
44+
@dataclass
45+
class SomeClass(Serializable):
46+
x: Literal["a", "b", "c"] = "a"
47+
48+
# This test should fail if there's a warning on decoding- previous versions
49+
# have raised a UserWarning when decoding a literal, of the form:
50+
# Unable to find a decoding function for annotation typing.Literal['a', 'b', 'c']
51+
# with pytest.warns(UserWarning, match="Unable to find a decoding function"):
52+
# assert SomeClass.loads('{"x": "a"}') == SomeClass()
53+
54+
# Make sure that we can't decode a value that's not in the literal
55+
with pytest.raises(TypeError):
56+
SomeClass.loads('{"x": "d"}')
57+
58+
4159
def test_typevar_decoding(simple_attribute):
4260
@dataclass
4361
class Item(Serializable, decode_into_subclasses=True):

0 commit comments

Comments
 (0)