33from __future__ import annotations
44
55import inspect
6+ import sys
67import warnings
78from collections import OrderedDict
89from collections .abc import Mapping
1213from logging import getLogger
1314from pathlib import Path
1415from typing import Any , Callable , TypeVar
15- from typing_extensions import Literal
1616
1717from simple_parsing .annotation_utils .get_field_annotations import (
1818 evaluate_string_annotation ,
4444_decoding_fns : dict [type [T ], Callable [[Any ], T ]] = {
4545 # the 'primitive' types are decoded using the type fn as a constructor.
4646 t : t
47- for t in [str , float , int , bytes ]
47+ for t in [str , bytes ]
4848}
4949
5050
51- def decode_bool (v : Any ) -> bool :
52- if isinstance (v , str ):
53- return str2bool (v )
54- return bool (v )
51+ def register_decoding_fn (
52+ some_type : type [T ], function : Callable [[Any ], T ], overwrite : bool = False
53+ ) -> None :
54+ """Register a decoding function for the type `some_type`."""
55+ _register (some_type , function , overwrite = overwrite )
56+
57+
58+ def _register (t : type , func : Callable , overwrite : bool = False ) -> None :
59+ if t not in _decoding_fns or overwrite :
60+ # logger.debug(f"Registering the type {t} with decoding function {func}")
61+ _decoding_fns [t ] = func
62+
5563
64+ C = TypeVar ("C" , bound = Callable [[Any ], Any ])
5665
57- _decoding_fns [bool ] = decode_bool
66+
67+ def decoding_fn_for_type (some_type : type ) -> Callable [[C ], C ]:
68+ """Registers a function to be used to convert a serialized value to the given type.
69+
70+ The function should accept one argument (the serialized value) and return the decoded value.
71+ """
72+
73+ def _wrapper (fn : C ) -> C :
74+ register_decoding_fn (some_type , fn , overwrite = True )
75+ return fn
76+
77+ return _wrapper
78+
79+
80+ @decoding_fn_for_type (int )
81+ def _decode_int (v : str ) -> int :
82+ int_v = int (v )
83+ if isinstance (v , bool ):
84+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = int_v ))
85+ elif int_v != float (v ):
86+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = int_v ))
87+ return int_v
88+
89+
90+ @decoding_fn_for_type (float )
91+ def _decode_float (v : Any ) -> float :
92+ float_v = float (v )
93+ if isinstance (v , bool ):
94+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = float_v ))
95+ return float_v
96+
97+
98+ @decoding_fn_for_type (bool )
99+ def _decode_bool (v : Any ) -> bool :
100+ if isinstance (v , str ):
101+ bool_v = str2bool (v )
102+ else :
103+ bool_v = bool (v )
104+ if isinstance (v , (int , float )) and v not in (0 , 1 , 0.0 , 1.0 ):
105+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = bool_v ))
106+ return bool_v
58107
59108
60109def decode_field (
@@ -93,11 +142,36 @@ def decode_field(
93142
94143 decoding_function = get_decoding_fn (field_type )
95144
96- if is_dataclass_type (field_type ) and drop_extra_fields is not None :
97- # Pass the drop_extra_fields argument to the decoding function.
98- return decoding_function (raw_value , drop_extra_fields = drop_extra_fields )
145+ _kwargs = dict (category = UnsafeCastingWarning ) if sys .version_info >= (3 , 11 ) else {}
99146
100- return decoding_function (raw_value )
147+ with warnings .catch_warnings (record = True , ** _kwargs ) as warning_messages :
148+ if is_dataclass_type (field_type ) and drop_extra_fields is not None :
149+ # Pass the drop_extra_fields argument to the decoding function.
150+ decoded_value = decoding_function (raw_value , drop_extra_fields = drop_extra_fields )
151+ else :
152+ decoded_value = decoding_function (raw_value )
153+
154+ for warning_message in warning_messages .copy ():
155+ if not isinstance (warning_message .message , UnsafeCastingWarning ):
156+ warnings .warn_explicit (
157+ message = warning_message .message ,
158+ category = warning_message .category ,
159+ filename = warning_message .filename ,
160+ lineno = warning_message .lineno ,
161+ # module=warning_message.module,
162+ # registry=warning_message.registry,
163+ # module_globals=warning_message.module_globals,
164+ )
165+ warning_messages .remove (warning_message )
166+
167+ if warning_messages :
168+ warnings .warn (
169+ RuntimeWarning (
170+ f"Unsafe casting occurred when deserializing field '{ name } ' of type { field_type } : "
171+ f"raw value: { raw_value !r} , decoded value: { decoded_value !r} ."
172+ )
173+ )
174+ return decoded_value
101175
102176
103177# NOTE: Disabling the caching here might help avoid some bugs, and it's unclear if this has that
@@ -224,7 +298,7 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
224298 logger .debug (f"Decoding a typevar: { t } , bound type is { bound } ." )
225299 if bound is not None :
226300 return get_decoding_fn (bound )
227-
301+
228302 if is_literal (t ):
229303 logger .debug (f"Decoding a Literal field: { t } " )
230304 possible_vals = get_type_arguments (t )
@@ -241,19 +315,6 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
241315 return try_constructor (t )
242316
243317
244- def _register (t : type , func : Callable , overwrite : bool = False ) -> None :
245- if t not in _decoding_fns or overwrite :
246- # logger.debug(f"Registering the type {t} with decoding function {func}")
247- _decoding_fns [t ] = func
248-
249-
250- def register_decoding_fn (
251- some_type : type [T ], function : Callable [[Any ], T ], overwrite : bool = False
252- ) -> None :
253- """Register a decoding function for the type `some_type`."""
254- _register (some_type , function , overwrite = overwrite )
255-
256-
257318def decode_optional (t : type [T ]) -> Callable [[Any | None ], T | None ]:
258319 decode = get_decoding_fn (t )
259320
@@ -281,15 +342,21 @@ def _try_functions(val: Any) -> T | Any:
281342
282343
283344def decode_union (* types : type [T ]) -> Callable [[Any ], T | Any ]:
284- types = list (types )
285- optional = type (None ) in types
345+ types_list = list (types )
346+ optional = type (None ) in types_list
347+
286348 # Partition the Union into None and non-None types.
287- while type (None ) in types :
288- types .remove (type (None ))
349+ while type (None ) in types_list :
350+ types_list .remove (type (None ))
289351
290352 decoding_fns : list [Callable [[Any ], T ]] = [
291- decode_optional (t ) if optional else get_decoding_fn (t ) for t in types
353+ decode_optional (t ) if optional else get_decoding_fn (t ) for t in types_list
292354 ]
355+
356+ # TODO: We could be a bit smarter about the order in which we try the functions, but for now,
357+ # we just try the functions in the same order as the annotation, and return the result from the
358+ # first function that doesn't raise an exception.
359+
293360 # Try using each of the non-None types, in succession. Worst case, return the value.
294361 return try_functions (* decoding_fns )
295362
@@ -455,3 +522,10 @@ def constructor(val):
455522
456523
457524register_decoding_fn (Path , Path )
525+
526+
527+ class UnsafeCastingWarning (RuntimeWarning ):
528+ def __init__ (self , raw_value : Any , decoded_value : Any ) -> None :
529+ super ().__init__ ()
530+ self .raw_value = raw_value
531+ self .decoded_value = decoded_value
0 commit comments