diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 69b4d212..08890385 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -11,6 +11,7 @@ from benchmark_utils import MetricsStatistics from benchmark_utils import multiple_iteration_timeit_from_trace from benchmark_utils import ShardingStrategy +from benchmark_utils import get_real_dtype_bytes from common import MARKER import jax from jax import core @@ -72,7 +73,7 @@ def get_metrics_helper( for key, value in params if value is not None and key not in exclude_keys } - metadata["dtype"] = metadata["dtype"].dtype.itemsize + metadata["dtype"] = get_real_dtype_bytes(metadata["dtype"].dtype) return metadata @@ -98,7 +99,8 @@ def unified_ici_collectives_metrics( hlo_first_replica_group = [] input_num_elements = matrix_shape[0] * matrix_shape[1] * matrix_shape[2] - dtype_bytes = dtype.dtype.itemsize + dtype_name = dtype.dtype.name + dtype_bytes = get_real_dtype_bytes(dtype.dtype) if xla_output: xla_output_json = json.loads(xla_output) hlo_input_shape = xla_output_json.get("hlo_input_shape") diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index bb279f42..67f0429b 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -6,6 +6,7 @@ from benchmark_utils import ( MetricsStatistics, multiple_iteration_timeit_from_trace, + get_real_dtype_bytes, ) from common import MARKER import jax @@ -76,7 +77,7 @@ def single_device_hbm_copy_calculate_metrics( metrics = {} # Calculate throughput. - tensor_size_bytes = num_elements * dtype.dtype.itemsize + tensor_size_bytes = num_elements * get_real_dtype_bytes(dtype.dtype) tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9 time_statistics = MetricsStatistics( diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index 90950007..c7dd5db3 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -8,6 +8,7 @@ import jax.sharding from benchmark_utils import ( get_trace, + get_real_dtype_bytes, ) from common import MARKER import tempfile @@ -68,7 +69,7 @@ def get_metrics_helper( for key, value in params if value is not None and key not in exclude_keys } - metadata['dtype'] = metadata['dtype'].dtype.itemsize + metadata['dtype'] = get_real_dtype_bytes(metadata['dtype'].dtype) return metadata @@ -84,7 +85,7 @@ def send_recv_benchmark( device_count = jax.local_device_count() devices = mesh_utils.create_device_mesh((device_count,)) mesh = jax.sharding.Mesh(devices, 'x') - item_size = jnp.dtype(dtype).itemsize + item_size = get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_bytes = num_elements * item_size last_dim = tensor_size_bytes // (1 * 8 * item_size) @@ -161,7 +162,7 @@ def send_recv_benchmark_calculate_metrics( metadata = get_metrics_helper(params) metrics = {} - tensor_size_bytes = num_elements * jnp.dtype(dtype).itemsize + tensor_size_bytes = num_elements * get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_gbytes = tensor_size_bytes / 10**9 metrics['runtime_ms (ms)'] = runtime_ms diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index e28f39e4..ccd4f4c1 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -28,6 +28,18 @@ import jax.extend from tensorflow.tsl.profiler.protobuf import xplane_pb2 + +def get_real_dtype_bytes(dtype) -> float: + """Returns the real byte size of a dtype, handling sub-byte types.""" + try: + return jnp.finfo(dtype).bits / 8 + except Exception: + try: + return jnp.iinfo(dtype).bits / 8 + except Exception: + return dtype.itemsize + + # The dictionary to map a JAX (collective) function to its main HLO. TARGET_TASK_NAME_COLLECTIVES_MAP = { "all_to_all_ici_op": r"all-to-all.[0-9]+",