Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.9
+++++

* :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx``
* :pr:`378`: implements the computation of discrepancies in ``method_to_onnx``
* :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``,

Expand Down
78 changes: 78 additions & 0 deletions _unittests/ut_export/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,84 @@ def test_add_empty_cache_if_needed_args_kwargs(self):
self.string_type(with_empty[0][1], with_shape=True),
)

@requires_experimental_experiment("0.1")
def test_method_to_onnx_expand_batch_dimension(self):
class Model(torch.nn.Module):
def forward(self, x=None, y=None):
return x + y

filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx")
inputs = [
dict(x=torch.randn((1, 5, 6)), y=torch.randn((1, 1, 6))),
dict(x=torch.randn((1, 7, 7)), y=torch.randn((1, 1, 7))),
]
model = Model()
method_to_call = method_to_onnx(
model, exporter="custom", filename=filename, expand_batch_for={"x", "y"}
)
expecteds = []
for kwargs in inputs:
expecteds.append(method_to_call(**kwargs))
self.assertExists(filename)
sess = self.check_ort(filename)
input_names = [i.name for i in sess.get_inputs()]
self.assertEqual(["x", "y"], input_names)
input_shapes = [i.shape for i in sess.get_inputs()]
self.assertEqual([[2, "channel", "D0"], [2, 1, "D0_1"]], input_shapes)
self.clean_dump()

@requires_experimental_experiment("0.1")
@requires_transformers("4.57")
def test_method_to_onnx_expand_batch_dimension_dynamic_cache(self):
class Model(torch.nn.Module):
def forward(self, x=None, cache=None):
return x + cache.layers[0].keys

filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx")
inputs = [
dict(
x=torch.randn((1, 1, 3, 4)),
cache=make_dynamic_cache(
[(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)]
),
),
dict(
x=torch.randn((1, 1, 5, 4)),
cache=make_dynamic_cache(
[(torch.randn(1, 2, 5, 4), torch.randn(1, 2, 5, 4)) for i in range(2)]
),
),
]
model = Model()
method_to_call = method_to_onnx(
model,
exporter="custom",
filename=filename,
expand_batch_for={"x", "cache"},
patch_kwargs=dict(patch_transformers=True),
)
expecteds = []
for kwargs in inputs:
expecteds.append(method_to_call(**kwargs))
self.assertExists(filename)
sess = self.check_ort(filename)
input_names = [i.name for i in sess.get_inputs()]
self.assertEqual(
["x", "cache_key_0", "cache_value_0", "cache_key_1", "cache_value_1"], input_names
)
input_shapes = [i.shape for i in sess.get_inputs()]
self.assertEqual(
[
[2, 1, "D0", 4],
[2, 2, "D0_1", 4],
[2, 2, "D0_2", 4],
[2, 2, "D0_3", 4],
[2, 2, "D0_4", 4],
],
input_shapes,
)
self.clean_dump()


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 4 additions & 3 deletions onnx_diagnostic/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _run_subprocess(args: List[str], cwd: Optional[str] = None):
raise_exception = False
output = ""
while True:
output = p.stdout.readline().decode(errors="ignore")
output = p.stdout.readline().decode(errors="ignore") # type: ignore[union-attr]
if output == "" and p.poll() is not None:
break
if output:
Expand All @@ -191,8 +191,9 @@ def _run_subprocess(args: List[str], cwd: Optional[str] = None):
):
raise_exception = True
p.poll()
error = p.stderr.readline().decode(errors="ignore")
p.stdout.close()
error = p.stderr.readline().decode(errors="ignore") # type: ignore[union-attr]
p.stdout.close() # type: ignore[union-attr]
p.stderr.close() # type: ignore[union-attr]
if error and raise_exception:
raise RuntimeError(
f"An error was found in the output. The build is stopped."
Expand Down
81 changes: 63 additions & 18 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def __init__(
patch_kwargs: Optional[Dict[str, Any]] = None,
skip_kwargs_names: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
):
super().__init__()
self._model_to_call = mod
Expand All @@ -366,6 +367,7 @@ def __init__(
self.verbose = verbose
self.skip_kwargs_names = skip_kwargs_names
self.dynamic_shapes = dynamic_shapes
self.expand_batch_for = expand_batch_for
self._to_onnx_kwargs = dict(
input_names=input_names,
target_opset=target_opset,
Expand All @@ -386,9 +388,11 @@ def __init__(
self._serialization_classes: Set[type] = set()

def __str__(self) -> str:
"usual"
return self.__repr__()

def __repr__(self) -> str:
"usual"
return (
f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
f"{self._method_name})"
Expand All @@ -415,22 +419,17 @@ def _collect_classes(self, obj):

def forward(self, *args, **kwargs):
if not self._export_done:
self._inputs.append(
torch_deepcopy(
(
args,
(
kwargs
if not kwargs or not self.skip_kwargs_names
else {
k: v
for k, v in kwargs.items()
if k not in self.skip_kwargs_names
}
),
)
)
inp_args = args
inp_kwargs = (
kwargs
if not kwargs or not self.skip_kwargs_names
else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names}
)
if self.expand_batch_for:
inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
self._inputs.append((inp_args, inp_kwargs))
if self.verbose:
print(
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
Expand Down Expand Up @@ -543,7 +542,7 @@ def __init__(self, parent):
)

@classmethod
def make_empty_cache_from_others(cls, examples: List["Cache"]) -> "Cache": # noqa: F821
def make_empty_cache_from_others(cls, examples: List[Any]) -> Any:
"""Builds an empty cache based on existing one."""
unique_types = {type(t) for t in examples}
assert (
Expand Down Expand Up @@ -616,9 +615,50 @@ def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]:
input_set_copy[miss] = cls.make_empty_cache_from_others(
[sub[miss] for sub in inputs if miss in sub]
)
new_inputs.append({k: input_set_copy[k] for k in ordered})
new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr]
return new_inputs

@classmethod
def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any:
expand_for_args = {i for i in expand_for if isinstance(i, int)}
expand_for_kwargs = {i for i in expand_for if isinstance(i, str)}
if isinstance(obj, tuple):
return tuple(
o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i)
for i, o in enumerate(obj)
)
assert isinstance(obj, dict), f"Unexpected type {type(obj)}"
return {
k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k)
for k, v in obj.items()
}

@classmethod
def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any:
if isinstance(obj, torch.Tensor):
assert obj.shape[0] == 1, (
f"Are you sure to expoand input {msg!r}, "
f"batch size is not 1 and shape={obj.shape}"
)
sizes = [2, *obj.shape[1:]]
return obj.expand(*sizes)
if isinstance(obj, list):
return [
cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj)
]
if obj.__class__.__name__ == "DynamicCache":
dc = CacheKeyValue(obj)
flat = dc.aslist()
flat = cls._expand_batch_dimension_input(flat, msg)
return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache()
# This might end up in an infinite loop if no registration is done.
flat, _spec = torch.utils._pytree.tree_flatten(obj)
assert (
not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj)
), f"class {type(obj)} was is not registered for serialization."
flat = cls._expand_batch_dimension_input(flat, msg)
return torch.utils._pytree.tree_unflatten(flat, _spec)

def check_discrepancies(
self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
) -> List[Dict[str, Union[str, int, float]]]:
Expand All @@ -633,7 +673,6 @@ def check_discrepancies(
:param verbose: verbosity
:return: results, a list of dictionaries, ready to be consumed by a dataframe
"""

assert self._export_done, "The onnx export was not done."
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
assert os.path.exists(
Expand Down Expand Up @@ -753,6 +792,7 @@ def method_to_onnx(
patch_kwargs: Optional[Dict[str, Any]] = None,
skip_kwargs_names: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
) -> Callable:
"""
Exports one method into ONNX for a module into ONNX.
Expand Down Expand Up @@ -782,6 +822,10 @@ def method_to_onnx(
:param skip_kwargs_names: use default values for these parameters part of
the signature of the method to export
:param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
:param expand_batch_for: LLM are usually called with a batch size equal to 1,
but the export may benefit from having another value for the batch size,
this parameter forces the input specified in this set to be expanded
to 2 if the batch size is one
:return: the output of the selected exporter, usually a structure including
an onnx model

Expand All @@ -808,5 +852,6 @@ def method_to_onnx(
patch_kwargs=patch_kwargs,
skip_kwargs_names=skip_kwargs_names,
dynamic_shapes=dynamic_shapes,
expand_batch_for=expand_batch_for,
)
return wrapped_model
Loading