diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index d3671fb5..7c166329 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.9 +++++ +* :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx`` * :pr:`378`: implements the computation of discrepancies in ``method_to_onnx`` * :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``, diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 547ac589..b1395cf3 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -306,6 +306,84 @@ def test_add_empty_cache_if_needed_args_kwargs(self): self.string_type(with_empty[0][1], with_shape=True), ) + @requires_experimental_experiment("0.1") + def test_method_to_onnx_expand_batch_dimension(self): + class Model(torch.nn.Module): + def forward(self, x=None, y=None): + return x + y + + filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx") + inputs = [ + dict(x=torch.randn((1, 5, 6)), y=torch.randn((1, 1, 6))), + dict(x=torch.randn((1, 7, 7)), y=torch.randn((1, 1, 7))), + ] + model = Model() + method_to_call = method_to_onnx( + model, exporter="custom", filename=filename, expand_batch_for={"x", "y"} + ) + expecteds = [] + for kwargs in inputs: + expecteds.append(method_to_call(**kwargs)) + self.assertExists(filename) + sess = self.check_ort(filename) + input_names = [i.name for i in sess.get_inputs()] + self.assertEqual(["x", "y"], input_names) + input_shapes = [i.shape for i in sess.get_inputs()] + self.assertEqual([[2, "channel", "D0"], [2, 1, "D0_1"]], input_shapes) + self.clean_dump() + + @requires_experimental_experiment("0.1") + @requires_transformers("4.57") + def test_method_to_onnx_expand_batch_dimension_dynamic_cache(self): + class Model(torch.nn.Module): + def forward(self, x=None, cache=None): + return x + cache.layers[0].keys + + filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx") + inputs = [ + dict( + x=torch.randn((1, 1, 3, 4)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 3, 4), torch.randn(1, 2, 3, 4)) for i in range(2)] + ), + ), + dict( + x=torch.randn((1, 1, 5, 4)), + cache=make_dynamic_cache( + [(torch.randn(1, 2, 5, 4), torch.randn(1, 2, 5, 4)) for i in range(2)] + ), + ), + ] + model = Model() + method_to_call = method_to_onnx( + model, + exporter="custom", + filename=filename, + expand_batch_for={"x", "cache"}, + patch_kwargs=dict(patch_transformers=True), + ) + expecteds = [] + for kwargs in inputs: + expecteds.append(method_to_call(**kwargs)) + self.assertExists(filename) + sess = self.check_ort(filename) + input_names = [i.name for i in sess.get_inputs()] + self.assertEqual( + ["x", "cache_key_0", "cache_value_0", "cache_key_1", "cache_value_1"], input_names + ) + input_shapes = [i.shape for i in sess.get_inputs()] + self.assertEqual( + [ + [2, 1, "D0", 4], + [2, 2, "D0_1", 4], + [2, 2, "D0_2", 4], + [2, 2, "D0_3", 4], + [2, 2, "D0_4", 4], + ], + input_shapes, + ) + self.clean_dump() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/doc.py b/onnx_diagnostic/doc.py index cf0e22ad..4a19ca9a 100644 --- a/onnx_diagnostic/doc.py +++ b/onnx_diagnostic/doc.py @@ -178,7 +178,7 @@ def _run_subprocess(args: List[str], cwd: Optional[str] = None): raise_exception = False output = "" while True: - output = p.stdout.readline().decode(errors="ignore") + output = p.stdout.readline().decode(errors="ignore") # type: ignore[union-attr] if output == "" and p.poll() is not None: break if output: @@ -191,8 +191,9 @@ def _run_subprocess(args: List[str], cwd: Optional[str] = None): ): raise_exception = True p.poll() - error = p.stderr.readline().decode(errors="ignore") - p.stdout.close() + error = p.stderr.readline().decode(errors="ignore") # type: ignore[union-attr] + p.stdout.close() # type: ignore[union-attr] + p.stderr.close() # type: ignore[union-attr] if error and raise_exception: raise RuntimeError( f"An error was found in the output. The build is stopped." diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 73fdc895..2c642dfe 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -349,6 +349,7 @@ def __init__( patch_kwargs: Optional[Dict[str, Any]] = None, skip_kwargs_names: Optional[Set[str]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + expand_batch_for: Optional[Sequence[Union[int, str]]] = None, ): super().__init__() self._model_to_call = mod @@ -366,6 +367,7 @@ def __init__( self.verbose = verbose self.skip_kwargs_names = skip_kwargs_names self.dynamic_shapes = dynamic_shapes + self.expand_batch_for = expand_batch_for self._to_onnx_kwargs = dict( input_names=input_names, target_opset=target_opset, @@ -386,9 +388,11 @@ def __init__( self._serialization_classes: Set[type] = set() def __str__(self) -> str: + "usual" return self.__repr__() def __repr__(self) -> str: + "usual" return ( f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}." f"{self._method_name})" @@ -415,22 +419,17 @@ def _collect_classes(self, obj): def forward(self, *args, **kwargs): if not self._export_done: - self._inputs.append( - torch_deepcopy( - ( - args, - ( - kwargs - if not kwargs or not self.skip_kwargs_names - else { - k: v - for k, v in kwargs.items() - if k not in self.skip_kwargs_names - } - ), - ) - ) + inp_args = args + inp_kwargs = ( + kwargs + if not kwargs or not self.skip_kwargs_names + else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names} ) + if self.expand_batch_for: + inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for) + inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for) + inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs)) + self._inputs.append((inp_args, inp_kwargs)) if self.verbose: print( f"[method_to_onnx] input[{len(self._inputs)-1}]: " @@ -543,7 +542,7 @@ def __init__(self, parent): ) @classmethod - def make_empty_cache_from_others(cls, examples: List["Cache"]) -> "Cache": # noqa: F821 + def make_empty_cache_from_others(cls, examples: List[Any]) -> Any: """Builds an empty cache based on existing one.""" unique_types = {type(t) for t in examples} assert ( @@ -616,9 +615,50 @@ def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]: input_set_copy[miss] = cls.make_empty_cache_from_others( [sub[miss] for sub in inputs if miss in sub] ) - new_inputs.append({k: input_set_copy[k] for k in ordered}) + new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr] return new_inputs + @classmethod + def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any: + expand_for_args = {i for i in expand_for if isinstance(i, int)} + expand_for_kwargs = {i for i in expand_for if isinstance(i, str)} + if isinstance(obj, tuple): + return tuple( + o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i) + for i, o in enumerate(obj) + ) + assert isinstance(obj, dict), f"Unexpected type {type(obj)}" + return { + k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k) + for k, v in obj.items() + } + + @classmethod + def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any: + if isinstance(obj, torch.Tensor): + assert obj.shape[0] == 1, ( + f"Are you sure to expoand input {msg!r}, " + f"batch size is not 1 and shape={obj.shape}" + ) + sizes = [2, *obj.shape[1:]] + return obj.expand(*sizes) + if isinstance(obj, list): + return [ + cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj) + ] + if obj.__class__.__name__ == "DynamicCache": + dc = CacheKeyValue(obj) + flat = dc.aslist() + flat = cls._expand_batch_dimension_input(flat, msg) + return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache() + # This might end up in an infinite loop if no registration is done. + flat, _spec = torch.utils._pytree.tree_flatten(obj) + assert ( + not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj) + ), f"class {type(obj)} was is not registered for serialization." + flat = cls._expand_batch_dimension_input(flat, msg) + return torch.utils._pytree.tree_unflatten(flat, _spec) + def check_discrepancies( self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0 ) -> List[Dict[str, Union[str, int, float]]]: @@ -633,7 +673,6 @@ def check_discrepancies( :param verbose: verbosity :return: results, a list of dictionaries, ready to be consumed by a dataframe """ - assert self._export_done, "The onnx export was not done." assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" assert os.path.exists( @@ -753,6 +792,7 @@ def method_to_onnx( patch_kwargs: Optional[Dict[str, Any]] = None, skip_kwargs_names: Optional[Set[str]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + expand_batch_for: Optional[Sequence[Union[int, str]]] = None, ) -> Callable: """ Exports one method into ONNX for a module into ONNX. @@ -782,6 +822,10 @@ def method_to_onnx( :param skip_kwargs_names: use default values for these parameters part of the signature of the method to export :param dynamic_shapes: dynamic shapes to use if the guessed ones are not right + :param expand_batch_for: LLM are usually called with a batch size equal to 1, + but the export may benefit from having another value for the batch size, + this parameter forces the input specified in this set to be expanded + to 2 if the batch size is one :return: the output of the selected exporter, usually a structure including an onnx model @@ -808,5 +852,6 @@ def method_to_onnx( patch_kwargs=patch_kwargs, skip_kwargs_names=skip_kwargs_names, dynamic_shapes=dynamic_shapes, + expand_batch_for=expand_batch_for, ) return wrapped_model