Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.9
+++++

* :pr:`382`: make the ordering of the inferred dynamic shapes more robust
* :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx``
* :pr:`378`: implements the computation of discrepancies in ``method_to_onnx``
* :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``,
Expand Down
44 changes: 28 additions & 16 deletions _doc/examples/plot_export_tiny_llm_method_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,37 @@ def generate_text(
# these ones are filled with default values we don't want in
# the onnx model
skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"},
# dynamic shapes can be inferred from at least two calls to the forward method,
# 3 is better for LLMs, you can see the inference results with ``verbose=1``,
# this parameter is used to overwrite the inferred values,
# this is usually needed because the inferred dynamic shapes contains
# less dynamic dimension than requested.
dynamic_shapes={
"cache_position": {0: "total_sequence_length"},
"past_key_values": [
{0: "batch_size", 2: "past_sequence_length"},
{0: "batch_size", 2: "past_sequence_length"},
],
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
},
# The input used in the example has a batch size equal to 1, all
# inputs going through method forward will have the same batch size.
# To force the dynamism of this dimension, we need to indicate
# which inputs has a batch size.
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
)

# %%
# The lambda function cannot be skipped as
# forward_replacement is a module.
# dynamic shapes can be inferred from at least two calls to the forward method,
# 3 is better for LLMs (first call is prefill, cache is missing),
# you can see the inference results with ``verbose=1``.
# If the value is not the expected one (to change the names for example),
# They can be overwritten.
#
# .. code-block:: python
#
# dynamic_shapes={
# "cache_position": {0: "total_sequence_length"},
# "past_key_values": [
# {0: "batch_size", 2: "past_sequence_length"},
# {0: "batch_size", 2: "past_sequence_length"},
# ],
# "input_ids": {0: "batch_size", 1: "sequence_length"},
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
# }
#
# Finally, we need to replace the forward method.
# As ``forward_replacement`` is a module of type
# :class:`onnx_diagnostic.export.api.WrapperToExportMethodToOnnx`,
# a lambda function must be used to avoid this one to be
# included as a submodule (and an infinite loop).

print(f"type(forward_replacement)={type(forward_replacement)}")
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
Expand Down
112 changes: 112 additions & 0 deletions _unittests/ut_export/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx


class TestDynamicShapes(ExtTestCase):
Expand Down Expand Up @@ -995,6 +996,117 @@ def forward(self, cache=None):
expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]})
self.assertEqual(expected, ds)

def test_dynamic_shape_order(self):
inputs = [
(
tuple(),
dict(
cache_position=torch.arange(8),
input_ids=torch.randint(10, size=(1, 8)),
attention_mask=torch.ones((1, 8), dtype=torch.int64),
),
),
(
tuple(),
dict(
cache_position=torch.arange(1),
input_ids=torch.randint(10, size=(1, 1)),
past_key_values=make_dynamic_cache(
[(torch.rand((1, 1, 8, 96)), torch.rand((1, 1, 8, 96)))]
),
attention_mask=torch.ones((1, 9), dtype=torch.int64),
),
),
(
tuple(),
dict(
cache_position=torch.arange(1),
input_ids=torch.randint(10, size=(1, 1)),
past_key_values=make_dynamic_cache(
[(torch.rand((1, 1, 9, 96)), torch.rand((1, 1, 9, 96)))]
),
attention_mask=torch.ones((1, 10), dtype=torch.int64),
),
),
]
mi = ModelInputs(None, inputs)
ds = mi.guess_dynamic_shapes()
DYN = torch.export.Dim.DYNAMIC
self.assertEqual(
(
(),
{
"attention_mask": {1: DYN},
"past_key_values": [{2: DYN}, {2: DYN}],
"input_ids": {1: DYN},
"cache_position": {0: DYN},
},
),
ds,
)
ordered = list(ds[1])
self.assertEqual(
["cache_position", "input_ids", "past_key_values", "attention_mask"], ordered
)

