From c4526ccf6548ab2d092b0ebbbeb19952fab391aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 10 Jan 2026 15:14:52 +0100 Subject: [PATCH 1/9] add mask --- .../plot_export_tiny_llm_method_generate.py | 8 +++-- _unittests/ut_export/test_dynamic_shapes.py | 33 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index 6568ffe4..d89ccf87 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -31,10 +31,13 @@ def generate_text( prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95 ): - inputs = tokenizer.encode(prompt, return_tensors="pt") + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] outputs = model.generate( - inputs, + input_ids=input_ids, + attention_mask=attention_mask, max_length=max_length, temperature=temperature, top_k=top_k, @@ -97,6 +100,7 @@ def generate_text( {0: "batch_size", 2: "past_sequence_length"}, ], "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, }, ) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index f59d46c0..9381da0c 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -962,6 +962,39 @@ def forward(self, x, y=None): _a, _kw, ds = mi.move_to_kwargs(*mi.inputs[-1], ds) self.assertEqual(ds, (tuple(), {"x": {0: DYN}, "y": {0: DYN}})) + def test_guess_dynamic_shapes_with_none(self): + class Model(torch.nn.Module): + def forward(self, cache=None): + return cache + + cache = make_dynamic_cache( + [(torch.randn(2, 3, 5, 6), torch.randn((2, 3, 5, 6))) for i in range(2)] + ) + cache2 = make_dynamic_cache( + [(torch.randn(2, 3, 1, 6), torch.randn((2, 3, 6, 6))) for i in range(2)] + ) + + inputs = [dict(cache=cache), dict(cache=cache2)] + mi = ModelInputs(Model(), inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) + self.assertEqual(expected, ds) + + inputs = [{}, dict(cache=cache), dict(cache=cache2)] + mi = ModelInputs(Model(), inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) + self.assertEqual(expected, ds) + + inputs = [{}, dict(cache=cache), dict(cache=cache2)] + mi = ModelInputs(None, inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) + self.assertEqual(expected, ds) + if __name__ == "__main__": unittest.main(verbosity=2) From f8bcbc7d9635fe5f0bd5ab76c694f3032aa9700d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 10 Jan 2026 15:16:20 +0100 Subject: [PATCH 2/9] doc --- README.rst | 2 +- _doc/examples/plot_export_tiny_llm_method_generate.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 87b14efc..1f0af03f 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Enlightening Examples * `Export microsoft/phi-2 `_ -* `Export a model through method generate (with Tiny-LLM) +* `Export a LLM through method generate (with Tiny-LLM) `_ **Torch Export** diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index d89ccf87..97f913dc 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -1,8 +1,8 @@ """ .. _l-plot-tiny-llm-export-method-generate: -Export a model through method generate (with Tiny-LLM) -====================================================== +Export a LLM through method generate (with Tiny-LLM) +==================================================== The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. From b68dfcf532663e77121a0910b287a2bdf3995bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 10 Jan 2026 18:46:53 +0100 Subject: [PATCH 3/9] fix --- .../plot_dump_intermediate_results.py | 3 +- .../plot_export_tiny_llm_method_generate.py | 3 +- .../plot_layer_norm_discrepancies.py | 3 +- _unittests/ut_export/test_api.py | 58 +++- onnx_diagnostic/doc.py | 265 +++++++++++++++++- onnx_diagnostic/export/api.py | 109 ++++++- onnx_diagnostic/helpers/cache_helper.py | 29 +- 7 files changed, 447 insertions(+), 23 deletions(-) diff --git a/_doc/examples/plot_dump_intermediate_results.py b/_doc/examples/plot_dump_intermediate_results.py index bc3090ef..6ce8f8fb 100644 --- a/_doc/examples/plot_dump_intermediate_results.py +++ b/_doc/examples/plot_dump_intermediate_results.py @@ -26,7 +26,6 @@ import onnx import torch import onnxruntime -from onnx_array_api.plotting.graphviz_helper import plot_dot from onnx_diagnostic import doc from onnx_diagnostic.helpers import max_diff, string_diff, string_type from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward @@ -203,7 +202,7 @@ # ++++++++++++++++++++ onx = onnx.load("plot_dump_intermediate_results.onnx") -plot_dot(onx) +doc.plot_dot(onx) # %% doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue") diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index 97f913dc..a4e3d841 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -133,5 +133,4 @@ def generate_text( # %% - -doc.plot_legend("Tiny-LLM\nforward inputs\nthrough generate", "onnx export", "tomato") +doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_doc/technical/plot_layer_norm_discrepancies.py b/_doc/technical/plot_layer_norm_discrepancies.py index de7a6e71..4a40576d 100644 --- a/_doc/technical/plot_layer_norm_discrepancies.py +++ b/_doc/technical/plot_layer_norm_discrepancies.py @@ -24,8 +24,7 @@ import onnx.helper as oh import onnxruntime import torch -from onnx_array_api.plotting.graphviz_helper import plot_dot -from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title +from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title, plot_dot from onnx_diagnostic.ext_test_case import unit_test_going from onnx_diagnostic.helpers import max_diff, string_diff, string_type from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 4a47f362..547ac589 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -15,7 +15,7 @@ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches -from onnx_diagnostic.export.api import to_onnx, method_to_onnx +from onnx_diagnostic.export.api import to_onnx, method_to_onnx, WrapperToExportMethodToOnnx class TestValidate(ExtTestCase): @@ -250,6 +250,62 @@ def forward(self, x, y=None): method_to_call.check_discrepancies(verbose=1) self.clean_dump() + def test_add_empty_cache_if_needed_dict(self): + inputs = [ + dict(x=torch.rand((1, 3))), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)] + ), + ), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 1, 4), torch.randn(1, 2, 1, 4)) for i in range(2)] + ), + ), + ] + with_empty = WrapperToExportMethodToOnnx.add_empty_cache_if_needed(inputs) + self.assertEqual( + ( + "dict(x:T1s1x3,cache:DynamicCache(key_cache=#2[T1s1x2x0x4,T1s1x2x0x4], " + "value_cache=#2[T1s1x2x0x4,T1s1x2x0x4]))" + ), + self.string_type(with_empty[0], with_shape=True), + ) + + def test_add_empty_cache_if_needed_args_kwargs(self): + inputs = [ + (tuple(), dict(x=torch.rand((1, 3)))), + ( + tuple(), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)] + ), + ), + ), + ( + tuple(), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 1, 4), torch.randn(1, 2, 1, 4)) for i in range(2)] + ), + ), + ), + ] + with_empty = WrapperToExportMethodToOnnx.add_empty_cache_if_needed(inputs) + self.assertEqual( + ( + "dict(x:T1s1x3,cache:DynamicCache(key_cache=#2[T1s1x2x0x4,T1s1x2x0x4], " + "value_cache=#2[T1s1x2x0x4,T1s1x2x0x4]))" + ), + self.string_type(with_empty[0][1], with_shape=True), + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/doc.py b/onnx_diagnostic/doc.py index 3391543c..cf0e22ad 100644 --- a/onnx_diagnostic/doc.py +++ b/onnx_diagnostic/doc.py @@ -1,5 +1,11 @@ -from typing import Optional +import os +import subprocess +import sys +import tempfile +from typing import List, Optional, Tuple, Union import numpy as np +import onnx +from .helpers.dot_helper import to_dot def get_latest_pypi_version(package_name="onnx-diagnostic") -> str: @@ -36,6 +42,15 @@ def reset_torch_transformers(gallery_conf, fname): def plot_legend( text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15 ) -> "matplotlib.axes.Axes": # noqa: F821 + """ + Plots a graph with only text (for :epkg:`sphinx-gallery`). + + :param text: legend + :param text_bottom: text at the bottom + :param color: color + :param fontsize: font size + :return: axis + """ import matplotlib.pyplot as plt fig = plt.figure(figsize=(2, 2)) @@ -66,17 +81,14 @@ def rotate_align(ax, angle=15, align="right"): return ax -def save_fig(ax, name: str): +def save_fig(ax, name: str, **kwargs) -> "matplotlib.axis.Axis": # noqa: F821 """Applies ``tight_layout`` and saves the figures. Returns ax.""" - import matplotlib.pyplot as plt - - plt.tight_layout() fig = ax.get_figure() - fig.savefig(name) + fig.savefig(name, **kwargs) return ax -def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821 +def title(ax: "plt.axes", title: str) -> "matplotlib.axis.Axis": # noqa: F821 "Adds a title to axes and returns them." ax.set_title(title) return ax @@ -88,7 +100,7 @@ def plot_histogram( bins: int = 30, color: str = "orange", alpha: float = 0.7, -) -> "plt.axes": # noqa: F821 +) -> "matplotlib.axis.Axis": # noqa: F821 "Computes the distribution for a tensor." if ax is None: import matplotlib.pyplot as plt @@ -98,3 +110,240 @@ def plot_histogram( ax.hist(tensor, bins=30, color="orange", alpha=0.7) ax.set_yscale("log") return ax + + +def _find_in_PATH(prog: str) -> Optional[str]: + """ + Looks into every path mentioned in ``%PATH%`` a specific file, + it raises an exception if not found. + + :param prog: program to look for + :return: path + """ + sep = ";" if sys.platform.startswith("win") else ":" + path = os.environ["PATH"] + for p in path.split(sep): + f = os.path.join(p, prog) + if os.path.exists(f): + return p + return None + + +def _find_graphviz_dot(exc: bool = True) -> str: + """ + Determines the path to graphviz (on Windows), + the function tests the existence of versions 34 to 45 + assuming it was installed in a standard folder: + ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``. + + :param exc: raise exception of be silent + :return: path to dot + :raises FileNotFoundError: if graphviz not found + """ + if sys.platform.startswith("win"): + version = list(range(34, 60)) + version.extend([f"{v}.1" for v in version]) + for v in version: + graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe" + if os.path.exists(graphviz_dot): + return graphviz_dot + extra = ["build/update_modules/Graphviz/bin"] + for ext in extra: + graphviz_dot = os.path.join(ext, "dot.exe") + if os.path.exists(graphviz_dot): + return graphviz_dot + p = _find_in_PATH("dot.exe") + if p is None: + if exc: + raise FileNotFoundError( + f"Unable to find graphviz, look into paths such as {graphviz_dot}." + ) + return None + return os.path.join(p, "dot.exe") + # linux + return "dot" + + +def _run_subprocess(args: List[str], cwd: Optional[str] = None): + assert not isinstance(args, str), "args should be a sequence of strings, not a string." + + p = subprocess.Popen( + args, + cwd=cwd, + shell=False, + env=os.environ, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + raise_exception = False + output = "" + while True: + output = p.stdout.readline().decode(errors="ignore") + if output == "" and p.poll() is not None: + break + if output: + if ( + "fatal error" in output + or "CMake Error" in output + or "gmake: ***" in output + or "): error C" in output + or ": error: " in output + ): + raise_exception = True + p.poll() + error = p.stderr.readline().decode(errors="ignore") + p.stdout.close() + if error and raise_exception: + raise RuntimeError( + f"An error was found in the output. The build is stopped." + f"\n{output}\n---\n{error}" + ) + return output + "\n" + error + + +def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: + """ + Run :epkg:`Graphviz`. + + :param filename: filename which contains the graph definition + :param image: output image + :param engine: *dot* or *neato* + :return: output of graphviz + """ + ext = os.path.splitext(image)[-1] + assert ext in { + ".png", + ".bmp", + ".fig", + ".gif", + ".ico", + ".jpg", + ".jpeg", + ".pdf", + ".ps", + ".svg", + ".vrml", + ".tif", + ".tiff", + ".wbmp", + }, f"Unexpected extension {ext!r} for {image!r}." + if sys.platform.startswith("win"): + bin_ = os.path.dirname(_find_graphviz_dot()) + # if bin not in os.environ["PATH"]: + # os.environ["PATH"] = os.environ["PATH"] + ";" + bin + exe = os.path.join(bin_, engine) + else: + exe = engine + if os.path.exists(image): + os.remove(image) + cmd = [exe, f"-T{ext[1:]}", filename, "-o", image] + output = _run_subprocess(cmd) + assert os.path.exists(image), ( + f"Unable to find {image!r}, command line is " + f"{' '.join(cmd)!r}, Graphviz failed due to\n{output}" + ) + return output + + +def draw_graph_graphviz( + dot: Union[str, onnx.ModelProto], image: str, engine: str = "dot" +) -> str: + """ + Draws a graph using :epkg:`Graphviz`. + + :param dot: dot graph or ModelProto + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :return: :epkg:`Graphviz` output or + the dot text if *image* is None + + The function creates a temporary file to store the dot file if *image* is not None. + """ + if isinstance(dot, onnx.ModelProto): + sdot = to_dot(dot) + else: + if "{" not in dot: + assert dot.endswith(".onnx"), f"Unexpected file extension for {dot!r}" + proto = onnx.load(dot) + sdot = to_dot(proto) + else: + sdot = dot + assert "{" in sdot, f"This string is not a dot string\n{sdot}" + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(sdot.encode("utf-8")) + fp.close() + + filename = fp.name + assert os.path.exists( + filename + ), f"File {filename!r} cannot be created to store the graph." + out = _run_graphviz(filename, image, engine=engine) + assert os.path.exists( + image + ), f"Graphviz failed with no reason, {image!r} not found, output is {out}." + os.remove(filename) + return out + + +def plot_dot( + dot: Union[str, onnx.ModelProto], + ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 + engine: str = "dot", + figsize: Optional[Tuple[int, int]] = None, +) -> "matplotlib.axis.Axis": # noqa: F821 + """ + Draws a dot graph into a matplotlib graph. + + :param dot: dot graph or ModelProto + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :param figsize: figsize of ax is None + :return: :epkg:`Graphviz` output or, the dot text if *image* is None + + .. plot:: + + import matplotlib.pyplot as plt + import onnx.parser + from onnx_diagnostic.doc import plot_dot + + model = onnx.parser.parse_model( + ''' + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(four, four) + } + ''') + + ax = plot_dot(model) + ax.set_title("Dummy graph") + plt.show() + """ + if ax is None: + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1, figsize=figsize) + clean = True + else: + clean = False + + from PIL import Image + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp: + fp.close() + + draw_graph_graphviz(dot, fp.name, engine=engine) + img = np.asarray(Image.open(fp.name)) + os.remove(fp.name) + + ax.imshow(img) + + if clean: + ax.set_xticks([]) + ax.set_yticks([]) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + ax.set_axis_off() + ax.get_figure().tight_layout() + return ax diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 5d506bb8..c56a57f3 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -8,6 +8,7 @@ from .dynamic_shapes import ModelInputs from .onnx_plug import EagerDirectReplacementWithOnnx from ..helpers import flatten_object, max_diff, string_diff, string_type +from ..helpers.cache_helper import CacheKeyValue from ..helpers.torch_helper import torch_deepcopy from ..helpers.rt_helper import make_feeds from ..reference import OnnxruntimeEvaluator @@ -541,6 +542,83 @@ def __init__(self, parent): **self._to_onnx_kwargs, ) + @classmethod + def make_empty_cache_from_others(cls, examples: List["Cache"]) -> "Cache": # noqa: F821 + """Builds an empty cache based on existing one.""" + unique_types = {type(t) for t in examples} + assert ( + len(unique_types) == 1 + ), f"Unable to guess an empty cache from {string_type(examples, with_shape=True)}" + unique_type = unique_types.pop() + if unique_type == torch.Tensor: + shapes = [t.shape for t in examples] + assert len(set(shapes)) > 1, f"Unable to guess an empty shape from shapes {shapes}" + ranks = {len(s) for s in shapes} + assert len(ranks) == 1, f"Ranks are different in {shapes}" + rank = ranks.pop() + new_shape = [] + for i in range(rank): + dims = [t.shape[i] for t in examples] + if len(set(dims)) == 1: + new_shape.append(dims[0]) + else: + # The empty shape + new_shape.append(0) + example = examples[0] + return torch.empty(tuple(new_shape), dtype=example.dtype, device=example.device) + assert ( + unique_type.__name__ == "DynamicCache" + ), f"This is not implemented for class {unique_type}" + caches = [CacheKeyValue(dc) for dc in examples] + caches_list = [dc.aslist() for dc in caches] + empty = [ + cls.make_empty_cache_from_others([caches_list[i][k] for i in range(len(examples))]) + for k in range(len(caches_list[0])) + ] + empty_cache = CacheKeyValue( + empty, cls_layers=caches[0].cls_layers + ).make_dynamic_cache() + return empty_cache + + @classmethod + def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]: + """ + Adds empty cache if needed as onnxruntime needs an empty cache, + not a missing cache. It only works if inputs are defined as a dictionary. + """ + if all(isinstance(t, tuple) for t in inputs) and all( + len(t) == 2 and isinstance(t[0], tuple) and isinstance(t[1], dict) and not t[0] + for t in inputs + ): + dict_part = [t[1] for t in inputs] + res = cls.add_empty_cache_if_needed(dict_part) + return [(tuple(), d) for d in res] + if any(not isinstance(t, dict) for t in inputs): + return inputs + all_keys = set() + for input_set in inputs: + all_keys |= set(input_set) + # even though the inputs are defined as a dictionary, it is better + # to keep the same order + ordered = None + for input_set in inputs: + if set(input_set) == all_keys: + ordered = list(input_set) + break + new_inputs = [] + for input_set in inputs: + if set(input_set) == all_keys: + new_inputs.append(input_set) + continue + missing = {k for k in all_keys if k not in input_set} + input_set_copy = input_set.copy() + for miss in missing: + input_set_copy[miss] = cls.make_empty_cache_from_others( + [sub[miss] for sub in inputs if miss in sub] + ) + new_inputs.append({k: input_set_copy[k] for k in ordered}) + return new_inputs + def check_discrepancies( self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0 ) -> List[Dict[str, Union[str, int, float]]]: @@ -555,6 +633,18 @@ def check_discrepancies( :param verbose: verbosity :return: results, a list of dictionaries, ready to be consumed by a dataframe """ + + def _missing_classes(): + try: + import transformers + + return [ + transformers.modeling_outputs.CausalLMOutputWithPast, + transformers.cache_utils.DynamicCache, + ] + except ImportError: + return [] + assert self._export_done, "The onnx export was not done." assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" assert os.path.exists( @@ -565,9 +655,13 @@ def check_discrepancies( filename ), f"onnx file {filename!r} not found" classes = [ - cls - for cls in self._serialization_classes - if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device} + *_missing_classes(), + *[ + cls + for cls in self._serialization_classes + if cls + not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device} + ], ] if verbose: print(f"[method_to_onnx.check_discrepancies] register classes {classes}") @@ -590,7 +684,9 @@ def check_discrepancies( if verbose: print(f"[method_to_onnx.check_discrepancies] input_names={input_names}") data = [] - for i, (input, (output, latency)) in enumerate(zip(inputs, outputs)): + for i, (input, (output, latency)) in enumerate( + zip(self.add_empty_cache_if_needed(inputs), outputs) + ): if verbose: if verbose > 1: print( @@ -602,7 +698,10 @@ def check_discrepancies( f"{string_type(output, with_shape=True)}" ) else: - print(f"[method_to_onnx.check_discrepancies] process input {i}") + print( + f"[method_to_onnx.check_discrepancies] " + f"process input {i} #inputs={len(input)}" + ) flat_inputs = flatten_object(input, drop_keys=True) if len(flat_inputs) < len(input_names): diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 767c4f1f..0e705661 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -19,7 +19,7 @@ class CacheKeyValue: capi.value_cache """ - def __init__(self, cache=None): + def __init__(self, cache=None, cls_layers=None): if hasattr(cache, "layers"): layers = [ layer @@ -28,15 +28,26 @@ def __init__(self, cache=None): ] self.key_cache = [layer.keys for layer in layers] self.value_cache = [layer.values for layer in layers] + assert ( + cls_layers is None + ), f"cache is {type(cache)}, cannot specify cls_layers={cls_layers}" self.cls_layers = [type(lay) for lay in cache.layers] elif cache is not None and hasattr(cache, "key_cache"): self.key_cache = cache.key_cache self.value_cache = cache.value_cache - self.cls_layers = None + self.cls_layers = cls_layers + elif ( + cache is not None + and isinstance(cache, list) + and all(isinstance(t, torch.Tensor) for t in cache) + ): + self.key_cache = cache[::2] + self.value_cache = cache[1::2] + self.cls_layers = cls_layers elif cache is None: self.key_cache = None self.value_cache = None - self.cls_layers = None + self.cls_layers = cls_layers else: raise NotImplementedError(f"type(cache)={type(cache)}") @@ -51,6 +62,18 @@ def n_layers(self) -> int: """Returns the number of layers.""" return len(self.key_cache) if self.key_cache else 0 + def __len__(self) -> int: + "Returns the number of tensors." + return len(self.key_cache) + len(self.value_cache) + + def aslist(self) -> List[torch.Tensor]: + "Returns tensors in a list." + res = [] + for i in range(self.n_layers): + res.append(self.key_cache[i]) + res.append(self.value_cache[i]) + return res + def flatten_unflatten_for_dynamic_shapes( obj: Any, From d32419e6beaa539c44fe3573af9fb1713f8a9ffe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 10 Jan 2026 18:49:12 +0100 Subject: [PATCH 4/9] fix --- onnx_diagnostic/export/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index c56a57f3..f48ab660 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -699,8 +699,8 @@ def _missing_classes(): ) else: print( - f"[method_to_onnx.check_discrepancies] " - f"process input {i} #inputs={len(input)}" + f"[method_to_onnx.check_discrepancies] process input {i} " + f"#args={len(input[0])} #kwargs={len(input[1])}" ) flat_inputs = flatten_object(input, drop_keys=True) From 2872f707c0280d4153b76c33f079d1d12372fe41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 11:17:08 +0100 Subject: [PATCH 5/9] disable one test --- .../test_documentation_examples.py | 4 +++- onnx_diagnostic/export/api.py | 21 +++---------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index beef33cc..556e6297 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -125,7 +125,6 @@ def add_test_methods(cls): "plot_export_locate_issue.py", "plot_export_with_auto.py", "plot_export_tiny_llm.py", - "plot_export_tiny_llm_method_generate.py", } and not has_torch("2.8") ): @@ -155,6 +154,9 @@ def add_test_methods(cls): if not reason and torch.__version__.startswith("2.9.0"): reason = "examples are failing for on CI for 2.9.0" + if not reason and name in {"plot_export_tiny_llm_method_generate.py"}: + reason = "does not work when called in a separate process" + if reason: @unittest.skip(reason) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index f48ab660..73fdc895 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -634,17 +634,6 @@ def check_discrepancies( :return: results, a list of dictionaries, ready to be consumed by a dataframe """ - def _missing_classes(): - try: - import transformers - - return [ - transformers.modeling_outputs.CausalLMOutputWithPast, - transformers.cache_utils.DynamicCache, - ] - except ImportError: - return [] - assert self._export_done, "The onnx export was not done." assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" assert os.path.exists( @@ -655,13 +644,9 @@ def _missing_classes(): filename ), f"onnx file {filename!r} not found" classes = [ - *_missing_classes(), - *[ - cls - for cls in self._serialization_classes - if cls - not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device} - ], + cls + for cls in self._serialization_classes + if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device} ] if verbose: print(f"[method_to_onnx.check_discrepancies] register classes {classes}") From 959006e5d02ff7a44f27a85814e0731e7dc51c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 11:43:34 +0100 Subject: [PATCH 6/9] add test for expand_for_batch --- _unittests/ut_export/test_api.py | 29 ++++++++++++++++++++++++++ onnx_diagnostic/export/api.py | 35 ++++++++++++++++++-------------- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 547ac589..7ba8e597 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -306,6 +306,35 @@ def test_add_empty_cache_if_needed_args_kwargs(self): self.string_type(with_empty[0][1], with_shape=True), ) + @requires_experimental_experiment("0.1") + def test_method_to_onnx_expand_batch(self): + class Model(torch.nn.Module): + def forward(self, x=None, y=None): + return x + y + + filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx") + inputs = [ + dict(x=torch.randn((1, 5, 6)), y=torch.randn((1, 1, 6))), + dict(x=torch.randn((1, 7, 7)), y=torch.randn((1, 1, 7))), + ] + model = Model() + method_to_call = method_to_onnx( + model, exporter="custom", filename=filename, expand_batch_for={"x", "y"} + ) + expecteds = [] + for kwargs in inputs: + expecteds.append(method_to_call(**kwargs)) + self.assertExists(filename) + sess = self.check_ort(filename) + input_names = [i.name for i in sess.get_inputs()] + input_shapes = [i.shape for i in sess.get_inputs()] + print("***", input_shapes) + for expected, kwargs in zip(expecteds, inputs): + feeds = make_feeds(input_names, kwargs, use_numpy=True) + got = sess.run(None, feeds) + self.assertEqualArray(expected, got[0]) + self.clean_dump() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 73fdc895..ed5967a0 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -349,6 +349,7 @@ def __init__( patch_kwargs: Optional[Dict[str, Any]] = None, skip_kwargs_names: Optional[Set[str]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + expand_batch_for: Optional[Sequence[Union[int, str]]] = None, ): super().__init__() self._model_to_call = mod @@ -384,11 +385,14 @@ def __init__( ) self._export_done = False self._serialization_classes: Set[type] = set() + self._expand_batch_for = expand_batch_for def __str__(self) -> str: + "usual" return self.__repr__() def __repr__(self) -> str: + "usual" return ( f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}." f"{self._method_name})" @@ -415,22 +419,17 @@ def _collect_classes(self, obj): def forward(self, *args, **kwargs): if not self._export_done: - self._inputs.append( - torch_deepcopy( - ( - args, - ( - kwargs - if not kwargs or not self.skip_kwargs_names - else { - k: v - for k, v in kwargs.items() - if k not in self.skip_kwargs_names - } - ), - ) - ) + inp_args = args + inp_kwargs = ( + kwargs + if not kwargs + else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names} ) + if self._expand_batch_for: + inp_args = self._expand_batch_dimension(inp_args, self._expand_batch_for) + inp_kwargs = self._expand_batch_dimension(inp_kwargs, self._expand_batch_for) + inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs)) + self._inputs.append((inp_args, inp_kwargs)) if self.verbose: print( f"[method_to_onnx] input[{len(self._inputs)-1}]: " @@ -753,6 +752,7 @@ def method_to_onnx( patch_kwargs: Optional[Dict[str, Any]] = None, skip_kwargs_names: Optional[Set[str]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + expand_batch_for: Optional[Sequence[Union[int, str]]] = None, ) -> Callable: """ Exports one method into ONNX for a module into ONNX. @@ -782,6 +782,10 @@ def method_to_onnx( :param skip_kwargs_names: use default values for these parameters part of the signature of the method to export :param dynamic_shapes: dynamic shapes to use if the guessed ones are not right + :param expand_batch_for: LLM are usually called with a batch size equal to 1, + but the export may benefit from having another value for the batch size, + this parameter forces the input specified in this set to be expanded + to 2 if the batch size is one :return: the output of the selected exporter, usually a structure including an onnx model @@ -808,5 +812,6 @@ def method_to_onnx( patch_kwargs=patch_kwargs, skip_kwargs_names=skip_kwargs_names, dynamic_shapes=dynamic_shapes, + expand_batch_for=expand_batch_for, ) return wrapped_model From 12c3150d8081c31fcd895eabfb6d687d03e1adad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 12:56:41 +0100 Subject: [PATCH 7/9] fix --- _unittests/ut_export/test_api.py | 61 ++++++++++++++++++++++++++++---- onnx_diagnostic/export/api.py | 52 +++++++++++++++++++++++---- 2 files changed, 101 insertions(+), 12 deletions(-) diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 7ba8e597..b1395cf3 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -307,7 +307,7 @@ def test_add_empty_cache_if_needed_args_kwargs(self): ) @requires_experimental_experiment("0.1") - def test_method_to_onnx_expand_batch(self): + def test_method_to_onnx_expand_batch_dimension(self): class Model(torch.nn.Module): def forward(self, x=None, y=None): return x + y @@ -327,12 +327,61 @@ def forward(self, x=None, y=None): self.assertExists(filename) sess = self.check_ort(filename) input_names = [i.name for i in sess.get_inputs()] + self.assertEqual(["x", "y"], input_names) input_shapes = [i.shape for i in sess.get_inputs()] - print("***", input_shapes) - for expected, kwargs in zip(expecteds, inputs): - feeds = make_feeds(input_names, kwargs, use_numpy=True) - got = sess.run(None, feeds) - self.assertEqualArray(expected, got[0]) + self.assertEqual([[2, "channel", "D0"], [2, 1, "D0_1"]], input_shapes) + self.clean_dump() + + @requires_experimental_experiment("0.1") + @requires_transformers("4.57") + def test_method_to_onnx_expand_batch_dimension_dynamic_cache(self): + class Model(torch.nn.Module): + def forward(self, x=None, cache=None): + return x + cache.layers[0].keys + + filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx") + inputs = [ + dict( + x=torch.randn((1, 1, 3, 4)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)] + ), + ), + dict( + x=torch.randn((1, 1, 5, 4)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 5, 4), torch.randn(1, 2, 5, 4)) for i in range(2)] + ), + ), + ] + model = Model() + method_to_call = method_to_onnx( + model, + exporter="custom", + filename=filename, + expand_batch_for={"x", "cache"}, + patch_kwargs=dict(patch_transformers=True), + ) + expecteds = [] + for kwargs in inputs: + expecteds.append(method_to_call(**kwargs)) + self.assertExists(filename) + sess = self.check_ort(filename) + input_names = [i.name for i in sess.get_inputs()] + self.assertEqual( + ["x", "cache_key_0", "cache_value_0", "cache_key_1", "cache_value_1"], input_names + ) + input_shapes = [i.shape for i in sess.get_inputs()] + self.assertEqual( + [ + [2, 1, "D0", 4], + [2, 2, "D0_1", 4], + [2, 2, "D0_2", 4], + [2, 2, "D0_3", 4], + [2, 2, "D0_4", 4], + ], + input_shapes, + ) self.clean_dump() diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index ed5967a0..450b29f0 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -367,6 +367,7 @@ def __init__( self.verbose = verbose self.skip_kwargs_names = skip_kwargs_names self.dynamic_shapes = dynamic_shapes + self.expand_batch_for = expand_batch_for self._to_onnx_kwargs = dict( input_names=input_names, target_opset=target_opset, @@ -385,7 +386,6 @@ def __init__( ) self._export_done = False self._serialization_classes: Set[type] = set() - self._expand_batch_for = expand_batch_for def __str__(self) -> str: "usual" @@ -422,12 +422,12 @@ def forward(self, *args, **kwargs): inp_args = args inp_kwargs = ( kwargs - if not kwargs + if not kwargs or not self.skip_kwargs_names else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names} ) - if self._expand_batch_for: - inp_args = self._expand_batch_dimension(inp_args, self._expand_batch_for) - inp_kwargs = self._expand_batch_dimension(inp_kwargs, self._expand_batch_for) + if self.expand_batch_for: + inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for) + inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for) inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs)) self._inputs.append((inp_args, inp_kwargs)) if self.verbose: @@ -618,6 +618,47 @@ def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]: new_inputs.append({k: input_set_copy[k] for k in ordered}) return new_inputs + @classmethod + def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any: + expand_for_args = {i for i in expand_for if isinstance(i, int)} + expand_for_kwargs = {i for i in expand_for if isinstance(i, str)} + if isinstance(obj, tuple): + return tuple( + o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i) + for i, o in enumerate(obj) + ) + assert isinstance(obj, dict), f"Unexpected type {type(obj)}" + return { + k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k) + for k, v in obj.items() + } + + @classmethod + def _expand_batch_dimension_input(cls, obj: Any, msg: str) -> Any: + if isinstance(obj, torch.Tensor): + assert obj.shape[0] == 1, ( + f"Are you sure to expoand input {msg!r}, " + f"batch size is not 1 and shape={obj.shape}" + ) + sizes = [2, *obj.shape[1:]] + return obj.expand(*sizes) + if isinstance(obj, list): + return [ + cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj) + ] + if obj.__class__.__name__ == "DynamicCache": + dc = CacheKeyValue(obj) + flat = dc.aslist() + flat = cls._expand_batch_dimension_input(flat, msg) + return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache() + # This might end up in an infinite loop if no registration is done. + flat, _spec = torch.utils._pytree.tree_flatten(obj) + assert ( + not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj) + ), f"class {type(obj)} was is not registered for serialization." + flat = cls._expand_batch_dimension_input(flat, msg) + return torch.utils._pytree.tree_unflatten(flat, _spec) + def check_discrepancies( self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0 ) -> List[Dict[str, Union[str, int, float]]]: @@ -632,7 +673,6 @@ def check_discrepancies( :param verbose: verbosity :return: results, a list of dictionaries, ready to be consumed by a dataframe """ - assert self._export_done, "The onnx export was not done." assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" assert os.path.exists( From dbbe5081b1b23a2ac48966906c6c3ad0a6f93299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 12:58:13 +0100 Subject: [PATCH 8/9] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index d3671fb5..7c166329 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.9 +++++ +* :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx`` * :pr:`378`: implements the computation of discrepancies in ``method_to_onnx`` * :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``, From e05e27f177b3e89fcc5cce7b82a984bac966ae7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 11 Jan 2026 13:02:25 +0100 Subject: [PATCH 9/9] mypy --- onnx_diagnostic/doc.py | 7 ++++--- onnx_diagnostic/export/api.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/doc.py b/onnx_diagnostic/doc.py index cf0e22ad..4a19ca9a 100644 --- a/onnx_diagnostic/doc.py +++ b/onnx_diagnostic/doc.py @@ -178,7 +178,7 @@ def _run_subprocess(args: List[str], cwd: Optional[str] = None): raise_exception = False output = "" while True: - output = p.stdout.readline().decode(errors="ignore") + output = p.stdout.readline().decode(errors="ignore") # type: ignore[union-attr] if output == "" and p.poll() is not None: break if output: @@ -191,8 +191,9 @@ def _run_subprocess(args: List[str], cwd: Optional[str] = None): ): raise_exception = True p.poll() - error = p.stderr.readline().decode(errors="ignore") - p.stdout.close() + error = p.stderr.readline().decode(errors="ignore") # type: ignore[union-attr] + p.stdout.close() # type: ignore[union-attr] + p.stderr.close() # type: ignore[union-attr] if error and raise_exception: raise RuntimeError( f"An error was found in the output. The build is stopped." diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 450b29f0..2c642dfe 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -542,7 +542,7 @@ def __init__(self, parent): ) @classmethod - def make_empty_cache_from_others(cls, examples: List["Cache"]) -> "Cache": # noqa: F821 + def make_empty_cache_from_others(cls, examples: List[Any]) -> Any: """Builds an empty cache based on existing one.""" unique_types = {type(t) for t in examples} assert ( @@ -615,7 +615,7 @@ def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]: input_set_copy[miss] = cls.make_empty_cache_from_others( [sub[miss] for sub in inputs if miss in sub] ) - new_inputs.append({k: input_set_copy[k] for k in ordered}) + new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr] return new_inputs @classmethod @@ -634,7 +634,7 @@ def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]] } @classmethod - def _expand_batch_dimension_input(cls, obj: Any, msg: str) -> Any: + def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any: if isinstance(obj, torch.Tensor): assert obj.shape[0] == 1, ( f"Are you sure to expoand input {msg!r}, "