Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Change Logs
===========

0.8.10
++++++

* :pr:`384`: add ``weights_only=False`` when using :func:`torch.load`

0.8.9
+++++

Expand Down
58 changes: 53 additions & 5 deletions _doc/examples/plot_export_tiny_llm_method_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def generate_text(
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text

# Define your prompt


prompt = "Continue: it rains..."
# Define your prompt
prompt = "Continue: it rains, what should I do?"
generated_text = generate_text(prompt, model, tokenizer)
print("-----------------")
print(generated_text)
Expand All @@ -69,7 +68,7 @@ def generate_text(
# If the default settings do not work, ``skip_kwargs_names`` and ``dynamic_shapes``
# can be changed to remove some undesired inputs or add more dynamic dimensions.

filename = "plot_export_tiny_llm_method_generate.onnx"
filename = "plot_export_tiny_llm_method_generate.custom.onnx"
forward_replacement = method_to_onnx(
model,
method_name="forward", # default value
Expand All @@ -87,8 +86,12 @@ def generate_text(
# The input used in the example has a batch size equal to 1, all
# inputs going through method forward will have the same batch size.
# To force the dynamism of this dimension, we need to indicate
# which inputs has a batch size.
# which inputs have a batch size.
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
# Earlier versions of pytorch did not accept a dynamic batch size equal to 1,
# this last parameter can be added to expand some inputs if the batch size is 1.
# The exporter should work without.
expand_batch_for={"input_ids", "attention_mask", "past_key_values"},
)

# %%
Expand Down Expand Up @@ -139,6 +142,51 @@ def generate_text(
df = pandas.DataFrame(data)
print(df)

# %%
# Minimal script to export a LLM
# ++++++++++++++++++++++++++++++
#
# The following lines are a condensed copy with less comments.

# from HuggingFace
print("----------------")
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

# to export into onnx
forward_replacement = method_to_onnx(
model,
method_name="forward",
exporter="onnx-dynamo",
filename="plot_export_tiny_llm_method_generate.dynamo.onnx",
patch_kwargs=dict(patch_transformers=True),
verbose=0,
convert_after_n_calls=3,
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
)
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)

# from HuggingFace again
prompt = "Continue: it rains, what should I do?"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=100,
temperature=1,
top_k=50,
top_p=0.95,
do_sample=True,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("prompt answer:", generated_text)

# to check discrepancies
data = forward_replacement.check_discrepancies()
df = pandas.DataFrame(data)
print(df)


# %%
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
2 changes: 1 addition & 1 deletion _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ The function replaces dynamic dimensions defined as strings by
Older versions
==============

* `0.8.10 <../v0.8.10/index.html>`_
* `0.8.9 <../v0.8.9/index.html>`_
* `0.8.8 <../v0.8.8/index.html>`_
* `0.7.16 <../v0.7.16/index.html>`_
* `0.6.3 <../v0.6.3/index.html>`_
* `0.5.0 <../v0.5.0/index.html>`_
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Functions, classes to dig into a model when this one is right, slow, wrong...
"""

__version__ = "0.8.9"
__version__ = "0.8.10"
__author__ = "Xavier Dupré"
30 changes: 22 additions & 8 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,6 @@ def forward(self, *args, **kwargs):
and not isinstance(v, (bool, int, float))
}
)
if self.expand_batch_for:
# extends the inputs to artificially create a batch dimension != 1.
inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
# reorders the parameter following the method signature.
inp_kwargs = self._reorder_kwargs(inp_kwargs)
Expand Down Expand Up @@ -557,6 +553,10 @@ def __init__(self, parent):
else:
a, kw = self._inputs[-1]
nds = [self.dynamic_shapes]
if self.expand_batch_for:
# extends the inputs to artificially create a batch dimension != 1.
a = self._expand_batch_dimension(a, self.expand_batch_for)
kw = self._expand_batch_dimension(kw, self.expand_batch_for)
if self.verbose:
print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
Expand Down Expand Up @@ -738,7 +738,9 @@ def check_discrepancies(
:param verbose: verbosity
:return: results, a list of dictionaries, ready to be consumed by a dataframe
"""
assert self._export_done, "The onnx export was not done."
assert (
self._export_done
), f"The onnx export was not done, only {len(self._inputs)} were stored."
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
assert os.path.exists(
self._output_file
Expand All @@ -750,17 +752,29 @@ def check_discrepancies(
classes = [
cls
for cls in self._serialization_classes
if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
if cls
not in {
int,
float,
bool,
str,
torch.Tensor,
list,
set,
dict,
torch.device,
torch.dtype,
}
]
if verbose:
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
with torch.serialization.safe_globals(classes):
inputs = torch.load(self._input_file)
inputs = torch.load(self._input_file, weights_only=False)
if verbose:
print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
with torch.serialization.safe_globals(classes):
outputs = torch.load(self._output_file)
outputs = torch.load(self._output_file, weights_only=False)
assert len(inputs) == len(outputs), (
f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
f"inputs={string_type(inputs, with_shape=True)}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def serialization_functions(


def unregister_class_serialization(cls: type, verbose: int = 0):
"""Undo the registration."""
"""Undo the registration for a class."""
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
if cls in torch.fx._pytree.SUPPORTED_NODES:
del torch.fx._pytree.SUPPORTED_NODES[cls]
Expand Down Expand Up @@ -333,6 +333,10 @@ def unregister_class_serialization(cls: type, verbose: int = 0):


def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
"""
Undo the registration made by
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`.
"""
cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
for cls in cls_ensemble:
if undo.get(cls.__name__, False):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "onnx-diagnostic"
version = "0.8.9"
version = "0.8.10"
description = "Tools to help converting pytorch models into ONNX."
readme = "README.rst"
authors = [
Expand Down
Loading