1313from types import ModuleType
1414from typing import IO , Any , Callable , ClassVar , TypeVar , Union
1515
16+ from typing_extensions import Protocol
17+
1618from simple_parsing .utils import (
1719 DataclassT ,
1820 all_subclasses ,
@@ -57,6 +59,97 @@ def ordered_dict_representer(dumper: yaml.Dumper, instance: OrderedDict) -> yaml
5759 pass
5860
5961
62+ class FormatExtension (Protocol ):
63+ binary : ClassVar [bool ] = False
64+
65+ @staticmethod
66+ def load (fp : IO ) -> Any :
67+ ...
68+
69+ @staticmethod
70+ def dump (obj : Any , io : IO ) -> None :
71+ ...
72+
73+
74+ class JSONExtension (FormatExtension ):
75+ load = staticmethod (json .load )
76+ dump = staticmethod (json .dump )
77+
78+
79+ class PickleExtension (FormatExtension ):
80+ binary : ClassVar [bool ] = True
81+ load : ClassVar [Callable [[IO ], Any ]] = staticmethod (pickle .load )
82+ dump : ClassVar [Callable [[Any , IO [bytes ]], None ]] = staticmethod (pickle .dump )
83+
84+
85+ class YamlExtension (FormatExtension ):
86+ def load (self , io : IO ) -> Any :
87+ import yaml
88+
89+ return yaml .safe_load (io )
90+
91+ def dump (self , obj : Any , io : IO , ** kwargs ) -> None :
92+ import yaml
93+
94+ return yaml .dump (obj , io , ** kwargs )
95+
96+
97+ class NumpyExtension (FormatExtension ):
98+ binary : bool = True
99+
100+ def load (self , io : IO ) -> Any :
101+ import numpy
102+
103+ obj = numpy .load (io , allow_pickle = True )
104+ if isinstance (obj , numpy .ndarray ) and obj .dtype == object :
105+ obj = obj .item ()
106+ return obj
107+
108+ def dump (self , obj : Any , io : IO [bytes ], ** kwargs ) -> None :
109+ import numpy
110+
111+ return numpy .save (io , obj , ** kwargs )
112+
113+
114+ class TorchExtension (FormatExtension ):
115+ binary : bool = True
116+
117+ def load (self , io : IO ) -> None :
118+ import torch # type: ignore
119+
120+ return torch .load (io )
121+
122+ def dump (self , obj : Any , io : IO , ** kwargs ) -> None :
123+ import torch # type: ignore
124+
125+ return torch .save (obj , io , ** kwargs )
126+
127+
128+ json_extension = JSONExtension ()
129+ yaml_extension = YamlExtension ()
130+
131+
132+ extensions : dict [str , FormatExtension ] = {
133+ ".json" : JSONExtension (),
134+ ".pkl" : PickleExtension (),
135+ ".yaml" : YamlExtension (),
136+ ".yml" : YamlExtension (),
137+ ".npy" : NumpyExtension (),
138+ ".pth" : TorchExtension (),
139+ }
140+
141+
142+ def get_extension (path : str | Path ) -> FormatExtension :
143+ path = Path (path )
144+ if path .suffix in extensions :
145+ return extensions [path .suffix ]
146+ else :
147+ raise RuntimeError (
148+ f"Cannot load to/save from a { path .suffix } file because "
149+ "this extension is not registered in the extensions dictionary."
150+ )
151+
152+
60153class SerializableMixin :
61154 """Makes a dataclass serializable to and from dictionaries.
62155
@@ -248,17 +341,17 @@ def load_yaml(
248341 """
249342 return load_yaml (cls , path , load_fn = load_fn , drop_extra_fields = drop_extra_fields , ** kwargs )
250343
251- def save (self , path : str | Path , dump_fn = None ) -> None :
252- save (self , path = path , dump_fn = dump_fn )
344+ def save (self , path : str | Path , format : FormatExtension | None = None ) -> None :
345+ save (self , path = path , format = format )
253346
254- def _save (self , path : str | Path , dump_fn : DumpFn = json . dump , ** kwargs ) -> None :
255- save (self , path = path , dump_fn = partial ( dump_fn , ** kwargs ) )
347+ def _save (self , path : str | Path , format : FormatExtension = json_extension , ** kwargs ) -> None :
348+ save (self , path = path , format = format , ** kwargs )
256349
257350 def save_yaml (self , path : str | Path , dump_fn : DumpFn | None = None , ** kwargs ) -> None :
258- save_yaml (self , path , dump_fn = dump_fn , ** kwargs )
351+ save_yaml (self , path , ** kwargs )
259352
260- def save_json (self , path : str | Path , dump_fn = json . dump , ** kwargs ) -> None :
261- save_json (self , path , dump_fn = dump_fn , ** kwargs )
353+ def save_json (self , path : str | Path , ** kwargs ) -> None :
354+ save_json (self , path , ** kwargs )
262355
263356 @classmethod
264357 def loads (
@@ -484,51 +577,6 @@ def loads_yaml(
484577 return loads (cls , s , drop_extra_fields = drop_extra_fields , load_fn = partial (load_fn , ** kwargs ))
485578
486579
487- extensions_to_loading_fn : dict [str , Callable [[IO ], Any ]] = {
488- ".json" : json .load ,
489- ".pkl" : pickle .load ,
490- }
491- extensions_to_read_mode : dict [str , str ] = {".pkl" : "rb" }
492-
493- extensions_to_write_mode : dict [str , str ] = {".pkl" : "wb" }
494- extensions_to_dump_fn : dict [str , Callable [[Any , IO ], None ]] = {
495- ".json" : json .dump ,
496- ".pkl" : pickle .dump ,
497- }
498- try :
499- import yaml
500-
501- extensions_to_loading_fn [".yaml" ] = yaml .safe_load
502- extensions_to_loading_fn [".yml" ] = yaml .safe_load
503- extensions_to_dump_fn [".yaml" ] = yaml .dump
504- extensions_to_dump_fn [".yml" ] = yaml .dump
505-
506-
507- except ImportError :
508- pass
509-
510- try :
511- import numpy # type: ignore
512-
513- extensions_to_loading_fn [".npy" ] = numpy .load
514- extensions_to_dump_fn [".npy" ] = numpy .save
515- extensions_to_read_mode [".npy" ] = "rb"
516- extensions_to_write_mode [".npy" ] = "wb"
517-
518- except ImportError :
519- pass
520-
521- try :
522- import torch # type: ignore
523-
524- extensions_to_loading_fn [".pth" ] = torch .load
525- extensions_to_dump_fn [".pth" ] = torch .save
526- extensions_to_read_mode [".pth" ] = "rb"
527- extensions_to_write_mode [".pth" ] = "wb"
528- except ImportError :
529- pass
530-
531-
532580def read_file (path : str | Path ) -> dict :
533581 """Returns the contents of the given file as a dictionary.
534582 Uses the right function depending on `path.suffix`:
@@ -540,66 +588,33 @@ def read_file(path: str | Path) -> dict:
540588 ".pkl": pickle.load,
541589 }
542590 """
543- path = Path (path )
544- if path .suffix in extensions_to_loading_fn :
545- load_fn = extensions_to_loading_fn [path .suffix ]
546- else :
547- raise RuntimeError (
548- f"Unable to determine what function to use in order to load "
549- f"path { path } into a dictionary since the path's extension isn't registered in the "
550- f"`extensions_to_loading_fn` dictionary..."
551- )
552- mode = extensions_to_read_mode .get (path .suffix , "r" )
553- with open (path , mode = mode ) as f :
554- return load_fn (f )
591+ format = get_extension (path )
592+ with open (path , mode = "rb" if format .binary else "r" ) as f :
593+ return format .load (f )
555594
556595
557596def save (
558597 obj : Any ,
559598 path : str | Path ,
560- dump_fn : Callable [[ dict , IO ], None ] | None = None ,
599+ format : FormatExtension | None = None ,
561600 save_dc_types : bool = False ,
601+ ** kwargs ,
562602) -> None :
563- """Save the given dataclass or dictionary to the given file.
564-
565- Note: The `encode` function is applied to all the object fields to get serializable values,
566- like so:
567- - obj -> encode -> "raw" values (dicts, strings, ints, etc) -> `dump_fn` ([json/yaml/etc].dumps) -> string
568- """
569- path = Path (path )
570-
603+ """Save the given dataclass or dictionary to the given file."""
571604 if not isinstance (obj , dict ):
572605 obj = to_dict (obj , save_dc_types = save_dc_types )
606+ if format is None :
607+ format = get_extension (path )
608+ with open (path , mode = "wb" if format .binary else "w" ) as f :
609+ return format .dump (obj , f , ** kwargs )
573610
574- if dump_fn :
575- save_fn = dump_fn
576- elif path .suffix in extensions_to_dump_fn :
577- save_fn = extensions_to_dump_fn [path .suffix ]
578- else :
579- raise RuntimeError (
580- f"Unable to determine what function to use in order to save obj { obj } to path { path } ,"
581- f"since the path's extension isn't registered in the "
582- f"`extensions_to_dump_fn` dictionary..."
583- )
584- mode = extensions_to_write_mode .get (path .suffix , "w" )
585- with open (path , mode = mode ) as f :
586- return save_fn (obj , f )
587611
612+ def save_yaml (obj , path : str | Path , ** kwargs ) -> None :
613+ save (obj , path , format = yaml_extension , ** kwargs )
588614
589- def save_yaml (
590- obj , path : str | Path , dump_fn : DumpFn | None = None , save_dc_types : bool = False , ** kwargs
591- ) -> None :
592- import yaml
593615
594- if dump_fn is None :
595- dump_fn = yaml .dump
596- save (obj , path , dump_fn = partial (dump_fn , ** kwargs ), save_dc_types = save_dc_types )
597-
598-
599- def save_json (
600- obj , path : str | Path , dump_fn : DumpFn = json .dump , save_dc_types : bool = False , ** kwargs
601- ) -> None :
602- save (obj , path , dump_fn = partial (dump_fn , ** kwargs ), save_dc_types = save_dc_types )
616+ def save_json (obj , path : str | Path , ** kwargs ) -> None :
617+ save (obj , path , format = json_extension , ** kwargs )
603618
604619
605620def load_yaml (
0 commit comments