def test_dynamic_batch_dynamic(self):
inputs = [
(
tuple(),
dict(
cache_position=torch.arange(8),
input_ids=torch.randint(10, size=(1, 8)),
attention_mask=torch.ones((1, 8), dtype=torch.int64),
),
),
(
tuple(),
dict(
cache_position=torch.arange(1),
input_ids=torch.randint(10, size=(1, 1)),
past_key_values=make_dynamic_cache(
[(torch.rand((1, 1, 8, 96)), torch.rand((1, 1, 8, 96)))]
),
attention_mask=torch.ones((1, 9), dtype=torch.int64),
),
),
(
tuple(),
dict(
cache_position=torch.arange(1),
input_ids=torch.randint(10, size=(1, 1)),
past_key_values=make_dynamic_cache(
[(torch.rand((1, 1, 9, 96)), torch.rand((1, 1, 9, 96)))]
),
attention_mask=torch.ones((1, 10), dtype=torch.int64),
),
),
]
mi = ModelInputs(None, inputs)
ds = mi.guess_dynamic_shapes()[1]
DYN = torch.export.Dim.DYNAMIC
self.assertEqual(
{
"cache_position": {0: DYN},
"input_ids": {1: DYN},
"past_key_values": [{2: DYN}, {2: DYN}],
"attention_mask": {1: DYN},
},
ds,
)
ds = WrapperToExportMethodToOnnx._dynamic_batch_dimension(
ds, {"input_ids", "past_key_values", "attention_mask"}
)
self.assertEqual(
{
"cache_position": {0: DYN},
"input_ids": {0: "batch", 1: DYN},
"past_key_values": [{0: "batch", 2: DYN}, {0: "batch", 2: DYN}],
"attention_mask": {0: "batch", 1: DYN},
},
ds,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
78 changes: 78 additions & 0 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..helpers.cache_helper import CacheKeyValue
from ..helpers.torch_helper import torch_deepcopy
from ..helpers.rt_helper import make_feeds
from ..helpers.onnx_helper import pretty_onnx
from ..reference import OnnxruntimeEvaluator


Expand Down Expand Up @@ -349,6 +350,7 @@ def __init__(
patch_kwargs: Optional[Dict[str, Any]] = None,
skip_kwargs_names: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
):
super().__init__()
Expand All @@ -359,6 +361,7 @@ def __init__(
if method_name == "forward"
else getattr(mod, method_name)
)
self._signature = inspect.signature(self._method_call)
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
self._outputs: List[Any] = []
self._convert_after_n_calls = convert_after_n_calls
Expand All @@ -368,6 +371,7 @@ def __init__(
self.skip_kwargs_names = skip_kwargs_names
self.dynamic_shapes = dynamic_shapes
self.expand_batch_for = expand_batch_for
self.dynamic_batch_for = dynamic_batch_for
self._to_onnx_kwargs = dict(
input_names=input_names,
target_opset=target_opset,
Expand Down Expand Up @@ -417,25 +421,40 @@ def _collect_classes(self, obj):
self._collect_classes(v)
return

def _reorder_kwargs(self, kwargs):
new_kwargs = {k: kwargs[k] for k in self._signature.parameters if k in kwargs}
for k, v in kwargs.items():
if k not in new_kwargs:
new_kwargs[k] = v
return new_kwargs

def forward(self, *args, **kwargs):
if not self._export_done:
inp_args = args
# filters out the inputs not desired
inp_kwargs = (
kwargs
if not kwargs or not self.skip_kwargs_names
else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names}
)
if self.expand_batch_for:
# extends the inputs to artificially create a batch dimension != 1.
inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
# reorders the parameter following the method signature.
inp_kwargs = self._reorder_kwargs(inp_kwargs)
# stores the inputs
self._inputs.append((inp_args, inp_kwargs))

if self.verbose:
print(
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
f"{string_type(self._inputs[-1], with_shape=True)}"
)

if len(self._inputs) >= self._convert_after_n_calls:
# conversion starts after _convert_after_n_calls calls to the forward method
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
input_file = f"{name}.inputs.pt"
self._input_file = input_file
Expand All @@ -447,11 +466,13 @@ def forward(self, *args, **kwargs):
self._convert_method_to_onnx()
self._export_done = True

# calls the inner method (no change here)
begin = time.perf_counter()
res = self._method_call(*args, **kwargs)
duration = time.perf_counter() - begin
self._collect_classes([args, kwargs, res])
if self._inputs:
# stores the outputs if discrepancies need to be checked
self._outputs.append((torch_deepcopy(res), duration))
assert len(self._inputs) == len(self._outputs), (
f"Number of inputs {len(self._inputs)} and "
Expand Down Expand Up @@ -514,6 +535,14 @@ def __init__(self, parent):
if self.verbose:
print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
if self.dynamic_batch_for:
nds = (
self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
)
if self.verbose:
print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
print(f"[method_to_onnx] dynamic_shapes with batch={nds}")
else:
a, kw = self._inputs[-1]
nds = [self.dynamic_shapes]
Expand Down Expand Up @@ -659,6 +688,31 @@ def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any:
flat = cls._expand_batch_dimension_input(flat, msg)
return torch.utils._pytree.tree_unflatten(flat, _spec)

@classmethod
def _dynamic_batch_dimension(
cls, ds: Union[Tuple[Any, ...], Dict[str, Any]], dynamic_for: Sequence[Union[int, str]]
) -> Union[Tuple[Any, ...], Dict[str, Any]]:
if isinstance(ds, tuple):
return tuple(
(v if i not in dynamic_for else cls._dynamic_batch_dimension_input(v, i))
for i, v in enumerate(ds)
)
return {
k: (v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k))
for k, v in ds.items()
}

@classmethod
def _dynamic_batch_dimension_input(cls, ds: Any, msg: Union[str, int]) -> Any:
if isinstance(ds, dict) and all(isinstance(k, int) for k in ds):
ds[0] = "batch"
return {k: v for k, v in sorted(ds.items())} # noqa: C416
if isinstance(ds, list):
return [
cls._dynamic_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(ds)
]
raise NotImplementedError(f"cannot make first dimension dynamic for batch for {ds}")

def check_discrepancies(
self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
) -> List[Dict[str, Union[str, int, float]]]:
Expand Down Expand Up @@ -707,6 +761,10 @@ def check_discrepancies(
input_names = sess.input_names
if verbose:
print(f"[method_to_onnx.check_discrepancies] input_names={input_names}")
print(
f"[method_to_onnx.check_discrepancies] onnx_shapes="
f"{', '.join(pretty_onnx(i) for i in sess.input_types)}"
)
data = []
for i, (input, (output, latency)) in enumerate(
zip(self.add_empty_cache_if_needed(inputs), outputs)
Expand All @@ -728,6 +786,15 @@ def check_discrepancies(
)

flat_inputs = flatten_object(input, drop_keys=True)
if verbose > 1:
print(
f"[method_to_onnx.check_discrepancies] "
f"input={string_type(input, with_shape=True)}"
)
print(
f"[method_to_onnx.check_discrepancies] "
f"flat_inputs={string_type(flat_inputs, with_shape=True)}"
)
if len(flat_inputs) < len(input_names):
# not implemented yet, it is caused by a missing cache,
# which requires an empty cache instead
Expand All @@ -738,6 +805,11 @@ def check_discrepancies(
f"{len(flat_inputs)} flat torch inputs"
)
feeds = make_feeds(input_names, flat_inputs)
if verbose > 1:
print(
f"[method_to_onnx.check_discrepancies] "
f"feeds={string_type(feeds, with_shape=True)}"
)
begin = time.perf_counter()
ort_outputs = sess.run(None, feeds)
duration = time.perf_counter() - begin
Expand Down Expand Up @@ -792,6 +864,7 @@ def method_to_onnx(
patch_kwargs: Optional[Dict[str, Any]] = None,
skip_kwargs_names: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
) -> Callable:
"""
Expand Down Expand Up @@ -822,6 +895,10 @@ def method_to_onnx(
:param skip_kwargs_names: use default values for these parameters part of
the signature of the method to export
:param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
:param dynamic_batch_for: LLM are usually called with a batch size equal to 1,
but the export may benefit from having a dynamic batch size,
this parameter forces the input specified in this set to have the first dimension
be dynamic
:param expand_batch_for: LLM are usually called with a batch size equal to 1,
but the export may benefit from having another value for the batch size,
this parameter forces the input specified in this set to be expanded
Expand Down Expand Up @@ -852,6 +929,7 @@ def method_to_onnx(
patch_kwargs=patch_kwargs,
skip_kwargs_names=skip_kwargs_names,
dynamic_shapes=dynamic_shapes,
dynamic_batch_for=dynamic_batch_for,
expand_batch_for=expand_batch_for,
)
return wrapped_model
Loading
Loading