Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Enlightening Examples

* `Export microsoft/phi-2
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
* `Export a model through method generate (with Tiny-LLM)
* `Export a LLM through method generate (with Tiny-LLM)
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_method_generate.html>`_

**Torch Export**
Expand Down
3 changes: 1 addition & 2 deletions _doc/examples/plot_dump_intermediate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import onnx
import torch
import onnxruntime
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward
Expand Down Expand Up @@ -203,7 +202,7 @@
# ++++++++++++++++++++

onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)
doc.plot_dot(onx)

# %%
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")
15 changes: 9 additions & 6 deletions _doc/examples/plot_export_tiny_llm_method_generate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
.. _l-plot-tiny-llm-export-method-generate:

Export a model through method generate (with Tiny-LLM)
======================================================
Export a LLM through method generate (with Tiny-LLM)
====================================================

The main issue when exporting a LLM is the example on HuggingFace is
based on method generate but we only need to export the forward method.
Expand Down Expand Up @@ -31,10 +31,13 @@
def generate_text(
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
):
inputs = tokenizer.encode(prompt, return_tensors="pt")
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

outputs = model.generate(
inputs,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
Expand Down Expand Up @@ -97,6 +100,7 @@ def generate_text(
{0: "batch_size", 2: "past_sequence_length"},
],
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
},
)

Expand Down Expand Up @@ -129,5 +133,4 @@ def generate_text(


# %%

doc.plot_legend("Tiny-LLM\nforward inputs\nthrough generate", "onnx export", "tomato")
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
3 changes: 1 addition & 2 deletions _doc/technical/plot_layer_norm_discrepancies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
import onnx.helper as oh
import onnxruntime
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title
from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title, plot_dot
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype
Expand Down
58 changes: 57 additions & 1 deletion _unittests/ut_export/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.export.api import to_onnx, method_to_onnx
from onnx_diagnostic.export.api import to_onnx, method_to_onnx, WrapperToExportMethodToOnnx


class TestValidate(ExtTestCase):
Expand Down Expand Up @@ -250,6 +250,62 @@ def forward(self, x, y=None):
method_to_call.check_discrepancies(verbose=1)
self.clean_dump()

def test_add_empty_cache_if_needed_dict(self):
inputs = [
dict(x=torch.rand((1, 3))),
dict(
x=torch.rand((1, 3)),
cache=make_dynamic_cache(
[(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)]
),
),
dict(
x=torch.rand((1, 3)),
cache=make_dynamic_cache(
[(torch.randn(1, 2, 1, 4), torch.randn(1, 2, 1, 4)) for i in range(2)]
),
),
]
with_empty = WrapperToExportMethodToOnnx.add_empty_cache_if_needed(inputs)
self.assertEqual(
(
"dict(x:T1s1x3,cache:DynamicCache(key_cache=#2[T1s1x2x0x4,T1s1x2x0x4], "
"value_cache=#2[T1s1x2x0x4,T1s1x2x0x4]))"
),
self.string_type(with_empty[0], with_shape=True),
)

def test_add_empty_cache_if_needed_args_kwargs(self):
inputs = [
(tuple(), dict(x=torch.rand((1, 3)))),
(
tuple(),
dict(
x=torch.rand((1, 3)),
cache=make_dynamic_cache(
[(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)]
),
),
),
(
tuple(),
dict(
x=torch.rand((1, 3)),
cache=make_dynamic_cache(
[(torch.randn(1, 2, 1, 4), torch.randn(1, 2, 1, 4)) for i in range(2)]
),
),
),
]
with_empty = WrapperToExportMethodToOnnx.add_empty_cache_if_needed(inputs)
self.assertEqual(
(
"dict(x:T1s1x3,cache:DynamicCache(key_cache=#2[T1s1x2x0x4,T1s1x2x0x4], "
"value_cache=#2[T1s1x2x0x4,T1s1x2x0x4]))"
),
self.string_type(with_empty[0][1], with_shape=True),
)


if __name__ == "__main__":
unittest.main(verbosity=2)
33 changes: 33 additions & 0 deletions _unittests/ut_export/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,39 @@ def forward(self, x, y=None):
_a, _kw, ds = mi.move_to_kwargs(*mi.inputs[-1], ds)
self.assertEqual(ds, (tuple(), {"x": {0: DYN}, "y": {0: DYN}}))

def test_guess_dynamic_shapes_with_none(self):
class Model(torch.nn.Module):
def forward(self, cache=None):
return cache

cache = make_dynamic_cache(
[(torch.randn(2, 3, 5, 6), torch.randn((2, 3, 5, 6))) for i in range(2)]
)
cache2 = make_dynamic_cache(
[(torch.randn(2, 3, 1, 6), torch.randn((2, 3, 6, 6))) for i in range(2)]
)

inputs = [dict(cache=cache), dict(cache=cache2)]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
DYN = torch.export.Dim.DYNAMIC
expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]})
self.assertEqual(expected, ds)

inputs = [{}, dict(cache=cache), dict(cache=cache2)]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
DYN = torch.export.Dim.DYNAMIC
expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]})
self.assertEqual(expected, ds)

inputs = [{}, dict(cache=cache), dict(cache=cache2)]
mi = ModelInputs(None, inputs)
ds = mi.guess_dynamic_shapes()
DYN = torch.export.Dim.DYNAMIC
expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]})
self.assertEqual(expected, ds)


if __name__ == "__main__":
unittest.main(verbosity=2)
4 changes: 3 additions & 1 deletion _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def add_test_methods(cls):
"plot_export_locate_issue.py",
"plot_export_with_auto.py",
"plot_export_tiny_llm.py",
"plot_export_tiny_llm_method_generate.py",
}
and not has_torch("2.8")
):
Expand Down Expand Up @@ -155,6 +154,9 @@ def add_test_methods(cls):
if not reason and torch.__version__.startswith("2.9.0"):
reason = "examples are failing for on CI for 2.9.0"

if not reason and name in {"plot_export_tiny_llm_method_generate.py"}:
reason = "does not work when called in a separate process"

if reason:

@unittest.skip(reason)
Expand Down
Loading
Loading