@@ -2,7 +2,7 @@ use pyo3::{types::*, Bound};
22use serde:: de:: { self , DeserializeOwned , IntoDeserializer } ;
33use serde:: Deserialize ;
44
5- use crate :: error:: { PythonizeError , Result } ;
5+ use crate :: error:: { ErrorImpl , PythonizeError , Result } ;
66
77/// Attempt to convert a Python object to an instance of `T`
88pub fn depythonize < ' a , ' py , T > ( obj : & ' a Bound < ' py , PyAny > ) -> Result < T >
@@ -44,6 +44,19 @@ impl<'a, 'py> Depythonizer<'a, 'py> {
4444 }
4545 }
4646
47+ fn set_access ( & self ) -> Result < PySetAsSequence < ' py > > {
48+ match self . input . downcast :: < PySet > ( ) {
49+ Ok ( set) => Ok ( PySetAsSequence :: from_set ( & set) ) ,
50+ Err ( e) => {
51+ if let Ok ( f) = self . input . downcast :: < PyFrozenSet > ( ) {
52+ Ok ( PySetAsSequence :: from_frozenset ( & f) )
53+ } else {
54+ Err ( e. into ( ) )
55+ }
56+ }
57+ }
58+ }
59+
4760 fn dict_access ( & self ) -> Result < PyMappingAccess < ' py > > {
4861 PyMappingAccess :: new ( self . input . downcast ( ) ?)
4962 }
@@ -122,10 +135,9 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
122135 self . deserialize_bytes ( visitor)
123136 } else if obj. is_instance_of :: < PyFloat > ( ) {
124137 self . deserialize_f64 ( visitor)
125- } else if obj. is_instance_of :: < PyFrozenSet > ( )
126- || obj. is_instance_of :: < PySet > ( )
127- || obj. downcast :: < PySequence > ( ) . is_ok ( )
128- {
138+ } else if obj. is_instance_of :: < PyFrozenSet > ( ) || obj. is_instance_of :: < PySet > ( ) {
139+ self . deserialize_seq ( visitor)
140+ } else if obj. downcast :: < PySequence > ( ) . is_ok ( ) {
129141 self . deserialize_tuple ( obj. len ( ) ?, visitor)
130142 } else if obj. downcast :: < PyMapping > ( ) . is_ok ( ) {
131143 self . deserialize_map ( visitor)
@@ -238,7 +250,18 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
238250 where
239251 V : de:: Visitor < ' de > ,
240252 {
241- visitor. visit_seq ( self . sequence_access ( None ) ?)
253+ match self . sequence_access ( None ) {
254+ Ok ( seq) => visitor. visit_seq ( seq) ,
255+ Err ( e) => {
256+ // we allow sets to be deserialized as sequences, so try that
257+ if matches ! ( * e. inner, ErrorImpl :: UnexpectedType ( _) ) {
258+ if let Ok ( set) = self . set_access ( ) {
259+ return visitor. visit_seq ( set) ;
260+ }
261+ }
262+ Err ( e)
263+ }
264+ }
242265 }
243266
244267 fn deserialize_tuple < V > ( self , len : usize , visitor : V ) -> Result < V :: Value >
@@ -357,6 +380,40 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> {
357380 }
358381}
359382
383+ struct PySetAsSequence < ' py > {
384+ iter : Bound < ' py , PyIterator > ,
385+ }
386+
387+ impl < ' py > PySetAsSequence < ' py > {
388+ fn from_set ( set : & Bound < ' py , PySet > ) -> Self {
389+ Self {
390+ iter : PyIterator :: from_bound_object ( & set) . expect ( "set is always iterable" ) ,
391+ }
392+ }
393+
394+ fn from_frozenset ( set : & Bound < ' py , PyFrozenSet > ) -> Self {
395+ Self {
396+ iter : PyIterator :: from_bound_object ( & set) . expect ( "frozenset is always iterable" ) ,
397+ }
398+ }
399+ }
400+
401+ impl < ' de > de:: SeqAccess < ' de > for PySetAsSequence < ' _ > {
402+ type Error = PythonizeError ;
403+
404+ fn next_element_seed < T > ( & mut self , seed : T ) -> Result < Option < T :: Value > >
405+ where
406+ T : de:: DeserializeSeed < ' de > ,
407+ {
408+ match self . iter . next ( ) {
409+ Some ( item) => seed
410+ . deserialize ( & mut Depythonizer :: from_object ( & item?) )
411+ . map ( Some ) ,
412+ None => Ok ( None ) ,
413+ }
414+ }
415+ }
416+
360417struct PyMappingAccess < ' py > {
361418 keys : Bound < ' py , PySequence > ,
362419 values : Bound < ' py , PySequence > ,
@@ -606,6 +663,22 @@ mod test {
606663 test_de ( code, & expected, & expected_json) ;
607664 }
608665
666+ #[ test]
667+ fn test_vec_from_pyset ( ) {
668+ let expected = vec ! [ "foo" . to_string( ) ] ;
669+ let expected_json = json ! ( [ "foo" ] ) ;
670+ let code = "{'foo'}" ;
671+ test_de ( code, & expected, & expected_json) ;
672+ }
673+
674+ #[ test]
675+ fn test_vec_from_pyfrozenset ( ) {
676+ let expected = vec ! [ "foo" . to_string( ) ] ;
677+ let expected_json = json ! ( [ "foo" ] ) ;
678+ let code = "frozenset({'foo'})" ;
679+ test_de ( code, & expected, & expected_json) ;
680+ }
681+
609682 #[ test]
610683 fn test_vec ( ) {
611684 let expected = vec ! [ 3 , 2 , 1 ] ;
0 commit comments