diff --git a/exir/passes/scalar_to_tensor_pass.py b/exir/passes/scalar_to_tensor_pass.py index 6dd80cd577d..a31d628662c 100644 --- a/exir/passes/scalar_to_tensor_pass.py +++ b/exir/passes/scalar_to_tensor_pass.py @@ -7,10 +7,14 @@ # pyre-strict import torch -from executorch.exir.pass_base import ExportPass, map_args +from executorch.exir.pass_base import ExportPass, map_args, PassResult class ScalarToTensorPass(ExportPass): + def __init__(self) -> None: + super().__init__() + self._modified: bool = False + # pyre-ignore def call_operator(self, op, args, kwargs, meta): # pyre-ignore @@ -21,12 +25,20 @@ def try_coerce(value, arg): # get a constant tensor with torch.tensor() call but instead # a fake tensor is created. with torch.utils._python_dispatch._disable_current_modes(): - return ( - torch.tensor(value) - if isinstance(value, (float, int, bool)) - and isinstance(arg.type, torch.TensorType) - else value - ) + if isinstance(value, (float, int, bool)) and isinstance( + arg.type, torch.TensorType + ): + self._modified = True + return torch.tensor(value) + + return value args, kwargs = map_args(op, try_coerce, args, kwargs) return super().call_operator(op, args, kwargs, meta) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self._modified = False + graph_module = super().call(graph_module).graph_module + modified = self._modified + self._modified = False + return PassResult(graph_module, modified) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 6fecaeaa310..452f9694a8d 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -85,6 +85,9 @@ from torch import nn from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND + +# Import passes +from torch._subclasses.fake_tensor import FakeTensorMode from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter @@ -539,28 +542,40 @@ class NullPass(ExportPass): self.assertEqual(new_node.target, old_node.target) def test_export_scalar_to_tensor_pass(self) -> None: - class Mul(torch.nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * 3.14 + # Build a graph with a scalar argument where schema expects tensor + graph = torch.fx.Graph() + test_input = torch.randn( + 1, + ) + with FakeTensorMode() as fake_mode: + fake_input = fake_mode.from_tensor(test_input) + x = graph.placeholder("x") + x.meta["val"] = fake_input - mul = Mul() + # Pass 3.14 as scalar - this should be converted to tensor by the pass + mul_node = graph.call_function( + torch.ops.aten.mul.Tensor, + args=(x, 3.14), + ) + graph.output(mul_node) - expo_prog = to_edge(export(mul, (torch.ones(1),), strict=True)) - new_prog = expo_prog.transform([ScalarToTensorPass()]) - self.assertIsNotNone(new_prog.exported_program().graph_module) - new_graph_module = new_prog.exported_program().graph_module + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + original_output = gm(test_input) - inp = torch.zeros(1) - self.assertTrue( - torch.allclose( - expo_prog.exported_program().module()(inp), - new_prog.exported_program().module()(inp), - ) - ) - for node in new_graph_module.graph.nodes: + result = ScalarToTensorPass()(gm) + new_gm = result.graph_module + self.assertTrue(result.modified) + # All scalars should be tensors by this point, so running a second time should not modify + self.assertFalse(ScalarToTensorPass()(new_gm).modified) + + # All scalars should be converted into nodes + for node in new_gm.graph.nodes: if node.op == "call_function": - for arg in node.args + tuple(node.kwargs.values()): - self.assertFalse(isinstance(arg, float)) + for arg in node.args: + self.assertTrue(isinstance(arg, torch.fx.Node)) + + new_output = new_gm(test_input) + self.assertTrue(torch.equal(original_output, new_output)) def test_remove_mixed_types_symfloats(self) -> None: class Foo(torch.nn.Module):