-
Notifications
You must be signed in to change notification settings - Fork 820
Description
🐛 Describe the bug
I'm seeing that Pytorch's PReLU function is not supported for Ethos-U backend. The forward call is broken down to torch.where(x>0, x, self.weights * x) which isn't supported.
Reproduce:
import logging
import torch
import torch.nn as nn
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import EthosUQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import \
get_symmetric_quantization_config as get_arm_symmetric_qconfig
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge_transform_and_lower
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
class PreluModel(torch.nn.Module):
def __init__(self, in_c, out_c, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
super(PreluModel, self).__init__()
self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
bias=False)
self.bn = nn.BatchNorm2d(out_c)
self.prelu = nn.PReLU(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.prelu(x)
return x
log_file = './log/partitioner.log'
logging.basicConfig(
handlers=[logging.FileHandler(log_file, 'w'), logging.StreamHandler()],
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO,
force=True,
)
device = 'cpu'
example_input = (torch.randn(1, 3, 224, 224, device=device), )
float_model = PreluModel(in_c=3, out_c=64).eval().to(device)
# Quantize & Lower
exported_program = torch.export.export(float_model, example_input, strict=True)
graph_module = exported_program.module(check_guards=False)
compile_spec = EthosUCompileSpec(target="ethos-u85-2048")
quantizer = EthosUQuantizer(compile_spec)
operator_config = get_arm_symmetric_qconfig(is_per_channel=True)
quantizer.set_global(operator_config)
prepared = prepare_pt2e(graph_module, quantizer)
prepared(*example_input)
quantized_graph_module = convert_pt2e(prepared, fold_quantize=True)
quantized_exported_program = torch.export.export(quantized_graph_module, example_input)
partitioner = EthosUPartitioner(compile_spec)
edge_program_manager = to_edge_transform_and_lower(
quantized_exported_program,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
executorch_program_manager = edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
The partitioner.log output
2026-01-27 13:04:35,119 - INFO - TOSAPartitioner::partition
2026-01-27 13:04:35,119 - INFO - Partitioning for EthosUBackend: TOSA-1.0+INT+int16
2026-01-27 13:04:35,119 - INFO - The following nodes were rejected for TOSA-1.0+INT+int16:
2026-01-27 13:04:35,120 - INFO -
╒════════════════════════╤════════════════════════╤═════════════════════════════════════╤═════════════════════════════════════╕
│ Node name │ Target │ Torch func │ Reason │
╞════════════════════════╪════════════════════════╪═════════════════════════════════════╪═════════════════════════════════════╡
│ aten_where_self │ aten.where.self │ ('prelu_1', │ Tensor x dtype torch.float32 and/or │
│ │ │ 'builtin_function_or_method.prelu') │ tensor y dtype torch.float32 is not │
│ │ │ │ supported in <EdgeOpOverload: │
│ │ │ │ aten.where.self>: schema = │
│ │ │ │ aten::where.self(Tensor condition, │
│ │ │ │ Tensor self, Tensor other) -> │
│ │ │ │ Tensor for tosa specification │
│ │ │ │ TOSA-1.0+INT+int16 │
├────────────────────────┼────────────────────────┼─────────────────────────────────────┼─────────────────────────────────────┤
│ aten_mul_tensor │ aten.mul.Tensor │ ('prelu_1', │ One or more inputs were not │
│ │ │ 'builtin_function_or_method.prelu') │ quantized. │
├────────────────────────┼────────────────────────┼─────────────────────────────────────┼─────────────────────────────────────┤
│ aten_view_copy_default │ aten.view_copy.default │ ('prelu_1', │ Was first node in partition and │
│ │ │ 'builtin_function_or_method.prelu') │ input p__param_constant2 had fp │
│ │ │ │ dtype. │
├────────────────────────┼────────────────────────┼─────────────────────────────────────┼─────────────────────────────────────┤
│ aten_gt_scalar │ aten.gt.Scalar │ ('prelu_1', │ Was first node in partition and │
│ │ │ 'builtin_function_or_method.prelu') │ input quantized_decomposed_dequanti │
│ │ │ │ ze_per_tensor_default_1 had fp │
│ │ │ │ dtype. │
╘════════════════════════╧════════════════════════╧═════════════════════════════════════╧═════════════════════════════════════╛
2026-01-27 13:04:35,120 - INFO - (Placeholders and outputs are not included in this list)
2026-01-27 13:04:35,122 - INFO - EthosUBackend preprocess
2026-01-27 13:04:35,122 - INFO - Converting ExportedProgram to TOSA: TOSA-1.0+INT+int16
Fix:
Substitute the PReLU implementation to be
class CustomPReLU(nn.Module):
def __init__(self, weight):
super(CustomPReLU, self).__init__()
self.weight = weight
def forward(self, x):
weight = self.weight.view(1, -1, 1, 1,)
positive = torch.clamp(x, min=0.0)
negative = torch.clamp(x, max=0.0)
return positive + weight * negative
and it will be supported on Ethos-U backend.
Versions
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 4.2.0
Libc version: N/A
Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] executorch==1.0.1
[pip3] numpy==2.4.0
[pip3] onnx==1.20.0
[pip3] onnx-ir==0.1.14
[pip3] onnxscript==0.5.7
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.9.1
[pip3] torchao==0.14.0
[pip3] torchaudio==2.9.1
[pip3] torchsampler==0.1.2
[pip3] torchvision==0.24.0
[conda] Could not collect