Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions exir/passes/scalar_to_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
# pyre-strict

import torch
from executorch.exir.pass_base import ExportPass, map_args
from executorch.exir.pass_base import ExportPass, map_args, PassResult


class ScalarToTensorPass(ExportPass):
def __init__(self) -> None:
super().__init__()
self._modified: bool = False

# pyre-ignore
def call_operator(self, op, args, kwargs, meta):
# pyre-ignore
Expand All @@ -21,12 +25,20 @@ def try_coerce(value, arg):
# get a constant tensor with torch.tensor() call but instead
# a fake tensor is created.
with torch.utils._python_dispatch._disable_current_modes():
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and isinstance(arg.type, torch.TensorType)
else value
)
if isinstance(value, (float, int, bool)) and isinstance(
arg.type, torch.TensorType
):
self._modified = True
return torch.tensor(value)

return value

args, kwargs = map_args(op, try_coerce, args, kwargs)
return super().call_operator(op, args, kwargs, meta)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self._modified = False
graph_module = super().call(graph_module).graph_module
modified = self._modified
self._modified = False
return PassResult(graph_module, modified)
51 changes: 33 additions & 18 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@

from torch import nn
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND

# Import passes
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import export
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
from torch.fx import GraphModule, subgraph_rewriter
Expand Down Expand Up @@ -539,28 +542,40 @@ class NullPass(ExportPass):
self.assertEqual(new_node.target, old_node.target)

def test_export_scalar_to_tensor_pass(self) -> None:
class Mul(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * 3.14
# Build a graph with a scalar argument where schema expects tensor
graph = torch.fx.Graph()
test_input = torch.randn(
1,
)
with FakeTensorMode() as fake_mode:
fake_input = fake_mode.from_tensor(test_input)
x = graph.placeholder("x")
x.meta["val"] = fake_input

mul = Mul()
# Pass 3.14 as scalar - this should be converted to tensor by the pass
mul_node = graph.call_function(
torch.ops.aten.mul.Tensor,
args=(x, 3.14),
)
graph.output(mul_node)

expo_prog = to_edge(export(mul, (torch.ones(1),), strict=True))
new_prog = expo_prog.transform([ScalarToTensorPass()])
self.assertIsNotNone(new_prog.exported_program().graph_module)
new_graph_module = new_prog.exported_program().graph_module
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
original_output = gm(test_input)

inp = torch.zeros(1)
self.assertTrue(
torch.allclose(
expo_prog.exported_program().module()(inp),
new_prog.exported_program().module()(inp),
)
)
for node in new_graph_module.graph.nodes:
result = ScalarToTensorPass()(gm)
new_gm = result.graph_module
self.assertTrue(result.modified)
# All scalars should be tensors by this point, so running a second time should not modify
self.assertFalse(ScalarToTensorPass()(new_gm).modified)

# All scalars should be converted into nodes
for node in new_gm.graph.nodes:
if node.op == "call_function":
for arg in node.args + tuple(node.kwargs.values()):
self.assertFalse(isinstance(arg, float))
for arg in node.args:
self.assertTrue(isinstance(arg, torch.fx.Node))

new_output = new_gm(test_input)
self.assertTrue(torch.equal(original_output, new_output))

def test_remove_mixed_types_symfloats(self) -> None:
class Foo(torch.nn.Module):
Expand Down
Loading