From a935315b1eb468027b51178490866a1982b2aa7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 17:17:41 +0100 Subject: [PATCH 1/8] simplifies an example --- .../plot_export_tiny_llm_method_generate.py | 20 +++---- _unittests/ut_export/test_dynamic_shapes.py | 53 +++++++++++++++++++ onnx_diagnostic/export/api.py | 32 +++++++++++ onnx_diagnostic/export/dynamic_shapes.py | 25 +++++++-- onnx_diagnostic/helpers/helper.py | 11 ++-- onnx_diagnostic/helpers/rt_helper.py | 8 ++- 6 files changed, 128 insertions(+), 21 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index a4e3d841..ef41deac 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -93,15 +93,15 @@ def generate_text( # this parameter is used to overwrite the inferred values, # this is usually needed because the inferred dynamic shapes contains # less dynamic dimension than requested. - dynamic_shapes={ - "cache_position": {0: "total_sequence_length"}, - "past_key_values": [ - {0: "batch_size", 2: "past_sequence_length"}, - {0: "batch_size", 2: "past_sequence_length"}, - ], - "input_ids": {0: "batch_size", 1: "sequence_length"}, - "attention_mask": {0: "batch_size", 1: "sequence_length"}, - }, + # dynamic_shapes={ + # "cache_position": {0: "total_sequence_length"}, + # "past_key_values": [ + # {0: "batch_size", 2: "past_sequence_length"}, + # {0: "batch_size", 2: "past_sequence_length"}, + # ], + # "input_ids": {0: "batch_size", 1: "sequence_length"}, + # "attention_mask": {0: "batch_size", 1: "sequence_length"}, + # }, ) # %% @@ -127,7 +127,7 @@ def generate_text( # It is done after because the model may not hold twice in memory # (torch and onnxruntime). # verbose=2 shows more information about expected outputs. -data = forward_replacement.check_discrepancies(verbose=1) +data = forward_replacement.check_discrepancies(verbose=2) df = pandas.DataFrame(data) print(df) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 9381da0c..5765d5cc 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -995,6 +995,59 @@ def forward(self, cache=None): expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) self.assertEqual(expected, ds) + def test_dynamic_shape_order(self): + inputs = [ + ( + tuple(), + dict( + cache_position=torch.arange(8), + input_ids=torch.randint(10, size=(1, 8)), + attention_mask=torch.ones((1, 8), dtype=torch.int64), + ), + ), + ( + tuple(), + dict( + cache_position=torch.arange(1), + input_ids=torch.randint(10, size=(1, 1)), + past_key_values=make_dynamic_cache( + [(torch.rand((1, 1, 8, 96)), torch.rand((1, 1, 8, 96)))] + ), + attention_mask=torch.ones((1, 9), dtype=torch.int64), + ), + ), + ( + tuple(), + dict( + cache_position=torch.arange(1), + input_ids=torch.randint(10, size=(1, 1)), + past_key_values=make_dynamic_cache( + [(torch.rand((1, 1, 9, 96)), torch.rand((1, 1, 9, 96)))] + ), + attention_mask=torch.ones((1, 10), dtype=torch.int64), + ), + ), + ] + mi = ModelInputs(None, inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + self.assertEqual( + ( + (), + { + "attention_mask": {1: DYN}, + "past_key_values": [{2: DYN}, {2: DYN}], + "input_ids": {1: DYN}, + "cache_position": {0: DYN}, + }, + ), + ds, + ) + ordered = list(ds[1]) + self.assertEqual( + ["cache_position", "input_ids", "past_key_values", "attention_mask"], ordered + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 2c642dfe..3eff1e91 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -359,6 +359,7 @@ def __init__( if method_name == "forward" else getattr(mod, method_name) ) + self._signature = inspect.signature(self._method_call) self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] self._outputs: List[Any] = [] self._convert_after_n_calls = convert_after_n_calls @@ -417,25 +418,40 @@ def _collect_classes(self, obj): self._collect_classes(v) return + def _reorder_kwargs(self, kwargs): + new_kwargs = {k: kwargs[k] for k in self._signature.parameters if k in kwargs} + for k, v in kwargs.items(): + if k not in new_kwargs: + new_kwargs[k] = v + return new_kwargs + def forward(self, *args, **kwargs): if not self._export_done: inp_args = args + # filters out the inputs not desired inp_kwargs = ( kwargs if not kwargs or not self.skip_kwargs_names else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names} ) if self.expand_batch_for: + # extends the inputs to artificially create a batch dimension != 1. inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for) inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for) inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs)) + # reorders the parameter following the method signature. + inp_kwargs = self._reorder_kwargs(inp_kwargs) + # stores the inputs self._inputs.append((inp_args, inp_kwargs)) + if self.verbose: print( f"[method_to_onnx] input[{len(self._inputs)-1}]: " f"{string_type(self._inputs[-1], with_shape=True)}" ) + if len(self._inputs) >= self._convert_after_n_calls: + # conversion starts after _convert_after_n_calls calls to the forward method name = os.path.splitext(self._to_onnx_kwargs["filename"])[0] input_file = f"{name}.inputs.pt" self._input_file = input_file @@ -447,11 +463,13 @@ def forward(self, *args, **kwargs): self._convert_method_to_onnx() self._export_done = True + # calls the inner method (no change here) begin = time.perf_counter() res = self._method_call(*args, **kwargs) duration = time.perf_counter() - begin self._collect_classes([args, kwargs, res]) if self._inputs: + # stores the outputs if discrepancies need to be checked self._outputs.append((torch_deepcopy(res), duration)) assert len(self._inputs) == len(self._outputs), ( f"Number of inputs {len(self._inputs)} and " @@ -728,6 +746,15 @@ def check_discrepancies( ) flat_inputs = flatten_object(input, drop_keys=True) + if verbose > 1: + print( + f"[method_to_onnx.check_discrepancies] " + f"input={string_type(input, with_shape=True)}" + ) + print( + f"[method_to_onnx.check_discrepancies] " + f"flat_inputs={string_type(flat_inputs, with_shape=True)}" + ) if len(flat_inputs) < len(input_names): # not implemented yet, it is caused by a missing cache, # which requires an empty cache instead @@ -738,6 +765,11 @@ def check_discrepancies( f"{len(flat_inputs)} flat torch inputs" ) feeds = make_feeds(input_names, flat_inputs) + if verbose > 1: + print( + f"[method_to_onnx.check_discrepancies] " + f"feeds={string_type(feeds, with_shape=True)}" + ) begin = time.perf_counter() ort_outputs = sess.run(None, feeds) duration = time.perf_counter() - begin diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 7eaac732..a8c3aadd 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -967,6 +967,8 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES """ Guesses the dynamic shapes for that module from two execution. If there is only one execution, then that would be static dimensions. + If the model signature is available, the kwargs are reordered following + the signature order, otherwise it follows the order given in the inputs. :param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any dimension if the number of inputs is one, @@ -1026,11 +1028,24 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES msg=lambda name=name: f" failing input {name!r}", ) # reordering - if kwargs is not None and self.forward_ordered_parameter_names: - kwargs1 = { - p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs - } - kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}} + if kwargs: + if self.forward_ordered_parameter_names: + kwargs1 = { + p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs + } + kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}} + else: + # We reorder the same the way the input were given. + use = None + params = set(kwargs) + for _args, kws in self.inputs: + if set(kws) == params: + use = kws + break + if use: + ordered = list(use) + kwargs = {k: kwargs[k] for k in ordered} + return tuple(args), kwargs def move_to_kwargs( diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 923c92ef..3d0110cb 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1,7 +1,6 @@ import ast import enum import inspect -import itertools import json from dataclasses import is_dataclass, fields from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -991,15 +990,17 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any: if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}: from .cache_helper import CacheKeyValue - kc = CacheKeyValue(x) - return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache))) + return CacheKeyValue(x).aslist() if x.__class__.__name__ == "EncoderDecoderCache": - res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache) + res = [ + *flatten_object(x.self_attention_cache), + *flatten_object(x.cross_attention_cache), + ] return tuple(res) if x.__class__.__name__ == "MambaCache": if isinstance(x.conv_states, list): - res = flatten_object(x.conv_states) + flatten_object(x.ssm_states) + res = [*flatten_object(x.conv_states), *flatten_object(x.ssm_states)] return tuple(res) return (x.conv_states, x.ssm_states) if hasattr(x, "to_tuple"): diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 6d874b89..75e30eba 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -41,7 +41,13 @@ def make_feeds( """ # NOTE: position_ids is a special case because ModelBuilder does not usually use it, # because it's fued into rotary embedding in GQA. - if is_modelbuilder and isinstance(inputs, dict): + if is_modelbuilder and isinstance(inputs, dict) and "positions_ids" in inputs: + position_ids = input["position_ids"] + assert ( + (position_ids == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)) + .max() + .item() + ), f"ModelBuilder does not support position_ids={position_ids}" inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing. flat = flatten_object(inputs, drop_keys=True) From f902ff72e86b12b77c627fa34e64bcd23872fde3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 17:23:33 +0100 Subject: [PATCH 2/8] fix example --- .../plot_export_tiny_llm_method_generate.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index ef41deac..0e6d6e61 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -88,25 +88,32 @@ def generate_text( # these ones are filled with default values we don't want in # the onnx model skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"}, - # dynamic shapes can be inferred from at least two calls to the forward method, - # 3 is better for LLMs, you can see the inference results with ``verbose=1``, - # this parameter is used to overwrite the inferred values, - # this is usually needed because the inferred dynamic shapes contains - # less dynamic dimension than requested. - # dynamic_shapes={ - # "cache_position": {0: "total_sequence_length"}, - # "past_key_values": [ - # {0: "batch_size", 2: "past_sequence_length"}, - # {0: "batch_size", 2: "past_sequence_length"}, - # ], - # "input_ids": {0: "batch_size", 1: "sequence_length"}, - # "attention_mask": {0: "batch_size", 1: "sequence_length"}, - # }, ) # %% -# The lambda function cannot be skipped as -# forward_replacement is a module. +# dynamic shapes can be inferred from at least two calls to the forward method, +# 3 is better for LLMs (first call is prefill, cache is missing), +# you can see the inference results with ``verbose=1``. +# If the value is not the expected one (to change the names for example), +# They can be overwritten. +# +# .. code-block:: python +# +# dynamic_shapes={ +# "cache_position": {0: "total_sequence_length"}, +# "past_key_values": [ +# {0: "batch_size", 2: "past_sequence_length"}, +# {0: "batch_size", 2: "past_sequence_length"}, +# ], +# "input_ids": {0: "batch_size", 1: "sequence_length"}, +# "attention_mask": {0: "batch_size", 1: "sequence_length"}, +# } +# +# Finally, we need to replace the forward method. +# As ``forward_replacement`` is a module of type +# :class:`onnx_diagnostic.export.api.WrapperToExportMethodToOnnx`, +# a lambda function must be used to avoid this one to be +# included as a submodule (and an infinite loop). print(f"type(forward_replacement)={type(forward_replacement)}") model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs) From 291262bf725c4529d2a2f70aae0f691a91c6e5b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 17:25:29 +0100 Subject: [PATCH 3/8] fix position ids --- onnx_diagnostic/helpers/rt_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 75e30eba..88ddf104 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -43,7 +43,7 @@ def make_feeds( # because it's fued into rotary embedding in GQA. if is_modelbuilder and isinstance(inputs, dict) and "positions_ids" in inputs: position_ids = input["position_ids"] - assert ( + assert isinstance(position_ids, torch.Tensor) and ( (position_ids == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)) .max() .item() From 6e2e3b188d1b627e2e8efee30cc9608026c284d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 17:26:44 +0100 Subject: [PATCH 4/8] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 7c166329..fc6c53c4 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.9 +++++ +* :pr:`382`: make the ordering of the inferred dynamic shapes more robust * :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx`` * :pr:`378`: implements the computation of discrepancies in ``method_to_onnx`` * :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``, From 9ba87fad670cbafca055004c7bd03bd6f086ad2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 17:38:27 +0100 Subject: [PATCH 5/8] fix --- onnx_diagnostic/export/shape_helper.py | 8 ----- onnx_diagnostic/helpers/rt_helper.py | 2 +- onnx_diagnostic/tasks/text_generation.py | 34 +++++++++---------- .../serialization/transformers_impl.py | 12 +++---- 4 files changed, 23 insertions(+), 33 deletions(-) diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py index 9b96a0e6..5bc45de1 100644 --- a/onnx_diagnostic/export/shape_helper.py +++ b/onnx_diagnostic/export/shape_helper.py @@ -47,7 +47,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: make_dynamic_cache, make_encoder_decoder_cache, make_mamba_cache, - make_sliding_window_cache, make_static_cache, ) from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs @@ -77,13 +76,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: ] ), ), - make_sliding_window_cache( - [ - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - ] - ), make_static_cache( [ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 88ddf104..75420e98 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -42,7 +42,7 @@ def make_feeds( # NOTE: position_ids is a special case because ModelBuilder does not usually use it, # because it's fued into rotary embedding in GQA. if is_modelbuilder and isinstance(inputs, dict) and "positions_ids" in inputs: - position_ids = input["position_ids"] + position_ids = input["position_ids"] # type: ignore[valid-type] assert isinstance(position_ids, torch.Tensor) and ( (position_ids == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)) .max() diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index b9336879..42be9721 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -1,11 +1,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from ..helpers.cache_helper import ( - make_dynamic_cache, - make_mamba_cache, - make_sliding_window_cache, - make_static_cache, -) +from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache, make_static_cache from ..helpers.config_helper import ( update_config, check_hasattr, @@ -187,17 +182,22 @@ def get_inputs( if cls_cache is None or isinstance(cls_cache, str) else cls_cache.__name__ ) - make_caches = { - "DynamicCache": make_dynamic_cache, - "SlidingWindowCache": make_sliding_window_cache, - "StaticCache": make_static_cache, - } - assert cache_name is None or cache_name in make_caches, ( - f"Unable to handle cls_cache={cache_name!r}, it should be in " - f"{sorted(make_caches)}" - ) - make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] - is_static = cache_name == "StaticCache" + if cache_name == "DynamicSlidingWindowCache": + from ..helpers.cache_helper import make_sliding_window_cache + + make_cache = make_sliding_window_cache + is_static = False + else: + make_caches = { + "DynamicCache": make_dynamic_cache, + "StaticCache": make_static_cache, + } + assert cache_name is None or cache_name in make_caches, ( + f"Unable to handle cls_cache={cache_name!r}, it should be in " + f"{sorted(make_caches)}" + ) + make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] + is_static = cache_name == "StaticCache" if is_static: # static diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 58fbceee..e08c4a99 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -19,13 +19,7 @@ except ImportError: from transformers.cache_utils import MambaCache from transformers.modeling_outputs import BaseModelOutput -from ...helpers.cache_helper import ( - make_dynamic_cache, - make_hybrid_cache, - make_sliding_window_cache, - make_static_cache, - CacheKeyValue, -) +from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue from . import make_serialization_function_for_dataclass @@ -132,6 +126,8 @@ def unflatten_hybrid_cache( values: List[Any], context: torch.utils._pytree.Context, output_type=None ) -> HybridCache: """Restores a :class:`transformers.cache_utils.HybridCache` from python objects.""" + from ...helpers.cache_helper import make_hybrid_cache + return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type) @@ -204,6 +200,8 @@ def unflatten_sliding_window_cache( Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects. """ + from ...helpers.cache_helper import make_sliding_window_cache + return _unflatten_cache( make_sliding_window_cache, values, context, output_type=output_type ) From 1a6f0073491631cba664c0b2f4a68c06ba107f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 12 Jan 2026 01:31:50 +0100 Subject: [PATCH 6/8] mb --- onnx_diagnostic/helpers/rt_helper.py | 15 +++++++++++---- onnx_diagnostic/tasks/text_generation.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 75420e98..7ace788c 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -41,13 +41,20 @@ def make_feeds( """ # NOTE: position_ids is a special case because ModelBuilder does not usually use it, # because it's fued into rotary embedding in GQA. - if is_modelbuilder and isinstance(inputs, dict) and "positions_ids" in inputs: - position_ids = input["position_ids"] # type: ignore[valid-type] + if is_modelbuilder and isinstance(inputs, dict) and "position_ids" in inputs: + position_ids = inputs["position_ids"] # type: ignore[valid-type] + # We just check position_ids are contiguous. assert isinstance(position_ids, torch.Tensor) and ( - (position_ids == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)) + ( + (position_ids - position_ids.min()) + == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0) + ) .max() .item() - ), f"ModelBuilder does not support position_ids={position_ids}" + ), ( + f"ModelBuilder does not support position_ids={position_ids}, " + f"inputs={string_type(inputs, with_shape=True)}" + ) inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing. flat = flatten_object(inputs, drop_keys=True) diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 42be9721..25b4d29c 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -196,7 +196,7 @@ def get_inputs( f"Unable to handle cls_cache={cache_name!r}, it should be in " f"{sorted(make_caches)}" ) - make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] + make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] # type: ignore[assignment] is_static = cache_name == "StaticCache" if is_static: From a5ec30b2092dc3f93d4289b88afb465c8466fdc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 12 Jan 2026 01:47:34 +0100 Subject: [PATCH 7/8] fix --- _unittests/ut_export/test_dynamic_shapes.py | 59 +++++++++++++++++++++ onnx_diagnostic/export/api.py | 30 +++++++++++ 2 files changed, 89 insertions(+) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 5765d5cc..e8b2a178 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -7,6 +7,7 @@ from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx class TestDynamicShapes(ExtTestCase): @@ -1048,6 +1049,64 @@ def test_dynamic_shape_order(self): ["cache_position", "input_ids", "past_key_values", "attention_mask"], ordered ) + def test_dynamic_batch_dynamic(self): + inputs = [ + ( + tuple(), + dict( + cache_position=torch.arange(8), + input_ids=torch.randint(10, size=(1, 8)), + attention_mask=torch.ones((1, 8), dtype=torch.int64), + ), + ), + ( + tuple(), + dict( + cache_position=torch.arange(1), + input_ids=torch.randint(10, size=(1, 1)), + past_key_values=make_dynamic_cache( + [(torch.rand((1, 1, 8, 96)), torch.rand((1, 1, 8, 96)))] + ), + attention_mask=torch.ones((1, 9), dtype=torch.int64), + ), + ), + ( + tuple(), + dict( + cache_position=torch.arange(1), + input_ids=torch.randint(10, size=(1, 1)), + past_key_values=make_dynamic_cache( + [(torch.rand((1, 1, 9, 96)), torch.rand((1, 1, 9, 96)))] + ), + attention_mask=torch.ones((1, 10), dtype=torch.int64), + ), + ), + ] + mi = ModelInputs(None, inputs) + ds = mi.guess_dynamic_shapes()[1] + DYN = torch.export.Dim.DYNAMIC + self.assertEqual( + { + "cache_position": {0: DYN}, + "input_ids": {1: DYN}, + "past_key_values": [{2: DYN}, {2: DYN}], + "attention_mask": {1: DYN}, + }, + ds, + ) + ds = WrapperToExportMethodToOnnx._dynamic_batch_dimension( + ds, {"input_ids", "past_key_values", "attention_mask"} + ) + self.assertEqual( + { + "cache_position": {0: DYN}, + "input_ids": {0: "batch", 1: DYN}, + "past_key_values": [{0: "batch", 2: DYN}, {0: "batch", 2: DYN}], + "attention_mask": {0: "batch", 1: DYN}, + }, + ds, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 3eff1e91..b501408f 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -349,6 +349,7 @@ def __init__( patch_kwargs: Optional[Dict[str, Any]] = None, skip_kwargs_names: Optional[Set[str]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None, expand_batch_for: Optional[Sequence[Union[int, str]]] = None, ): super().__init__() @@ -369,6 +370,7 @@ def __init__( self.skip_kwargs_names = skip_kwargs_names self.dynamic_shapes = dynamic_shapes self.expand_batch_for = expand_batch_for + self.dynamic_batch_for = (dynamic_batch_for,) self._to_onnx_kwargs = dict( input_names=input_names, target_opset=target_opset, @@ -532,6 +534,8 @@ def __init__(self, parent): if self.verbose: print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}") a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds) + if self.dynamic_batch_for: + nds = self._dynamic_batch_dimension(nds, self.dynamic_batch_for) else: a, kw = self._inputs[-1] nds = [self.dynamic_shapes] @@ -677,6 +681,26 @@ def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any: flat = cls._expand_batch_dimension_input(flat, msg) return torch.utils._pytree.tree_unflatten(flat, _spec) + @classmethod + def _dynamic_batch_dimension( + cls, ds: Dict[str, Any], dynamic_for: Sequence[Union[int, str]] + ) -> Dict[str, Any]: + return { + k: v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k) + for k, v in ds.items() + } + + @classmethod + def _dynamic_batch_dimension_input(cls, ds: Any, msg: Union[str, int]) -> Any: + if isinstance(ds, dict) and all(isinstance(k, int) for k in ds): + ds[0] = "batch" + return {k: v for k, v in sorted(ds.items())} # noqa: C416 + if isinstance(ds, list): + return [ + cls._dynamic_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(ds) + ] + raise NotImplementedError(f"cannot make first dimension dynamic for batch for {ds}") + def check_discrepancies( self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0 ) -> List[Dict[str, Union[str, int, float]]]: @@ -824,6 +848,7 @@ def method_to_onnx( patch_kwargs: Optional[Dict[str, Any]] = None, skip_kwargs_names: Optional[Set[str]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None, expand_batch_for: Optional[Sequence[Union[int, str]]] = None, ) -> Callable: """ @@ -854,6 +879,10 @@ def method_to_onnx( :param skip_kwargs_names: use default values for these parameters part of the signature of the method to export :param dynamic_shapes: dynamic shapes to use if the guessed ones are not right + :param dynamic_batch_for: LLM are usually called with a batch size equal to 1, + but the export may benefit from having a dynamic batch size, + this parameter forces the input specified in this set to have the first dimension + be dynamic :param expand_batch_for: LLM are usually called with a batch size equal to 1, but the export may benefit from having another value for the batch size, this parameter forces the input specified in this set to be expanded @@ -884,6 +913,7 @@ def method_to_onnx( patch_kwargs=patch_kwargs, skip_kwargs_names=skip_kwargs_names, dynamic_shapes=dynamic_shapes, + dynamic_batch_for=dynamic_batch_for, expand_batch_for=expand_batch_for, ) return wrapped_model From 069539fa2b2d943ab6e39972aeb5c8eb96d8f7cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 12 Jan 2026 09:58:55 +0100 Subject: [PATCH 8/8] fix dynamism --- .../plot_export_tiny_llm_method_generate.py | 7 ++++- onnx_diagnostic/export/api.py | 26 +++++++++++++++---- onnx_diagnostic/helpers/onnx_helper.py | 7 +++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index 0e6d6e61..cf0559a8 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -88,6 +88,11 @@ def generate_text( # these ones are filled with default values we don't want in # the onnx model skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"}, + # The input used in the example has a batch size equal to 1, all + # inputs going through method forward will have the same batch size. + # To force the dynamism of this dimension, we need to indicate + # which inputs has a batch size. + dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"}, ) # %% @@ -134,7 +139,7 @@ def generate_text( # It is done after because the model may not hold twice in memory # (torch and onnxruntime). # verbose=2 shows more information about expected outputs. -data = forward_replacement.check_discrepancies(verbose=2) +data = forward_replacement.check_discrepancies(verbose=1) df = pandas.DataFrame(data) print(df) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index b501408f..bd0dcd7b 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -11,6 +11,7 @@ from ..helpers.cache_helper import CacheKeyValue from ..helpers.torch_helper import torch_deepcopy from ..helpers.rt_helper import make_feeds +from ..helpers.onnx_helper import pretty_onnx from ..reference import OnnxruntimeEvaluator @@ -370,7 +371,7 @@ def __init__( self.skip_kwargs_names = skip_kwargs_names self.dynamic_shapes = dynamic_shapes self.expand_batch_for = expand_batch_for - self.dynamic_batch_for = (dynamic_batch_for,) + self.dynamic_batch_for = dynamic_batch_for self._to_onnx_kwargs = dict( input_names=input_names, target_opset=target_opset, @@ -535,7 +536,13 @@ def __init__(self, parent): print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}") a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds) if self.dynamic_batch_for: - nds = self._dynamic_batch_dimension(nds, self.dynamic_batch_for) + nds = ( + self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for), + self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for), + ) + if self.verbose: + print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}") + print(f"[method_to_onnx] dynamic_shapes with batch={nds}") else: a, kw = self._inputs[-1] nds = [self.dynamic_shapes] @@ -683,10 +690,15 @@ def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any: @classmethod def _dynamic_batch_dimension( - cls, ds: Dict[str, Any], dynamic_for: Sequence[Union[int, str]] - ) -> Dict[str, Any]: + cls, ds: Union[Tuple[Any, ...], Dict[str, Any]], dynamic_for: Sequence[Union[int, str]] + ) -> Union[Tuple[Any, ...], Dict[str, Any]]: + if isinstance(ds, tuple): + return tuple( + (v if i not in dynamic_for else cls._dynamic_batch_dimension_input(v, i)) + for i, v in enumerate(ds) + ) return { - k: v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k) + k: (v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k)) for k, v in ds.items() } @@ -749,6 +761,10 @@ def check_discrepancies( input_names = sess.input_names if verbose: print(f"[method_to_onnx.check_discrepancies] input_names={input_names}") + print( + f"[method_to_onnx.check_discrepancies] onnx_shapes=" + f"{', '.join(pretty_onnx(i) for i in sess.input_types)}" + ) data = [] for i, (input, (output, latency)) in enumerate( zip(self.add_empty_cache_if_needed(inputs), outputs) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index fd5a7b23..6e7eede9 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -28,6 +28,7 @@ NodeProto, OperatorSetIdProto, TensorProto, + TypeProto, ValueInfoProto, load as onnx_load, ) @@ -385,6 +386,12 @@ def pretty_onnx( shape_str = ",".join(map(str, shape)) return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}" + if isinstance(onx, TypeProto): + itype = onx.tensor_type.elem_type + shape = tuple((d.dim_param or d.dim_value) for d in onx.tensor_type.shape.dim) + shape_str = ",".join(map(str, shape)) + return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}]" + if isinstance(onx, AttributeProto): att = onx if att.type == AttributeProto.INT: