diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 26c9fc34..ddf6b7fe 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.8.10 +++++++ + +* :pr:`384`: add ``weights_only=False`` when using :func:`torch.load` + 0.8.9 +++++ diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index 9de38a26..7a2c80e9 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -48,10 +48,9 @@ def generate_text( generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text - # Define your prompt - -prompt = "Continue: it rains..." +# Define your prompt +prompt = "Continue: it rains, what should I do?" generated_text = generate_text(prompt, model, tokenizer) print("-----------------") print(generated_text) @@ -69,7 +68,7 @@ def generate_text( # If the default settings do not work, ``skip_kwargs_names`` and ``dynamic_shapes`` # can be changed to remove some undesired inputs or add more dynamic dimensions. -filename = "plot_export_tiny_llm_method_generate.onnx" +filename = "plot_export_tiny_llm_method_generate.custom.onnx" forward_replacement = method_to_onnx( model, method_name="forward", # default value @@ -87,8 +86,12 @@ def generate_text( # The input used in the example has a batch size equal to 1, all # inputs going through method forward will have the same batch size. # To force the dynamism of this dimension, we need to indicate - # which inputs has a batch size. + # which inputs have a batch size. dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"}, + # Earlier versions of pytorch did not accept a dynamic batch size equal to 1, + # this last parameter can be added to expand some inputs if the batch size is 1. + # The exporter should work without. + expand_batch_for={"input_ids", "attention_mask", "past_key_values"}, ) # %% @@ -139,6 +142,51 @@ def generate_text( df = pandas.DataFrame(data) print(df) +# %% +# Minimal script to export a LLM +# ++++++++++++++++++++++++++++++ +# +# The following lines are a condensed copy with less comments. + +# from HuggingFace +print("----------------") +MODEL_NAME = "arnir0/Tiny-LLM" +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + +# to export into onnx +forward_replacement = method_to_onnx( + model, + method_name="forward", + exporter="onnx-dynamo", + filename="plot_export_tiny_llm_method_generate.dynamo.onnx", + patch_kwargs=dict(patch_transformers=True), + verbose=0, + convert_after_n_calls=3, + dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"}, +) +model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs) + +# from HuggingFace again +prompt = "Continue: it rains, what should I do?" +inputs = tokenizer(prompt, return_tensors="pt") +outputs = model.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_length=100, + temperature=1, + top_k=50, + top_p=0.95, + do_sample=True, +) +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print("prompt answer:", generated_text) + +# to check discrepancies +data = forward_replacement.check_discrepancies() +df = pandas.DataFrame(data) +print(df) + # %% doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_doc/index.rst b/_doc/index.rst index 07963a4e..cefab167 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -240,8 +240,8 @@ The function replaces dynamic dimensions defined as strings by Older versions ============== +* `0.8.10 <../v0.8.10/index.html>`_ * `0.8.9 <../v0.8.9/index.html>`_ -* `0.8.8 <../v0.8.8/index.html>`_ * `0.7.16 <../v0.7.16/index.html>`_ * `0.6.3 <../v0.6.3/index.html>`_ * `0.5.0 <../v0.5.0/index.html>`_ diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index f35627ab..9daf0619 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.8.9" +__version__ = "0.8.10" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 3e89f13a..e19a30fd 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -445,10 +445,6 @@ def forward(self, *args, **kwargs): and not isinstance(v, (bool, int, float)) } ) - if self.expand_batch_for: - # extends the inputs to artificially create a batch dimension != 1. - inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for) - inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for) inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs)) # reorders the parameter following the method signature. inp_kwargs = self._reorder_kwargs(inp_kwargs) @@ -557,6 +553,10 @@ def __init__(self, parent): else: a, kw = self._inputs[-1] nds = [self.dynamic_shapes] + if self.expand_batch_for: + # extends the inputs to artificially create a batch dimension != 1. + a = self._expand_batch_dimension(a, self.expand_batch_for) + kw = self._expand_batch_dimension(kw, self.expand_batch_for) if self.verbose: print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}") print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}") @@ -738,7 +738,9 @@ def check_discrepancies( :param verbose: verbosity :return: results, a list of dictionaries, ready to be consumed by a dataframe """ - assert self._export_done, "The onnx export was not done." + assert ( + self._export_done + ), f"The onnx export was not done, only {len(self._inputs)} were stored." assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" assert os.path.exists( self._output_file @@ -750,17 +752,29 @@ def check_discrepancies( classes = [ cls for cls in self._serialization_classes - if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device} + if cls + not in { + int, + float, + bool, + str, + torch.Tensor, + list, + set, + dict, + torch.device, + torch.dtype, + } ] if verbose: print(f"[method_to_onnx.check_discrepancies] register classes {classes}") print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}") with torch.serialization.safe_globals(classes): - inputs = torch.load(self._input_file) + inputs = torch.load(self._input_file, weights_only=False) if verbose: print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}") with torch.serialization.safe_globals(classes): - outputs = torch.load(self._output_file) + outputs = torch.load(self._output_file, weights_only=False) assert len(inputs) == len(outputs), ( f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, " f"inputs={string_type(inputs, with_shape=True)}, " diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 1b1bad0d..cbab26c8 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -305,7 +305,7 @@ def serialization_functions( def unregister_class_serialization(cls: type, verbose: int = 0): - """Undo the registration.""" + """Undo the registration for a class.""" # torch.utils._pytree._deregister_pytree_flatten_spec(cls) if cls in torch.fx._pytree.SUPPORTED_NODES: del torch.fx._pytree.SUPPORTED_NODES[cls] @@ -333,6 +333,10 @@ def unregister_class_serialization(cls: type, verbose: int = 0): def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): + """ + Undo the registration made by + :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`. + """ cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo) for cls in cls_ensemble: if undo.get(cls.__name__, False): diff --git a/pyproject.toml b/pyproject.toml index 540f2244..c2ff498c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "onnx-diagnostic" -version = "0.8.9" +version = "0.8.10" description = "Tools to help converting pytorch models into ONNX." readme = "README.rst" authors = [