Skip to content

Shared Mutable Buffer Issue #17005

@gtyukasz

Description

@gtyukasz

🐛 Describe the bug

Exporting the module used in the shared mutable buffer test and executing its entry points with the runtime suggests that the set_state entry point does not actually share the state buffer.

Example based on the test code (modified the sample input shape of the set_state argument in export to match that of state) that executes methods of the exported module and the eager one, printing outputs:

import executorch.runtime
import torch

from executorch.exir import ExecutorchBackendConfig, to_edge
from executorch.exir.capture._capture import patch_forward
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from torch.export import export
from typing import Tuple


class MultiEntryPointStatefulModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.register_buffer("state", torch.zeros(2, 2))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.state.add_(x).view(-1) * 2

    def set_state(self, state: torch.Tensor) -> torch.Tensor:
        self.state.copy_(state)
        return self.state

    def get_state(self) -> torch.Tensor:
        return self.state

    def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
        return (torch.ones(1),)


module = MultiEntryPointStatefulModel().eval()
forward = export(module, module.get_example_inputs())
with patch_forward(module, module.get_state):
    get_state = export(module, ())
with patch_forward(module, module.set_state):
    set_state = export(module, (torch.zeros(2, 2),))
edge = to_edge(
    {"forward": forward, "set_state": set_state, "get_state": get_state}
)
et = edge.to_executorch(
    ExecutorchBackendConfig(
        # passes=[InitializedMutableBufferPass(list(module._buffers.keys()))],
        memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True),
        emit_mutable_buffer_names=True
    )
)

rt_prog = executorch.runtime.Runtime.get().load_program(
    et.buffer, verification=executorch.runtime.Verification.Minimal
)
rt_forward = rt_prog.load_method('forward')
rt_set_state = rt_prog.load_method('set_state')
rt_get_state = rt_prog.load_method('get_state')

print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())

fwd_in = torch.ones(1)
print(f'\nforward({fwd_in})')
print(rt_forward.execute([fwd_in]))
print('expected')
print(module.forward(fwd_in))

print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())

set_state_in = torch.empty(2, 2).fill_(13)
print(f'\nset_state({set_state_in})')
print(rt_set_state.execute([set_state_in]))
print('expected')
print(module.set_state(set_state_in))

print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())

print(f'\nforward({fwd_in})')
print(rt_forward.execute([fwd_in]))
print('expected')
print(module.forward(fwd_in))

print('\nget_state()')
print(rt_get_state.execute([]))
print('expected')
print(module.get_state())

Output:

get_state()
[tensor([[0., 0.],
        [0., 0.]])]
expected
tensor([[0., 0.],
        [0., 0.]])

forward(tensor([1.]))
[tensor([2., 2., 2., 2.])]
expected
tensor([2., 2., 2., 2.])

get_state()
[tensor([[1., 1.],
        [1., 1.]])]
expected
tensor([[1., 1.],
        [1., 1.]])

set_state(tensor([[13., 13.],
        [13., 13.]]))
[tensor([[13., 13.],
        [13., 13.]])]
expected
tensor([[13., 13.],
        [13., 13.]])

get_state()
[tensor([[1., 1.],
        [1., 1.]])]
expected
tensor([[13., 13.],
        [13., 13.]])

forward(tensor([1.]))
[tensor([4., 4., 4., 4.])]
expected
tensor([28., 28., 28., 28.])

get_state()
[tensor([[2., 2.],
        [2., 2.]])]
expected
tensor([[14., 14.],
        [14., 14.]])

The "expected" outputs of the eager module should match those of the exported, but set_state does not take effect in forward and get_state.

Versions

PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.12 (main, Oct 9 2025, 11:07:00) [Clang 17.0.0 (clang-1700.3.19.1)] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] executorch==1.1.0
[pip3] numpy==2.4.1
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Status

To triage

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions