From 1e69d921783639f1fc20b5e3bbb232953b5db261 Mon Sep 17 00:00:00 2001 From: "Yu-Hsuan (Amy) Lin" Date: Wed, 4 Feb 2026 10:40:27 +0000 Subject: [PATCH] Correct fp4 tensor size calculation The new utility will use jnp.finfo and jnp.iinfo to determine the accurate bit width of any dtype, ensuring correct bandwidth metrics for current and future sub-byte types (like int4 or float4). --- Ironwood/src/benchmark_collectives.py | 6 ++++-- Ironwood/src/benchmark_hbm.py | 3 ++- Ironwood/src/benchmark_send_recv.py | 7 ++++--- Ironwood/src/benchmark_utils.py | 12 ++++++++++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 69b4d212..08890385 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -11,6 +11,7 @@ from benchmark_utils import MetricsStatistics from benchmark_utils import multiple_iteration_timeit_from_trace from benchmark_utils import ShardingStrategy +from benchmark_utils import get_real_dtype_bytes from common import MARKER import jax from jax import core @@ -72,7 +73,7 @@ def get_metrics_helper( for key, value in params if value is not None and key not in exclude_keys } - metadata["dtype"] = metadata["dtype"].dtype.itemsize + metadata["dtype"] = get_real_dtype_bytes(metadata["dtype"].dtype) return metadata @@ -98,7 +99,8 @@ def unified_ici_collectives_metrics( hlo_first_replica_group = [] input_num_elements = matrix_shape[0] * matrix_shape[1] * matrix_shape[2] - dtype_bytes = dtype.dtype.itemsize + dtype_name = dtype.dtype.name + dtype_bytes = get_real_dtype_bytes(dtype.dtype) if xla_output: xla_output_json = json.loads(xla_output) hlo_input_shape = xla_output_json.get("hlo_input_shape") diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index bb279f42..67f0429b 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -6,6 +6,7 @@ from benchmark_utils import ( MetricsStatistics, multiple_iteration_timeit_from_trace, + get_real_dtype_bytes, ) from common import MARKER import jax @@ -76,7 +77,7 @@ def single_device_hbm_copy_calculate_metrics( metrics = {} # Calculate throughput. - tensor_size_bytes = num_elements * dtype.dtype.itemsize + tensor_size_bytes = num_elements * get_real_dtype_bytes(dtype.dtype) tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9 time_statistics = MetricsStatistics( diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index 90950007..c7dd5db3 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -8,6 +8,7 @@ import jax.sharding from benchmark_utils import ( get_trace, + get_real_dtype_bytes, ) from common import MARKER import tempfile @@ -68,7 +69,7 @@ def get_metrics_helper( for key, value in params if value is not None and key not in exclude_keys } - metadata['dtype'] = metadata['dtype'].dtype.itemsize + metadata['dtype'] = get_real_dtype_bytes(metadata['dtype'].dtype) return metadata @@ -84,7 +85,7 @@ def send_recv_benchmark( device_count = jax.local_device_count() devices = mesh_utils.create_device_mesh((device_count,)) mesh = jax.sharding.Mesh(devices, 'x') - item_size = jnp.dtype(dtype).itemsize + item_size = get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_bytes = num_elements * item_size last_dim = tensor_size_bytes // (1 * 8 * item_size) @@ -161,7 +162,7 @@ def send_recv_benchmark_calculate_metrics( metadata = get_metrics_helper(params) metrics = {} - tensor_size_bytes = num_elements * jnp.dtype(dtype).itemsize + tensor_size_bytes = num_elements * get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_gbytes = tensor_size_bytes / 10**9 metrics['runtime_ms (ms)'] = runtime_ms diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index e28f39e4..ccd4f4c1 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -28,6 +28,18 @@ import jax.extend from tensorflow.tsl.profiler.protobuf import xplane_pb2 + +def get_real_dtype_bytes(dtype) -> float: + """Returns the real byte size of a dtype, handling sub-byte types.""" + try: + return jnp.finfo(dtype).bits / 8 + except Exception: + try: + return jnp.iinfo(dtype).bits / 8 + except Exception: + return dtype.itemsize + + # The dictionary to map a JAX (collective) function to its main HLO. TARGET_TASK_NAME_COLLECTIVES_MAP = { "all_to_all_ici_op": r"all-to-all.[0-9]+",