diff --git a/Ironwood/src/benchmark_host_device.py b/Ironwood/src/benchmark_host_device.py index 4e4e0a6..67eb980 100644 --- a/Ironwood/src/benchmark_host_device.py +++ b/Ironwood/src/benchmark_host_device.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Tuple, List import jax -from jax import sharding +from jax import numpy as jnp import numpy as np from benchmark_utils import MetricsStatistics @@ -25,7 +25,7 @@ def benchmark_host_device( num_runs: int = 100, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks H2D/D2H transfer using simple device_put/device_get.""" + """Benchmarks H2D/D2H transfer using device_put/device_get.""" num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize @@ -33,8 +33,13 @@ def benchmark_host_device( column = 128 host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32) + # Used in pipelined flow + # TODO: turn into a param + num_devices_to_perform_h2d = 1 + target_devices = jax.devices()[:num_devices_to_perform_h2d] + print( - f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations", + f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}", flush=True )