diff --git a/README.rst b/README.rst index 87b14efc..1f0af03f 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Enlightening Examples * `Export microsoft/phi-2 `_ -* `Export a model through method generate (with Tiny-LLM) +* `Export a LLM through method generate (with Tiny-LLM) `_ **Torch Export** diff --git a/_doc/examples/plot_dump_intermediate_results.py b/_doc/examples/plot_dump_intermediate_results.py index bc3090ef..6ce8f8fb 100644 --- a/_doc/examples/plot_dump_intermediate_results.py +++ b/_doc/examples/plot_dump_intermediate_results.py @@ -26,7 +26,6 @@ import onnx import torch import onnxruntime -from onnx_array_api.plotting.graphviz_helper import plot_dot from onnx_diagnostic import doc from onnx_diagnostic.helpers import max_diff, string_diff, string_type from onnx_diagnostic.helpers.torch_helper import dummy_llm, steal_forward @@ -203,7 +202,7 @@ # ++++++++++++++++++++ onx = onnx.load("plot_dump_intermediate_results.onnx") -plot_dot(onx) +doc.plot_dot(onx) # %% doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue") diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index 6568ffe4..a4e3d841 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -1,8 +1,8 @@ """ .. _l-plot-tiny-llm-export-method-generate: -Export a model through method generate (with Tiny-LLM) -====================================================== +Export a LLM through method generate (with Tiny-LLM) +==================================================== The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. @@ -31,10 +31,13 @@ def generate_text( prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95 ): - inputs = tokenizer.encode(prompt, return_tensors="pt") + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] outputs = model.generate( - inputs, + input_ids=input_ids, + attention_mask=attention_mask, max_length=max_length, temperature=temperature, top_k=top_k, @@ -97,6 +100,7 @@ def generate_text( {0: "batch_size", 2: "past_sequence_length"}, ], "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, }, ) @@ -129,5 +133,4 @@ def generate_text( # %% - -doc.plot_legend("Tiny-LLM\nforward inputs\nthrough generate", "onnx export", "tomato") +doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_doc/technical/plot_layer_norm_discrepancies.py b/_doc/technical/plot_layer_norm_discrepancies.py index de7a6e71..4a40576d 100644 --- a/_doc/technical/plot_layer_norm_discrepancies.py +++ b/_doc/technical/plot_layer_norm_discrepancies.py @@ -24,8 +24,7 @@ import onnx.helper as oh import onnxruntime import torch -from onnx_array_api.plotting.graphviz_helper import plot_dot -from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title +from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title, plot_dot from onnx_diagnostic.ext_test_case import unit_test_going from onnx_diagnostic.helpers import max_diff, string_diff, string_type from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 4a47f362..547ac589 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -15,7 +15,7 @@ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches -from onnx_diagnostic.export.api import to_onnx, method_to_onnx +from onnx_diagnostic.export.api import to_onnx, method_to_onnx, WrapperToExportMethodToOnnx class TestValidate(ExtTestCase): @@ -250,6 +250,62 @@ def forward(self, x, y=None): method_to_call.check_discrepancies(verbose=1) self.clean_dump() + def test_add_empty_cache_if_needed_dict(self): + inputs = [ + dict(x=torch.rand((1, 3))), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)] + ), + ), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 1, 4), torch.randn(1, 2, 1, 4)) for i in range(2)] + ), + ), + ] + with_empty = WrapperToExportMethodToOnnx.add_empty_cache_if_needed(inputs) + self.assertEqual( + ( + "dict(x:T1s1x3,cache:DynamicCache(key_cache=#2[T1s1x2x0x4,T1s1x2x0x4], " + "value_cache=#2[T1s1x2x0x4,T1s1x2x0x4]))" + ), + self.string_type(with_empty[0], with_shape=True), + ) + + def test_add_empty_cache_if_needed_args_kwargs(self): + inputs = [ + (tuple(), dict(x=torch.rand((1, 3)))), + ( + tuple(), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)] + ), + ), + ), + ( + tuple(), + dict( + x=torch.rand((1, 3)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 1, 4), torch.randn(1, 2, 1, 4)) for i in range(2)] + ), + ), + ), + ] + with_empty = WrapperToExportMethodToOnnx.add_empty_cache_if_needed(inputs) + self.assertEqual( + ( + "dict(x:T1s1x3,cache:DynamicCache(key_cache=#2[T1s1x2x0x4,T1s1x2x0x4], " + "value_cache=#2[T1s1x2x0x4,T1s1x2x0x4]))" + ), + self.string_type(with_empty[0][1], with_shape=True), + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index f59d46c0..9381da0c 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -962,6 +962,39 @@ def forward(self, x, y=None): _a, _kw, ds = mi.move_to_kwargs(*mi.inputs[-1], ds) self.assertEqual(ds, (tuple(), {"x": {0: DYN}, "y": {0: DYN}})) + def test_guess_dynamic_shapes_with_none(self): + class Model(torch.nn.Module): + def forward(self, cache=None): + return cache + + cache = make_dynamic_cache( + [(torch.randn(2, 3, 5, 6), torch.randn((2, 3, 5, 6))) for i in range(2)] + ) + cache2 = make_dynamic_cache( + [(torch.randn(2, 3, 1, 6), torch.randn((2, 3, 6, 6))) for i in range(2)] + ) + + inputs = [dict(cache=cache), dict(cache=cache2)] + mi = ModelInputs(Model(), inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) + self.assertEqual(expected, ds) + + inputs = [{}, dict(cache=cache), dict(cache=cache2)] + mi = ModelInputs(Model(), inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) + self.assertEqual(expected, ds) + + inputs = [{}, dict(cache=cache), dict(cache=cache2)] + mi = ModelInputs(None, inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + expected = ((), {"cache": [{2: DYN}, {2: DYN}, {2: DYN}, {2: DYN}]}) + self.assertEqual(expected, ds) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index beef33cc..556e6297 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -125,7 +125,6 @@ def add_test_methods(cls): "plot_export_locate_issue.py", "plot_export_with_auto.py", "plot_export_tiny_llm.py", - "plot_export_tiny_llm_method_generate.py", } and not has_torch("2.8") ): @@ -155,6 +154,9 @@ def add_test_methods(cls): if not reason and torch.__version__.startswith("2.9.0"): reason = "examples are failing for on CI for 2.9.0" + if not reason and name in {"plot_export_tiny_llm_method_generate.py"}: + reason = "does not work when called in a separate process" + if reason: @unittest.skip(reason) diff --git a/onnx_diagnostic/doc.py b/onnx_diagnostic/doc.py index 3391543c..cf0e22ad 100644 --- a/onnx_diagnostic/doc.py +++ b/onnx_diagnostic/doc.py @@ -1,5 +1,11 @@ -from typing import Optional +import os +import subprocess +import sys +import tempfile +from typing import List, Optional, Tuple, Union import numpy as np +import onnx +from .helpers.dot_helper import to_dot def get_latest_pypi_version(package_name="onnx-diagnostic") -> str: @@ -36,6 +42,15 @@ def reset_torch_transformers(gallery_conf, fname): def plot_legend( text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15 ) -> "matplotlib.axes.Axes": # noqa: F821 + """ + Plots a graph with only text (for :epkg:`sphinx-gallery`). + + :param text: legend + :param text_bottom: text at the bottom + :param color: color + :param fontsize: font size + :return: axis + """ import matplotlib.pyplot as plt fig = plt.figure(figsize=(2, 2)) @@ -66,17 +81,14 @@ def rotate_align(ax, angle=15, align="right"): return ax -def save_fig(ax, name: str): +def save_fig(ax, name: str, **kwargs) -> "matplotlib.axis.Axis": # noqa: F821 """Applies ``tight_layout`` and saves the figures. Returns ax.""" - import matplotlib.pyplot as plt - - plt.tight_layout() fig = ax.get_figure() - fig.savefig(name) + fig.savefig(name, **kwargs) return ax -def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821 +def title(ax: "plt.axes", title: str) -> "matplotlib.axis.Axis": # noqa: F821 "Adds a title to axes and returns them." ax.set_title(title) return ax @@ -88,7 +100,7 @@ def plot_histogram( bins: int = 30, color: str = "orange", alpha: float = 0.7, -) -> "plt.axes": # noqa: F821 +) -> "matplotlib.axis.Axis": # noqa: F821 "Computes the distribution for a tensor." if ax is None: import matplotlib.pyplot as plt @@ -98,3 +110,240 @@ def plot_histogram( ax.hist(tensor, bins=30, color="orange", alpha=0.7) ax.set_yscale("log") return ax + + +def _find_in_PATH(prog: str) -> Optional[str]: + """ + Looks into every path mentioned in ``%PATH%`` a specific file, + it raises an exception if not found. + + :param prog: program to look for + :return: path + """ + sep = ";" if sys.platform.startswith("win") else ":" + path = os.environ["PATH"] + for p in path.split(sep): + f = os.path.join(p, prog) + if os.path.exists(f): + return p + return None + + +def _find_graphviz_dot(exc: bool = True) -> str: + """ + Determines the path to graphviz (on Windows), + the function tests the existence of versions 34 to 45 + assuming it was installed in a standard folder: + ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``. + + :param exc: raise exception of be silent + :return: path to dot + :raises FileNotFoundError: if graphviz not found + """ + if sys.platform.startswith("win"): + version = list(range(34, 60)) + version.extend([f"{v}.1" for v in version]) + for v in version: + graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe" + if os.path.exists(graphviz_dot): + return graphviz_dot + extra = ["build/update_modules/Graphviz/bin"] + for ext in extra: + graphviz_dot = os.path.join(ext, "dot.exe") + if os.path.exists(graphviz_dot): + return graphviz_dot + p = _find_in_PATH("dot.exe") + if p is None: + if exc: + raise FileNotFoundError( + f"Unable to find graphviz, look into paths such as {graphviz_dot}." + ) + return None + return os.path.join(p, "dot.exe") + # linux + return "dot" + + +def _run_subprocess(args: List[str], cwd: Optional[str] = None): + assert not isinstance(args, str), "args should be a sequence of strings, not a string." + + p = subprocess.Popen( + args, + cwd=cwd, + shell=False, + env=os.environ, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + raise_exception = False + output = "" + while True: + output = p.stdout.readline().decode(errors="ignore") + if output == "" and p.poll() is not None: + break + if output: + if ( + "fatal error" in output + or "CMake Error" in output + or "gmake: ***" in output + or "): error C" in output + or ": error: " in output + ): + raise_exception = True + p.poll() + error = p.stderr.readline().decode(errors="ignore") + p.stdout.close() + if error and raise_exception: + raise RuntimeError( + f"An error was found in the output. The build is stopped." + f"\n{output}\n---\n{error}" + ) + return output + "\n" + error + + +def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: + """ + Run :epkg:`Graphviz`. + + :param filename: filename which contains the graph definition + :param image: output image + :param engine: *dot* or *neato* + :return: output of graphviz + """ + ext = os.path.splitext(image)[-1] + assert ext in { + ".png", + ".bmp", + ".fig", + ".gif", + ".ico", + ".jpg", + ".jpeg", + ".pdf", + ".ps", + ".svg", + ".vrml", + ".tif", + ".tiff", + ".wbmp", + }, f"Unexpected extension {ext!r} for {image!r}." + if sys.platform.startswith("win"): + bin_ = os.path.dirname(_find_graphviz_dot()) + # if bin not in os.environ["PATH"]: + # os.environ["PATH"] = os.environ["PATH"] + ";" + bin + exe = os.path.join(bin_, engine) + else: + exe = engine + if os.path.exists(image): + os.remove(image) + cmd = [exe, f"-T{ext[1:]}", filename, "-o", image] + output = _run_subprocess(cmd) + assert os.path.exists(image), ( + f"Unable to find {image!r}, command line is " + f"{' '.join(cmd)!r}, Graphviz failed due to\n{output}" + ) + return output + + +def draw_graph_graphviz( + dot: Union[str, onnx.ModelProto], image: str, engine: str = "dot" +) -> str: + """ + Draws a graph using :epkg:`Graphviz`. + + :param dot: dot graph or ModelProto + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :return: :epkg:`Graphviz` output or + the dot text if *image* is None + + The function creates a temporary file to store the dot file if *image* is not None. + """ + if isinstance(dot, onnx.ModelProto): + sdot = to_dot(dot) + else: + if "{" not in dot: + assert dot.endswith(".onnx"), f"Unexpected file extension for {dot!r}" + proto = onnx.load(dot) + sdot = to_dot(proto) + else: + sdot = dot + assert "{" in sdot, f"This string is not a dot string\n{sdot}" + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(sdot.encode("utf-8")) + fp.close() + + filename = fp.name + assert os.path.exists( + filename + ), f"File {filename!r} cannot be created to store the graph." + out = _run_graphviz(filename, image, engine=engine) + assert os.path.exists( + image + ), f"Graphviz failed with no reason, {image!r} not found, output is {out}." + os.remove(filename) + return out + + +def plot_dot( + dot: Union[str, onnx.ModelProto], + ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 + engine: str = "dot", + figsize: Optional[Tuple[int, int]] = None, +) -> "matplotlib.axis.Axis": # noqa: F821 + """ + Draws a dot graph into a matplotlib graph. + + :param dot: dot graph or ModelProto + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :param figsize: figsize of ax is None + :return: :epkg:`Graphviz` output or, the dot text if *image* is None + + .. plot:: + + import matplotlib.pyplot as plt + import onnx.parser + from onnx_diagnostic.doc import plot_dot + + model = onnx.parser.parse_model( + ''' + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(four, four) + } + ''') + + ax = plot_dot(model) + ax.set_title("Dummy graph") + plt.show() + """ + if ax is None: + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1, figsize=figsize) + clean = True + else: + clean = False + + from PIL import Image + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp: + fp.close() + + draw_graph_graphviz(dot, fp.name, engine=engine) + img = np.asarray(Image.open(fp.name)) + os.remove(fp.name) + + ax.imshow(img) + + if clean: + ax.set_xticks([]) + ax.set_yticks([]) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + ax.set_axis_off() + ax.get_figure().tight_layout() + return ax diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 5d506bb8..73fdc895 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -8,6 +8,7 @@ from .dynamic_shapes import ModelInputs from .onnx_plug import EagerDirectReplacementWithOnnx from ..helpers import flatten_object, max_diff, string_diff, string_type +from ..helpers.cache_helper import CacheKeyValue from ..helpers.torch_helper import torch_deepcopy from ..helpers.rt_helper import make_feeds from ..reference import OnnxruntimeEvaluator @@ -541,6 +542,83 @@ def __init__(self, parent): **self._to_onnx_kwargs, ) + @classmethod + def make_empty_cache_from_others(cls, examples: List["Cache"]) -> "Cache": # noqa: F821 + """Builds an empty cache based on existing one.""" + unique_types = {type(t) for t in examples} + assert ( + len(unique_types) == 1 + ), f"Unable to guess an empty cache from {string_type(examples, with_shape=True)}" + unique_type = unique_types.pop() + if unique_type == torch.Tensor: + shapes = [t.shape for t in examples] + assert len(set(shapes)) > 1, f"Unable to guess an empty shape from shapes {shapes}" + ranks = {len(s) for s in shapes} + assert len(ranks) == 1, f"Ranks are different in {shapes}" + rank = ranks.pop() + new_shape = [] + for i in range(rank): + dims = [t.shape[i] for t in examples] + if len(set(dims)) == 1: + new_shape.append(dims[0]) + else: + # The empty shape + new_shape.append(0) + example = examples[0] + return torch.empty(tuple(new_shape), dtype=example.dtype, device=example.device) + assert ( + unique_type.__name__ == "DynamicCache" + ), f"This is not implemented for class {unique_type}" + caches = [CacheKeyValue(dc) for dc in examples] + caches_list = [dc.aslist() for dc in caches] + empty = [ + cls.make_empty_cache_from_others([caches_list[i][k] for i in range(len(examples))]) + for k in range(len(caches_list[0])) + ] + empty_cache = CacheKeyValue( + empty, cls_layers=caches[0].cls_layers + ).make_dynamic_cache() + return empty_cache + + @classmethod + def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]: + """ + Adds empty cache if needed as onnxruntime needs an empty cache, + not a missing cache. It only works if inputs are defined as a dictionary. + """ + if all(isinstance(t, tuple) for t in inputs) and all( + len(t) == 2 and isinstance(t[0], tuple) and isinstance(t[1], dict) and not t[0] + for t in inputs + ): + dict_part = [t[1] for t in inputs] + res = cls.add_empty_cache_if_needed(dict_part) + return [(tuple(), d) for d in res] + if any(not isinstance(t, dict) for t in inputs): + return inputs + all_keys = set() + for input_set in inputs: + all_keys |= set(input_set) + # even though the inputs are defined as a dictionary, it is better + # to keep the same order + ordered = None + for input_set in inputs: + if set(input_set) == all_keys: + ordered = list(input_set) + break + new_inputs = [] + for input_set in inputs: + if set(input_set) == all_keys: + new_inputs.append(input_set) + continue + missing = {k for k in all_keys if k not in input_set} + input_set_copy = input_set.copy() + for miss in missing: + input_set_copy[miss] = cls.make_empty_cache_from_others( + [sub[miss] for sub in inputs if miss in sub] + ) + new_inputs.append({k: input_set_copy[k] for k in ordered}) + return new_inputs + def check_discrepancies( self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0 ) -> List[Dict[str, Union[str, int, float]]]: @@ -555,6 +633,7 @@ def check_discrepancies( :param verbose: verbosity :return: results, a list of dictionaries, ready to be consumed by a dataframe """ + assert self._export_done, "The onnx export was not done." assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" assert os.path.exists( @@ -590,7 +669,9 @@ def check_discrepancies( if verbose: print(f"[method_to_onnx.check_discrepancies] input_names={input_names}") data = [] - for i, (input, (output, latency)) in enumerate(zip(inputs, outputs)): + for i, (input, (output, latency)) in enumerate( + zip(self.add_empty_cache_if_needed(inputs), outputs) + ): if verbose: if verbose > 1: print( @@ -602,7 +683,10 @@ def check_discrepancies( f"{string_type(output, with_shape=True)}" ) else: - print(f"[method_to_onnx.check_discrepancies] process input {i}") + print( + f"[method_to_onnx.check_discrepancies] process input {i} " + f"#args={len(input[0])} #kwargs={len(input[1])}" + ) flat_inputs = flatten_object(input, drop_keys=True) if len(flat_inputs) < len(input_names): diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 767c4f1f..0e705661 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -19,7 +19,7 @@ class CacheKeyValue: capi.value_cache """ - def __init__(self, cache=None): + def __init__(self, cache=None, cls_layers=None): if hasattr(cache, "layers"): layers = [ layer @@ -28,15 +28,26 @@ def __init__(self, cache=None): ] self.key_cache = [layer.keys for layer in layers] self.value_cache = [layer.values for layer in layers] + assert ( + cls_layers is None + ), f"cache is {type(cache)}, cannot specify cls_layers={cls_layers}" self.cls_layers = [type(lay) for lay in cache.layers] elif cache is not None and hasattr(cache, "key_cache"): self.key_cache = cache.key_cache self.value_cache = cache.value_cache - self.cls_layers = None + self.cls_layers = cls_layers + elif ( + cache is not None + and isinstance(cache, list) + and all(isinstance(t, torch.Tensor) for t in cache) + ): + self.key_cache = cache[::2] + self.value_cache = cache[1::2] + self.cls_layers = cls_layers elif cache is None: self.key_cache = None self.value_cache = None - self.cls_layers = None + self.cls_layers = cls_layers else: raise NotImplementedError(f"type(cache)={type(cache)}") @@ -51,6 +62,18 @@ def n_layers(self) -> int: """Returns the number of layers.""" return len(self.key_cache) if self.key_cache else 0 + def __len__(self) -> int: + "Returns the number of tensors." + return len(self.key_cache) + len(self.value_cache) + + def aslist(self) -> List[torch.Tensor]: + "Returns tensors in a list." + res = [] + for i in range(self.n_layers): + res.append(self.key_cache[i]) + res.append(self.value_cache[i]) + return res + def flatten_unflatten_for_dynamic_shapes( obj: Any,