From 3ab41e1fdd6e2f51db32e7fedb496fcbf8f722ca Mon Sep 17 00:00:00 2001 From: Hong-Yi Lin Date: Fri, 23 Jan 2026 12:58:59 +0000 Subject: [PATCH 1/3] Add multiple devices HBM. . --- .../configs/hbm/hbm_multiple_devices.yaml | 8 ++ Ironwood/src/benchmark_hbm.py | 77 +++++++++++++++++++ Ironwood/src/run_benchmark.py | 1 + 3 files changed, 86 insertions(+) create mode 100644 Ironwood/configs/hbm/hbm_multiple_devices.yaml diff --git a/Ironwood/configs/hbm/hbm_multiple_devices.yaml b/Ironwood/configs/hbm/hbm_multiple_devices.yaml new file mode 100644 index 00000000..1d5ceae5 --- /dev/null +++ b/Ironwood/configs/hbm/hbm_multiple_devices.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "multiple_device_hbm_copy" + benchmark_sweep_params: + - {num_elements_range: {start: 1048576, end: 4294967296, multiplier: 2}, dtype: "bfloat16", num_runs: 1} + trace_dir: "../microbenchmarks/hbm" + csv_path: "../microbenchmarks/hbm" + xlml_metrics_dir: "../microbenchmarks/hbm" + xla_dump_dir: "../microbenchmarks/hbm/hlo_graphs" \ No newline at end of file diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index bb279f42..e49b536e 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -6,11 +6,15 @@ from benchmark_utils import ( MetricsStatistics, multiple_iteration_timeit_from_trace, + ShardingStrategy, + create_mesh, ) from common import MARKER import jax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P SEED = 0 os.environ["LIBTPU_INIT_ARGS"] = ( @@ -102,3 +106,76 @@ def single_device_hbm_copy_calculate_metrics( metrics.update(statistics.serialize_statistics()) metrics = {key: value for key, value in metrics.items() if value is not None} return metadata, metrics + +SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING + +def multiple_device_hbm_copy( + num_elements: int, + dtype: jnp.dtype, + num_runs: int = 1, + trace_dir: str = None, +) -> Dict[str, Any]: + """Benchmarks HBM with copy(read and write) on a single device.""" + + def f(a): + with jax.named_scope(MARKER): + return a.copy() + + mesh = create_mesh(SHARDING_STRATEGY) + sharding = NamedSharding(mesh, P(None,)) + + a = jax.random.normal(jax.random.key(0), (num_elements,), out_sharding=sharding).astype(dtype) + print(a.shape) + print(a.dtype) + jitted_f = jax.jit(f) + # Run once + output = jitted_f(a) + jax.block_until_ready(output) + + # Run the benchmark + time_ms_list = multiple_iteration_timeit_from_trace( + compute_func=jitted_f, + data_generator=lambda: (a,), + matrix_dim=f"{num_elements}", + tries=num_runs, + task="copy", + trace_dir=trace_dir, + ) + return {"time_ms_list": time_ms_list} + +def multiple_device_hbm_copy_calculate_metrics( + num_elements: int, dtype: jnp.dtype, time_ms_list: list +) -> Dict[str, Any]: + """Calculates the metrics for the single device hbm copy benchmark.""" + # Build dictionary of all the parameters in the function + params = locals().items() + metadata = get_metrics_helper(params) + metrics = {} + + # Calculate throughput. + tensor_size_bytes = num_elements * dtype.dtype.itemsize + + tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9 + time_statistics = MetricsStatistics( + metrics_list=time_ms_list, metrics_name="time_ms" + ) + time_s_list = [time_ms / 10**3 for time_ms in time_ms_list] + bw_gbyte_sec_list = [tensor_size_gbytes / time_s for time_s in time_s_list] + statistics = MetricsStatistics( + metrics_list=bw_gbyte_sec_list, metrics_name="bw_gbyte_sec" + ) + print( + f"Tensor size: {tensor_size_bytes / 1024**2} MB, time taken (median):" + f" {time_statistics.statistics['p50']:.4f} ms, bandwidth (median): {statistics.statistics['p50']:.3f} GB/s" + ) + print() + # Gather the metrics to report. + metadata.update( + { + "tensor_size_gbytes": tensor_size_gbytes, + } + ) + metrics.update(time_statistics.serialize_statistics()) + metrics.update(statistics.serialize_statistics()) + metrics = {key: value for key, value in metrics.items() if value is not None} + return metadata, metrics \ No newline at end of file diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index 2b703487..fd207317 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -54,6 +54,7 @@ } HBM_BENCHMARK_MAP = { "single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy", + "multiple_device_hbm_copy": "benchmark_hbm.multiple_device_hbm_copy", } COMPUTE_BENCHMARK_MAP = { "gemm_simple": "benchmark_gemm.gemm_simple", From cf2327c05cc6ea3e70c051de1ae18c65bf381ee3 Mon Sep 17 00:00:00 2001 From: Hong-Yi Lin Date: Fri, 23 Jan 2026 13:17:20 +0000 Subject: [PATCH 2/3] Add multiple devices GEMM --- .../training/gemm_multiple_devices.yaml | 15 +++ Ironwood/src/benchmark_gemm.py | 97 +++++++++++++++++++ Ironwood/src/run_benchmark.py | 1 + 3 files changed, 113 insertions(+) create mode 100644 Ironwood/configs/training/gemm_multiple_devices.yaml diff --git a/Ironwood/configs/training/gemm_multiple_devices.yaml b/Ironwood/configs/training/gemm_multiple_devices.yaml new file mode 100644 index 00000000..7016208d --- /dev/null +++ b/Ironwood/configs/training/gemm_multiple_devices.yaml @@ -0,0 +1,15 @@ +benchmarks: +- benchmark_name: "gemm_multiple_devices" + trace_dir: "../microbenchmarks/gemm_multiple_run_bf16" + csv_path: "../microbenchmarks/gemm_multiple_run_bf16" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_bf16" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_bf16/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'} +- benchmark_name: "gemm_multiple_devices" + trace_dir: "../microbenchmarks/gemm_multiple_run_fp8" + csv_path: "../microbenchmarks/gemm_multiple_run_fp8" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp8" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp8/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'} \ No newline at end of file diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index c8c27bbe..4cc87050 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -540,3 +540,100 @@ def gemm_accum_calculate_metrics( total_flops_all_devices, PEAK_FLOPS_PER_DEVICE, ) + +def gemm_multiple_devices( + m: int, + k: int, + n: int, + dtype: jnp.dtype = jax.numpy.float8_e4m3fn, + num_runs: int = 1, + trace_dir: str = None, +) -> Dict[str, Any]: + """Benchmarks the OUT:BF16 = IN0 dtype x IN1:dtype. Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16.""" + + def f(x, y): + with jax.named_scope(MARKER): + acc = jax.numpy.einsum( + "ij,jk->ik", x, y, preferred_element_type=jnp.float32 + ) + return acc.astype(jnp.bfloat16) + SHARDING_STRATEGY = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M + + mesh = create_mesh(SHARDING_STRATEGY) + lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) + rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) + out_sharding = get_out_sharding(SHARDING_STRATEGY) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=(lhs_sharding.spec, rhs_sharding.spec), + out_specs=out_sharding, + check_rep=False, + ) + ) + + lhs_shape = (m, k) + rhs_shape = (k, n) + + lhs_dtype = dtype + rhs_dtype = dtype + + key = jax.random.key(SEED) + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + key, key_lhs, key_rhs = jax.random.split(key, 3) + + # Create random data on host + lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) + rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) + + # Put on device (HBM) + lhs_device = jax.device_put(lhs_host, lhs_sharding) + rhs_device = jax.device_put(rhs_host, rhs_sharding) + + return (lhs_device, rhs_device) + + # Run the benchmark + + print("Running gemm_multiple_run benchmark", num_runs) + dtype_str = "fp8" if dtype==jax.numpy.float8_e4m3fn else "bf16" + time_ms_list = multiple_iteration_timeit_from_trace( + jit_sharded_f, + data_generator, + matrix_dim=f"{dtype_str}_{m}x{n}x{k}", + tries=num_runs, + task="gemm_multiple_run", + trace_dir=trace_dir, + ) + return { + "time_ms_list": time_ms_list, + } + + +def gemm_multiple_devices_calculate_metrics( + m: int, + k: int, + n: int, + dtype: jnp.dtype, + time_ms_list: list[float], +) -> Dict[str, Any]: + # Calculate FLOPs + SHARDING_STRATEGY = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M + total_flops = 2 * m * k * n # Total floating-point operations + total_flops, total_flops_all_devices = handle_based_on_sharding( + total_flops, SHARDING_STRATEGY + ) + peak_flops = PEAK_FLOPS_PER_DEVICE if dtype==jax.numpy.float8_e4m3fn else PEAK_FLOPS_PER_DEVICE/2 + return unified_flops_metrics( + m, + n, + k, + time_ms_list, + total_flops, + total_flops_all_devices, + peak_flops, + ) diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index fd207317..e4c97571 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -63,6 +63,7 @@ "gemm_throttling": "benchmark_gemm_throttling.gemm_throttling", "gemm": "benchmark_gemm.gemm", "gemm_accum": "benchmark_gemm.gemm_accum", + "gemm_multiple_devices": "benchmark_gemm.gemm_multiple_devices", "quantization": "benchmark_compute.quantization", "transpose_quantization": "benchmark_compute.transpose_quantization", "quantization_static_scaling": ( From b11b01733cb65512295b4e174da2d6a657c2b0fe Mon Sep 17 00:00:00 2001 From: Hong-Yi Lin Date: Fri, 23 Jan 2026 14:38:39 +0000 Subject: [PATCH 3/3] Update code structures and configurations. --- .../configs/training/gemm_multiple_devices.yaml | 16 ++++++++-------- Ironwood/src/benchmark_gemm.py | 16 ++++++++-------- Ironwood/src/benchmark_hbm.py | 11 +++++------ Ironwood/src/run_benchmark.py | 2 +- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/Ironwood/configs/training/gemm_multiple_devices.yaml b/Ironwood/configs/training/gemm_multiple_devices.yaml index 7016208d..07590ad1 100644 --- a/Ironwood/configs/training/gemm_multiple_devices.yaml +++ b/Ironwood/configs/training/gemm_multiple_devices.yaml @@ -1,15 +1,15 @@ benchmarks: - benchmark_name: "gemm_multiple_devices" - trace_dir: "../microbenchmarks/gemm_multiple_run_bf16" - csv_path: "../microbenchmarks/gemm_multiple_run_bf16" - xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_bf16" - xla_dump_dir: "../microbenchmarks/gemm_multiple_run_bf16/hlo_graphs" + trace_dir: "../microbenchmarks/gemm_multiple_devices_bf16" + csv_path: "../microbenchmarks/gemm_multiple_devices_bf16" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_devices_bf16" + xla_dump_dir: "../microbenchmarks/gemm_multiple_devices_bf16/hlo_graphs" benchmark_sweep_params: - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'} - benchmark_name: "gemm_multiple_devices" - trace_dir: "../microbenchmarks/gemm_multiple_run_fp8" - csv_path: "../microbenchmarks/gemm_multiple_run_fp8" - xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp8" - xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp8/hlo_graphs" + trace_dir: "../microbenchmarks/gemm_multiple_devices_fp8" + csv_path: "../microbenchmarks/gemm_multiple_devices_fp8" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_devices_fp8" + xla_dump_dir: "../microbenchmarks/gemm_multiple_devices_fp8/hlo_graphs" benchmark_sweep_params: - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'} \ No newline at end of file diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index 4cc87050..0b92b83e 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -557,12 +557,12 @@ def f(x, y): "ij,jk->ik", x, y, preferred_element_type=jnp.float32 ) return acc.astype(jnp.bfloat16) - SHARDING_STRATEGY = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M + SHARDING_STRATEGY_MULTI_DEVICES = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M - mesh = create_mesh(SHARDING_STRATEGY) - lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) - rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) - out_sharding = get_out_sharding(SHARDING_STRATEGY) + mesh = create_mesh(SHARDING_STRATEGY_MULTI_DEVICES) + lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY_MULTI_DEVICES) + rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY_MULTI_DEVICES) + out_sharding = get_out_sharding(SHARDING_STRATEGY_MULTI_DEVICES) jit_sharded_f = jax.jit( shard_map( @@ -599,7 +599,7 @@ def data_generator(): # Run the benchmark - print("Running gemm_multiple_run benchmark", num_runs) + print("Running gemm_multiple_devices benchmark", num_runs) dtype_str = "fp8" if dtype==jax.numpy.float8_e4m3fn else "bf16" time_ms_list = multiple_iteration_timeit_from_trace( jit_sharded_f, @@ -622,10 +622,10 @@ def gemm_multiple_devices_calculate_metrics( time_ms_list: list[float], ) -> Dict[str, Any]: # Calculate FLOPs - SHARDING_STRATEGY = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M + SHARDING_STRATEGY_MULTI_DEVICES = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M total_flops = 2 * m * k * n # Total floating-point operations total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY + total_flops, SHARDING_STRATEGY_MULTI_DEVICES ) peak_flops = PEAK_FLOPS_PER_DEVICE if dtype==jax.numpy.float8_e4m3fn else PEAK_FLOPS_PER_DEVICE/2 return unified_flops_metrics( diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index e49b536e..7b6801b0 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -107,16 +107,15 @@ def single_device_hbm_copy_calculate_metrics( metrics = {key: value for key, value in metrics.items() if value is not None} return metadata, metrics -SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING - -def multiple_device_hbm_copy( +def multiple_devices_hbm_copy( num_elements: int, dtype: jnp.dtype, num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks HBM with copy(read and write) on a single device.""" + """Benchmarks HBM with copy(read and write) on multiple devices.""" + SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING def f(a): with jax.named_scope(MARKER): return a.copy() @@ -143,10 +142,10 @@ def f(a): ) return {"time_ms_list": time_ms_list} -def multiple_device_hbm_copy_calculate_metrics( +def multiple_devices_hbm_copy_calculate_metrics( num_elements: int, dtype: jnp.dtype, time_ms_list: list ) -> Dict[str, Any]: - """Calculates the metrics for the single device hbm copy benchmark.""" + """Calculates the metrics for the multiple devices hbm copy benchmark.""" # Build dictionary of all the parameters in the function params = locals().items() metadata = get_metrics_helper(params) diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index e4c97571..3b867fee 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -54,7 +54,7 @@ } HBM_BENCHMARK_MAP = { "single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy", - "multiple_device_hbm_copy": "benchmark_hbm.multiple_device_hbm_copy", + "multiple_device_hbm_copy": "benchmark_hbm.multiple_devices_hbm_copy", } COMPUTE_BENCHMARK_MAP = { "gemm_simple": "benchmark_gemm.gemm_simple",