Skip to content

XNNPACK backend: integer tensor operations not supported? #16896

@IgorSwat

Description

@IgorSwat

🐛 Describe the bug

While exporting one of the PyTorch models to ExecuTorch XNNPACK format, I stumbled upon a weird issue with operations that takes at least one integer tensor as an argument.
As it turns out, these operations are not being delegated to the XNNPACK backend at all - or at least I was not able to do so. And as I already observed with other models, non-delegated operators can lead to a significant performance downgrade.

Steps to reproduce

To simplify the example, I created a dummy model which takes 1 integer tensor and performs any operation on it (for example, a simple multiplication). Then I exported it and printed delegation info. The full script is available below:

import argparse
import torch
import torch.nn as nn
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.devtools.backend_debug import print_delegation_info
from executorch.exir import to_edge_transform_and_lower
from torch.export import Dim


class DummyModel(nn.Module):
  def forward(self, x: torch.Tensor):
    return x * 2  # Not delegated
    # return x * torch.full_like(x, 2, dtype=torch.long)  # Not delegated
    # return x.float() * 2  # Delegated


ENABLE_DYNAMIC_SHAPES = False
MAX_INPUT_SIZE = 1024

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Export Gemma model")
  parser.add_argument(
      "--output",
      type=str,
      required=True,
      help="Output model filepath"
  )

  args = parser.parse_args()
  output = args.output

  model = DummyModel()
  model.eval()

  dynamic_shapes = None
  if ENABLE_DYNAMIC_SHAPES:
    t = Dim("t", min=1, max=MAX_INPUT_SIZE)

  inputs = (
    torch.randint(0, 128, (1, 128), dtype=torch.int8),
  )

  exported_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
  executorch_program = to_edge_transform_and_lower(
      exported_program,
      partitioner = [XnnpackPartitioner()]
  ).to_executorch()

  print_delegation_info(executorch_program.exported_program().graph_module)

  print("Exporting finished!")

Actual behavior

The printed graph looks somewhat like the one below, indicating that the multiplication operation occurs in non-delegated graph (is not properly delegated to the backend):

# op_type occurrences_in_delegated_graphs occurrences_in_non_delegated_graphs
0 alloc 0 2
1 aten_mul_tensor 0 1
2 dim_order_ops__to_dim_order_copy_default 0 1
3 Total 0 4

What I tried

  • Using other integer types than torch.long - no effect.
  • Replacing multiplication expression with more explicit tensor multiplication (as shown in the second return statement in the code example) - no effect.
  • Casting integer tensor to float tensor - fixed the issue, altough it's not always possible to do so and kinda defeats the purpose of using integer tensors.

Versions

Collecting environment information...
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: version 3.31.10
Libc version: N/A

Python version: 3.12.11 (main, Jun 3 2025, 15:41:47) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-26.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] executorch==1.0.1a0+7b220c6
[pip3] numpy==2.4.1
[pip3] optimum-executorch==0.1.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.9.0
[pip3] torchao==0.14.0+git02941240f
[pip3] torchaudio==2.9.0
[pip3] torchdata==0.11.0+cpu
[pip3] torchsr==1.0.4
[pip3] torchtune==0.7.0+cpu
[pip3] torchvision==0.24.0
[conda] Could not collect

cc @GregoryComer @digantdesai @cbilgin

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions