Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions Ironwood/src/benchmark_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, Tuple, List

import jax
from jax import sharding
from jax import numpy as jnp
import numpy as np
from benchmark_utils import MetricsStatistics

Expand All @@ -25,16 +25,21 @@ def benchmark_host_device(
num_runs: int = 100,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
"""Benchmarks H2D/D2H transfer using device_put/device_get."""

num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize

# Allocate Host Source Buffer
column = 128
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)

# Used in pipelined flow
# TODO: turn into a param
num_devices_to_perform_h2d = 1
target_devices = jax.devices()[:num_devices_to_perform_h2d]

print(
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations",
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}",
flush=True
)

Expand Down