Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.9
+++++

* :pr:`383`: removed bool, int, float, None as input dummies for the exporter in ``method_to_onnx``
* :pr:`382`: make the ordering of the inferred dynamic shapes more robust
* :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx``
* :pr:`378`: implements the computation of discrepancies in ``method_to_onnx``
Expand Down
10 changes: 3 additions & 7 deletions _doc/examples/plot_export_tiny_llm_method_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def generate_text(
# the others are used to infer the dynamic shapes if they are not
# specified below
convert_after_n_calls=3,
# skips the following inputs even though they are captured,
# these ones are filled with default values we don't want in
# the onnx model
skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"},
# The input used in the example has a batch size equal to 1, all
# inputs going through method forward will have the same batch size.
# To force the dynamism of this dimension, we need to indicate
Expand All @@ -105,20 +101,20 @@ def generate_text(
# .. code-block:: python
#
# dynamic_shapes={
# "cache_position": {0: "total_sequence_length"},
# "cache_position": {0: "sequence_length"},
# "past_key_values": [
# {0: "batch_size", 2: "past_sequence_length"},
# {0: "batch_size", 2: "past_sequence_length"},
# ],
# "input_ids": {0: "batch_size", 1: "sequence_length"},
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
# "attention_mask": {0: "batch_size", 1: "total_sequence_length"},
# }
#
# Finally, we need to replace the forward method.
# As ``forward_replacement`` is a module of type
# :class:`onnx_diagnostic.export.api.WrapperToExportMethodToOnnx`,
# a lambda function must be used to avoid this one to be
# included as a submodule (and an infinite loop).
# included as a submodule (and create an infinite loop).

print(f"type(forward_replacement)={type(forward_replacement)}")
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
Expand Down
92 changes: 88 additions & 4 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,19 @@ def _reorder_kwargs(self, kwargs):
def forward(self, *args, **kwargs):
if not self._export_done:
inp_args = args
# filters out the inputs not desired
# filters out the inputs not desired, int, float, bool, None
# are considered as constant for the exporter, they are removed
# from the named arguments.
inp_kwargs = (
kwargs
if not kwargs or not self.skip_kwargs_names
else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names}
if not kwargs
else {
k: v
for k, v in kwargs.items()
if v is not None
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
and not isinstance(v, (bool, int, float))
}
)
if self.expand_batch_for:
# extends the inputs to artificially create a batch dimension != 1.
Expand Down Expand Up @@ -538,7 +546,10 @@ def __init__(self, parent):
if self.dynamic_batch_for:
nds = (
self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
self.rename_dynamic_shapes(
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
verbose=self.verbose,
),
)
if self.verbose:
print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
Expand Down Expand Up @@ -842,6 +853,79 @@ def check_discrepancies(
print("[method_to_onnx.check_discrepancies] done")
return data

@classmethod
def _apply_known_shape_pattern(
cls, shape: Dict[int, Any], pattern: Dict[int, str]
) -> Dict[int, Any]:
return {k: pattern.get(k, v) for k, v in shape.items()}

@classmethod
def get_dynamic_shape_patterns(cls) -> Dict[str, Any]:
"""
Returns the known patterns for the dynamic shapes.

.. runpython::
:showcode:

import pprint
from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
"""
return {
"LLM.text": {
"cache_position": {0: "seqlength"},
"past_key_values": {0: "batch", 2: "pastlength"},
"input_ids": {0: "batch", 1: "seqlength"},
"attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength
}
}

@classmethod
def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]:
"""
Renames the dynamic shapes with names.
Tries to rename any dynamic dimnesion dimension
before export. It is not very clever, it just tries
to recognize a known configuration based on input names.
Dimension names in dynamic shapes are renamed if *ds* has
the same number of named arguments as the one of the patterns
returned by function :meth:`get_dynamic_shape_patterns
<onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
"""
is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731
isinstance(_, int) for _ in s
)
llm_patterns = cls.get_dynamic_shape_patterns()
for pattern_name, pattern_shape in llm_patterns.items():
if len(set(ds) & set(pattern_shape)) == len(pattern_shape):
if verbose:
print(
f"[method_to_onnx.rename_dynamic_shapes] "
f"apply pattern shapes {pattern_name!r}"
)
new_ds = {}
for k, v in ds.items():
if k not in pattern_shape:
new_ds[k] = v
continue
if is_shape(v):
# A shape
new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k])
elif isinstance(v, list):
# A cache
new_ds[k] = [
(
cls._apply_known_shape_pattern(s, pattern_shape[k])
if is_shape(s)
else s
)
for s in v
]
return new_ds

# unchanged
return ds


def method_to_onnx(
mod: "torch.nn.Module",
Expand Down
Loading