-
Notifications
You must be signed in to change notification settings - Fork 819
Description
🐛 Describe the bug
Exporting the module used in the shared mutable buffer test and executing its entry points with the runtime suggests that the set_state entry point does not actually share the state buffer.
Example based on the test code (modified the sample input shape of the set_state argument in export to match that of state) that executes methods of the exported module and the eager one, printing outputs:
import executorch.runtime
import torch
from executorch.exir import ExecutorchBackendConfig, to_edge
from executorch.exir.capture._capture import patch_forward
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from torch.export import export
from typing import Tuple
class MultiEntryPointStatefulModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("state", torch.zeros(2, 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.state.add_(x).view(-1) * 2
def set_state(self, state: torch.Tensor) -> torch.Tensor:
self.state.copy_(state)
return self.state
def get_state(self) -> torch.Tensor:
return self.state
def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.ones(1),)
module = MultiEntryPointStatefulModel().eval()
forward = export(module, module.get_example_inputs())
with patch_forward(module, module.get_state):
get_state = export(module, ())
with patch_forward(module, module.set_state):
set_state = export(module, (torch.zeros(2, 2),))
edge = to_edge(
{"forward": forward, "set_state": set_state, "get_state": get_state}
)
et = edge.to_executorch(
ExecutorchBackendConfig(
# passes=[InitializedMutableBufferPass(list(module._buffers.keys()))],
memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True),
emit_mutable_buffer_names=True
)
)
rt_prog = executorch.runtime.Runtime.get().load_program(
et.buffer, verification=executorch.runtime.Verification.Minimal
)
rt_forward = rt_prog.load_method('forward')
rt_set_state = rt_prog.load_method('set_state')
rt_get_state = rt_prog.load_method('get_state')
print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())
fwd_in = torch.ones(1)
print(f'\nforward({fwd_in})')
print(rt_forward.execute([fwd_in]))
print('expected')
print(module.forward(fwd_in))
print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())
set_state_in = torch.empty(2, 2).fill_(13)
print(f'\nset_state({set_state_in})')
print(rt_set_state.execute([set_state_in]))
print('expected')
print(module.set_state(set_state_in))
print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())
print(f'\nforward({fwd_in})')
print(rt_forward.execute([fwd_in]))
print('expected')
print(module.forward(fwd_in))
print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())Output:
get_state()
[tensor([[0., 0.],
[0., 0.]])]
expected
tensor([[0., 0.],
[0., 0.]])
forward(tensor([1.]))
[tensor([2., 2., 2., 2.])]
expected
tensor([2., 2., 2., 2.])
get_state()
[tensor([[1., 1.],
[1., 1.]])]
expected
tensor([[1., 1.],
[1., 1.]])
set_state(tensor([[13., 13.],
[13., 13.]]))
[tensor([[13., 13.],
[13., 13.]])]
expected
tensor([[13., 13.],
[13., 13.]])
get_state()
[tensor([[1., 1.],
[1., 1.]])]
expected
tensor([[13., 13.],
[13., 13.]])
forward(tensor([1.]))
[tensor([4., 4., 4., 4.])]
expected
tensor([28., 28., 28., 28.])
get_state()
[tensor([[2., 2.],
[2., 2.]])]
expected
tensor([[14., 14.],
[14., 14.]])
The "expected" outputs of the eager module should match those of the exported, but set_state does not take effect in forward and get_state.
Versions
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.12 (main, Oct 9 2025, 11:07:00) [Clang 17.0.0 (clang-1700.3.19.1)] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M4 Pro
Versions of relevant libraries:
[pip3] executorch==1.1.0
[pip3] numpy==2.4.1
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
Type
Projects
Status