From d1a570b2467945b4280c3b3303a84faef61fb851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 12 Jan 2026 11:28:20 +0100 Subject: [PATCH 1/2] removed bool, int, float, None as input dummies for the exporter in method_to_onnx --- CHANGELOGS.rst | 1 + .../plot_export_tiny_llm_method_generate.py | 10 +- onnx_diagnostic/export/api.py | 92 ++++++++++++++++++- 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index fc6c53c4..26c9fc34 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.9 +++++ +* :pr:`383`: removed bool, int, float, None as input dummies for the exporter in ``method_to_onnx`` * :pr:`382`: make the ordering of the inferred dynamic shapes more robust * :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx`` * :pr:`378`: implements the computation of discrepancies in ``method_to_onnx`` diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index cf0559a8..9de38a26 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -84,10 +84,6 @@ def generate_text( # the others are used to infer the dynamic shapes if they are not # specified below convert_after_n_calls=3, - # skips the following inputs even though they are captured, - # these ones are filled with default values we don't want in - # the onnx model - skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"}, # The input used in the example has a batch size equal to 1, all # inputs going through method forward will have the same batch size. # To force the dynamism of this dimension, we need to indicate @@ -105,20 +101,20 @@ def generate_text( # .. code-block:: python # # dynamic_shapes={ -# "cache_position": {0: "total_sequence_length"}, +# "cache_position": {0: "sequence_length"}, # "past_key_values": [ # {0: "batch_size", 2: "past_sequence_length"}, # {0: "batch_size", 2: "past_sequence_length"}, # ], # "input_ids": {0: "batch_size", 1: "sequence_length"}, -# "attention_mask": {0: "batch_size", 1: "sequence_length"}, +# "attention_mask": {0: "batch_size", 1: "total_sequence_length"}, # } # # Finally, we need to replace the forward method. # As ``forward_replacement`` is a module of type # :class:`onnx_diagnostic.export.api.WrapperToExportMethodToOnnx`, # a lambda function must be used to avoid this one to be -# included as a submodule (and an infinite loop). +# included as a submodule (and create an infinite loop). print(f"type(forward_replacement)={type(forward_replacement)}") model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index bd0dcd7b..44160bf6 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -431,11 +431,19 @@ def _reorder_kwargs(self, kwargs): def forward(self, *args, **kwargs): if not self._export_done: inp_args = args - # filters out the inputs not desired + # filters out the inputs not desired, int, float, bool, None + # are considered as constant for the exporter, they are removed + # from the named arguments. inp_kwargs = ( kwargs - if not kwargs or not self.skip_kwargs_names - else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names} + if not kwargs + else { + k: v + for k, v in kwargs.items() + if v is not None + and (not self.skip_kwargs_names or k not in self.skip_kwargs_names) + and not isinstance(v, (bool, int, float)) + } ) if self.expand_batch_for: # extends the inputs to artificially create a batch dimension != 1. @@ -538,7 +546,10 @@ def __init__(self, parent): if self.dynamic_batch_for: nds = ( self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for), - self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for), + self.rename_dynamic_shapes( + self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for), + verbose=self.verbose, + ), ) if self.verbose: print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}") @@ -842,6 +853,79 @@ def check_discrepancies( print("[method_to_onnx.check_discrepancies] done") return data + @classmethod + def _apply_known_shape_pattern( + cls, shape: Dict[int, Any], pattern: Dict[int, str] + ) -> Dict[int, Any]: + return {k: pattern.get(k, v) for k, v in shape.items()} + + @classmethod + def get_dynamic_shape_patterns(cls) -> Dict[str, Any]: + """ + Returns the known patterns for the dynamic shapes. + + .. runpython:: + :showcode: + + import pprint + from onnx_diagnostic.export.api import WrappertoExportMethodToOnnx + pprint.pprint(WrappertoExportMethodToOnnx.get_dynamic_shape_patterns()) + """ + return { + "LLM.text": { + "cache_position": {0: "seqlength"}, + "past_key_values": {0: "batch", 2: "pastlength"}, + "input_ids": {0: "batch", 1: "seqlength"}, + "attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength + } + } + + @classmethod + def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]: + """ + Renames the dynamic shapes with names. + Tries to rename any dynamic dimnesion dimension + before export. It is not very clever, it just tries + to recognize a known configuration based on input names. + Dimension names in dynamic shapes are renamed if *ds* has + the same number of named arguments as the one of the patterns + returned by function :meth:`get_dynamic_shape_patterns + `. + """ + is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731 + isinstance(_, int) for _ in s + ) + llm_patterns = cls.get_dynamic_shape_patterns() + for pattern_name, pattern_shape in llm_patterns.items(): + if len(set(ds) & set(pattern_shape)) == len(pattern_shape): + if verbose: + print( + f"[method_to_onnx.rename_dynamic_shapes] " + f"apply pattern shapes {pattern_name!r}" + ) + new_ds = {} + for k, v in ds.items(): + if k not in pattern_shape: + new_ds[k] = v + continue + if is_shape(v): + # A shape + new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k]) + elif isinstance(v, list): + # A cache + new_ds[k] = [ + ( + cls._apply_known_shape_pattern(s, pattern_shape[k]) + if is_shape(s) + else s + ) + for s in v + ] + return new_ds + + # unchanged + return ds + def method_to_onnx( mod: "torch.nn.Module", From 2d2d3a6814320e8a7497b517560ce776a8453cce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 12 Jan 2026 11:35:53 +0100 Subject: [PATCH 2/2] doc --- onnx_diagnostic/export/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 44160bf6..3e89f13a 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -868,8 +868,8 @@ def get_dynamic_shape_patterns(cls) -> Dict[str, Any]: :showcode: import pprint - from onnx_diagnostic.export.api import WrappertoExportMethodToOnnx - pprint.pprint(WrappertoExportMethodToOnnx.get_dynamic_shape_patterns()) + from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx + pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns()) """ return { "LLM.text": { @@ -890,7 +890,7 @@ def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str Dimension names in dynamic shapes are renamed if *ds* has the same number of named arguments as the one of the patterns returned by function :meth:`get_dynamic_shape_patterns - `. + `. """ is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731 isinstance(_, int) for _ in s