Skip to content

Commit edb9f25

Browse files
breuleuxlebrice
andauthored
Lazy import of numpy/torch etc. (#253)
* Rewrite extensions in serializable to load torch/numpy lazily * Add file with test requirements * Add a dumb test for save functions Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add numpy and pytorch as test dependencies Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix small typing issues and bug with save fn Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Remove duplicate tests from test_yaml.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec> Co-authored-by: Fabrice Normandin <normandf@mila.quebec>
1 parent ce027d2 commit edb9f25

File tree

5 files changed

+175
-128
lines changed

5 files changed

+175
-128
lines changed

requirements-test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
matplotlib
2+
numpy
3+
orion

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"pytest",
2121
"pytest-xdist",
2222
"pytest-regressions",
23+
"numpy",
24+
"torch",
2325
],
2426
"yaml": ["pyyaml"],
2527
}

simple_parsing/helpers/serialization/serializable.py

Lines changed: 114 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from types import ModuleType
1414
from typing import IO, Any, Callable, ClassVar, TypeVar, Union
1515

16+
from typing_extensions import Protocol
17+
1618
from simple_parsing.utils import (
1719
DataclassT,
1820
all_subclasses,
@@ -57,6 +59,97 @@ def ordered_dict_representer(dumper: yaml.Dumper, instance: OrderedDict) -> yaml
5759
pass
5860

5961

62+
class FormatExtension(Protocol):
63+
binary: ClassVar[bool] = False
64+
65+
@staticmethod
66+
def load(fp: IO) -> Any:
67+
...
68+
69+
@staticmethod
70+
def dump(obj: Any, io: IO) -> None:
71+
...
72+
73+
74+
class JSONExtension(FormatExtension):
75+
load = staticmethod(json.load)
76+
dump = staticmethod(json.dump)
77+
78+
79+
class PickleExtension(FormatExtension):
80+
binary: ClassVar[bool] = True
81+
load: ClassVar[Callable[[IO], Any]] = staticmethod(pickle.load)
82+
dump: ClassVar[Callable[[Any, IO[bytes]], None]] = staticmethod(pickle.dump)
83+
84+
85+
class YamlExtension(FormatExtension):
86+
def load(self, io: IO) -> Any:
87+
import yaml
88+
89+
return yaml.safe_load(io)
90+
91+
def dump(self, obj: Any, io: IO, **kwargs) -> None:
92+
import yaml
93+
94+
return yaml.dump(obj, io, **kwargs)
95+
96+
97+
class NumpyExtension(FormatExtension):
98+
binary: bool = True
99+
100+
def load(self, io: IO) -> Any:
101+
import numpy
102+
103+
obj = numpy.load(io, allow_pickle=True)
104+
if isinstance(obj, numpy.ndarray) and obj.dtype == object:
105+
obj = obj.item()
106+
return obj
107+
108+
def dump(self, obj: Any, io: IO[bytes], **kwargs) -> None:
109+
import numpy
110+
111+
return numpy.save(io, obj, **kwargs)
112+
113+
114+
class TorchExtension(FormatExtension):
115+
binary: bool = True
116+
117+
def load(self, io: IO) -> None:
118+
import torch # type: ignore
119+
120+
return torch.load(io)
121+
122+
def dump(self, obj: Any, io: IO, **kwargs) -> None:
123+
import torch # type: ignore
124+
125+
return torch.save(obj, io, **kwargs)
126+
127+
128+
json_extension = JSONExtension()
129+
yaml_extension = YamlExtension()
130+
131+
132+
extensions: dict[str, FormatExtension] = {
133+
".json": JSONExtension(),
134+
".pkl": PickleExtension(),
135+
".yaml": YamlExtension(),
136+
".yml": YamlExtension(),
137+
".npy": NumpyExtension(),
138+
".pth": TorchExtension(),
139+
}
140+
141+
142+
def get_extension(path: str | Path) -> FormatExtension:
143+
path = Path(path)
144+
if path.suffix in extensions:
145+
return extensions[path.suffix]
146+
else:
147+
raise RuntimeError(
148+
f"Cannot load to/save from a {path.suffix} file because "
149+
"this extension is not registered in the extensions dictionary."
150+
)
151+
152+
60153
class SerializableMixin:
61154
"""Makes a dataclass serializable to and from dictionaries.
62155
@@ -248,17 +341,17 @@ def load_yaml(
248341
"""
249342
return load_yaml(cls, path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs)
250343

251-
def save(self, path: str | Path, dump_fn=None) -> None:
252-
save(self, path=path, dump_fn=dump_fn)
344+
def save(self, path: str | Path, format: FormatExtension | None = None) -> None:
345+
save(self, path=path, format=format)
253346

254-
def _save(self, path: str | Path, dump_fn: DumpFn = json.dump, **kwargs) -> None:
255-
save(self, path=path, dump_fn=partial(dump_fn, **kwargs))
347+
def _save(self, path: str | Path, format: FormatExtension = json_extension, **kwargs) -> None:
348+
save(self, path=path, format=format, **kwargs)
256349

257350
def save_yaml(self, path: str | Path, dump_fn: DumpFn | None = None, **kwargs) -> None:
258-
save_yaml(self, path, dump_fn=dump_fn, **kwargs)
351+
save_yaml(self, path, **kwargs)
259352

260-
def save_json(self, path: str | Path, dump_fn=json.dump, **kwargs) -> None:
261-
save_json(self, path, dump_fn=dump_fn, **kwargs)
353+
def save_json(self, path: str | Path, **kwargs) -> None:
354+
save_json(self, path, **kwargs)
262355

263356
@classmethod
264357
def loads(
@@ -484,51 +577,6 @@ def loads_yaml(
484577
return loads(cls, s, drop_extra_fields=drop_extra_fields, load_fn=partial(load_fn, **kwargs))
485578

486579

487-
extensions_to_loading_fn: dict[str, Callable[[IO], Any]] = {
488-
".json": json.load,
489-
".pkl": pickle.load,
490-
}
491-
extensions_to_read_mode: dict[str, str] = {".pkl": "rb"}
492-
493-
extensions_to_write_mode: dict[str, str] = {".pkl": "wb"}
494-
extensions_to_dump_fn: dict[str, Callable[[Any, IO], None]] = {
495-
".json": json.dump,
496-
".pkl": pickle.dump,
497-
}
498-
try:
499-
import yaml
500-
501-
extensions_to_loading_fn[".yaml"] = yaml.safe_load
502-
extensions_to_loading_fn[".yml"] = yaml.safe_load
503-
extensions_to_dump_fn[".yaml"] = yaml.dump
504-
extensions_to_dump_fn[".yml"] = yaml.dump
505-
506-
507-
except ImportError:
508-
pass
509-
510-
try:
511-
import numpy # type: ignore
512-
513-
extensions_to_loading_fn[".npy"] = numpy.load
514-
extensions_to_dump_fn[".npy"] = numpy.save
515-
extensions_to_read_mode[".npy"] = "rb"
516-
extensions_to_write_mode[".npy"] = "wb"
517-
518-
except ImportError:
519-
pass
520-
521-
try:
522-
import torch # type: ignore
523-
524-
extensions_to_loading_fn[".pth"] = torch.load
525-
extensions_to_dump_fn[".pth"] = torch.save
526-
extensions_to_read_mode[".pth"] = "rb"
527-
extensions_to_write_mode[".pth"] = "wb"
528-
except ImportError:
529-
pass
530-
531-
532580
def read_file(path: str | Path) -> dict:
533581
"""Returns the contents of the given file as a dictionary.
534582
Uses the right function depending on `path.suffix`:
@@ -540,66 +588,33 @@ def read_file(path: str | Path) -> dict:
540588
".pkl": pickle.load,
541589
}
542590
"""
543-
path = Path(path)
544-
if path.suffix in extensions_to_loading_fn:
545-
load_fn = extensions_to_loading_fn[path.suffix]
546-
else:
547-
raise RuntimeError(
548-
f"Unable to determine what function to use in order to load "
549-
f"path {path} into a dictionary since the path's extension isn't registered in the "
550-
f"`extensions_to_loading_fn` dictionary..."
551-
)
552-
mode = extensions_to_read_mode.get(path.suffix, "r")
553-
with open(path, mode=mode) as f:
554-
return load_fn(f)
591+
format = get_extension(path)
592+
with open(path, mode="rb" if format.binary else "r") as f:
593+
return format.load(f)
555594

556595

557596
def save(
558597
obj: Any,
559598
path: str | Path,
560-
dump_fn: Callable[[dict, IO], None] | None = None,
599+
format: FormatExtension | None = None,
561600
save_dc_types: bool = False,
601+
**kwargs,
562602
) -> None:
563-
"""Save the given dataclass or dictionary to the given file.
564-
565-
Note: The `encode` function is applied to all the object fields to get serializable values,
566-
like so:
567-
- obj -> encode -> "raw" values (dicts, strings, ints, etc) -> `dump_fn` ([json/yaml/etc].dumps) -> string
568-
"""
569-
path = Path(path)
570-
603+
"""Save the given dataclass or dictionary to the given file."""
571604
if not isinstance(obj, dict):
572605
obj = to_dict(obj, save_dc_types=save_dc_types)
606+
if format is None:
607+
format = get_extension(path)
608+
with open(path, mode="wb" if format.binary else "w") as f:
609+
return format.dump(obj, f, **kwargs)
573610

574-
if dump_fn:
575-
save_fn = dump_fn
576-
elif path.suffix in extensions_to_dump_fn:
577-
save_fn = extensions_to_dump_fn[path.suffix]
578-
else:
579-
raise RuntimeError(
580-
f"Unable to determine what function to use in order to save obj {obj} to path {path},"
581-
f"since the path's extension isn't registered in the "
582-
f"`extensions_to_dump_fn` dictionary..."
583-
)
584-
mode = extensions_to_write_mode.get(path.suffix, "w")
585-
with open(path, mode=mode) as f:
586-
return save_fn(obj, f)
587611

612+
def save_yaml(obj, path: str | Path, **kwargs) -> None:
613+
save(obj, path, format=yaml_extension, **kwargs)
588614

589-
def save_yaml(
590-
obj, path: str | Path, dump_fn: DumpFn | None = None, save_dc_types: bool = False, **kwargs
591-
) -> None:
592-
import yaml
593615

594-
if dump_fn is None:
595-
dump_fn = yaml.dump
596-
save(obj, path, dump_fn=partial(dump_fn, **kwargs), save_dc_types=save_dc_types)
597-
598-
599-
def save_json(
600-
obj, path: str | Path, dump_fn: DumpFn = json.dump, save_dc_types: bool = False, **kwargs
601-
) -> None:
602-
save(obj, path, dump_fn=partial(dump_fn, **kwargs), save_dc_types=save_dc_types)
616+
def save_json(obj, path: str | Path, **kwargs) -> None:
617+
save(obj, path, format=json_extension, **kwargs)
603618

604619

605620
def load_yaml(

test/helpers/test_save.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pathlib import Path
2+
3+
from ..nesting.example_use_cases import HyperParameters
4+
5+
6+
def test_save_yaml(tmpdir: Path):
7+
hparams = HyperParameters.setup("")
8+
tmp_path = Path(tmpdir / "temp.yml")
9+
hparams.save_yaml(tmp_path)
10+
11+
_hparams = HyperParameters.load_yaml(tmp_path)
12+
assert hparams == _hparams
13+
14+
15+
def test_save_json(tmpdir: Path):
16+
hparams = HyperParameters.setup("")
17+
tmp_path = Path(tmpdir / "temp.json")
18+
hparams.save_yaml(tmp_path)
19+
_hparams = HyperParameters.load_yaml(tmp_path)
20+
assert hparams == _hparams
21+
22+
23+
def test_save_yml(tmpdir: Path):
24+
hparams = HyperParameters.setup("")
25+
tmp_path = Path(tmpdir / "temp.yml")
26+
hparams.save(tmp_path)
27+
28+
_hparams = HyperParameters.load(tmp_path)
29+
assert hparams == _hparams
30+
31+
32+
def test_save_pickle(tmpdir: Path):
33+
hparams = HyperParameters.setup("")
34+
tmp_path = Path(tmpdir / "temp.pkl")
35+
hparams.save(tmp_path)
36+
37+
_hparams = HyperParameters.load(tmp_path)
38+
assert hparams == _hparams
39+
40+
41+
def test_save_numpy(tmpdir: Path):
42+
hparams = HyperParameters.setup("")
43+
tmp_path = Path(tmpdir / "temp.npy")
44+
hparams.save(tmp_path)
45+
46+
_hparams = HyperParameters.load(tmp_path)
47+
assert hparams == _hparams
48+
49+
50+
def test_save_torch(tmpdir: Path):
51+
hparams = HyperParameters.setup("")
52+
tmp_path = Path(tmpdir / "temp.pth")
53+
hparams.save(tmp_path)
54+
55+
_hparams = HyperParameters.load(tmp_path)
56+
assert hparams == _hparams

test/utils/test_yaml.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
""" Tests for serialization to/from yaml files. """
22
import textwrap
33
from dataclasses import dataclass
4-
from pathlib import Path
54
from typing import List
65

76
from simple_parsing import list_field
87
from simple_parsing.helpers.serialization import YamlSerializable
98

10-
from ..nesting.example_use_cases import HyperParameters
11-
129

1310
@dataclass
1411
class Point(YamlSerializable):
@@ -65,32 +62,6 @@ def test_dumps_loads():
6562
)
6663

6764

68-
def test_save_yaml(tmpdir: Path):
69-
hparams = HyperParameters.setup("")
70-
tmp_path = Path(tmpdir / "temp.yml")
71-
hparams.save_yaml(tmp_path)
72-
73-
_hparams = HyperParameters.load_yaml(tmp_path)
74-
assert hparams == _hparams
75-
76-
77-
def test_save_json(tmpdir: Path):
78-
hparams = HyperParameters.setup("")
79-
tmp_path = Path(tmpdir / "temp.json")
80-
hparams.save_yaml(tmp_path)
81-
_hparams = HyperParameters.load_yaml(tmp_path)
82-
assert hparams == _hparams
83-
84-
85-
def test_save_yml(tmpdir: Path):
86-
hparams = HyperParameters.setup("")
87-
tmp_path = Path(tmpdir / "temp.yml")
88-
hparams.save(tmp_path)
89-
90-
_hparams = HyperParameters.load(tmp_path)
91-
assert hparams == _hparams
92-
93-
9465
# def test_save_yml(HyperParameters, tmpdir: Path):
9566
# hparams = HyperParameters.setup("")
9667
# tmp_path = Path(tmpdir / "temp.pth")

0 commit comments

Comments
 (0)