Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Ironwood/configs/collectives/reduce_scatter_3d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
benchmarks:
- benchmark_name: psum_scatter
benchmark_sweep_params:
- {matrix_dim_range: {start: 2, end: 8192, multiplier: 2}, dtype: "float32", mesh_shape: "64x2", ici_size_range: 128, sharding_strategy: "64x1", op_dimension: 3, num_runs: 5} # Parallel Replica Groups
- {matrix_dim_range: {start: 2, end: 8192, multiplier: 2}, dtype: "float32", mesh_shape: "4x4x8", ici_size_range: 128, sharding_strategy: "4x4x8", op_dimension: 3, num_runs: 5} # Non Parallel Replica Groups
trace_dir: "../microbenchmarks/reduce_scatter_3d"
csv_path: "../microbenchmarks/reduce_scatter_3d"
xlml_metrics_dir: "../microbenchmarks/reduce_scatter_3d"
xla_dump_dir: "../microbenchmarks/reduce_scatter_3d/hlo_graphs"
107 changes: 84 additions & 23 deletions Ironwood/src/benchmark_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P


BASE_SHAPE = [1, 8, 128]
SEED = 0
GLOBAL_SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING
Expand Down Expand Up @@ -76,6 +77,67 @@ def get_metrics_helper(
return metadata


def _run_under_xprof(
function: jax.stages.Compiled,
inputs: list[jax.Array],
n_repeats: int,
task: str,
trace_dir: str = None,
matrix_dim: str = None,
):
"""Runs a function under xprof."""
# warmup

trace_name = f"{task}_dim_{matrix_dim}"
trace_full_dir = f"{trace_dir}/{trace_name}"

jax.block_until_ready(function(*inputs))
with jax.profiler.trace(trace_full_dir, create_perfetto_link=False):
for i in range(n_repeats):
with jax.profiler.StepTraceAnnotation(task, step_num=i):
with jax.named_scope(f"{MARKER}_{i}"):
result = function(*inputs)
jax.block_until_ready(result)

# Relying on xplane to get the metrics as the trace.json is truncating the last
# operations
_, space = find_sparsecore_usage_from_xplane(trace_full_dir)

rel_events = []
for plane in space.planes:
event_metadata_map = plane.event_metadata
for line in plane.lines:
for event in line.events:
metadata = event_metadata_map.get(event.metadata_id)
if not metadata:
continue
event_name = metadata.display_name
if "reduce-scatter" in event_name or "all-reduce" in event_name:
obj = {}
obj["name"] = event_name
obj["device_duration_ps"] = event.duration_ps
obj["pid"] = plane.id
obj["process_name"] = plane.name
rel_events.append(obj)

min_pid = min([e["pid"] for e in rel_events])
events_from_min_pid = [e for e in rel_events if e["pid"] == min_pid]
print("events_from_min_pid: ", events_from_min_pid)

marker_call_done_events = [
e for e in events_from_min_pid if e.get("name", "").endswith("call-done")
]
if marker_call_done_events:
events_from_min_pid = marker_call_done_events
print("events_from_min_pid: ", events_from_min_pid)
durations_ms = [
float(e["device_duration_ps"]) / 1e9 for e in events_from_min_pid
]
print("durations_ms: ", durations_ms)
print(trace_full_dir)
return [sum(durations_ms)]


def unified_ici_collectives_metrics(
xla_output: str,
matrix_shape: tuple[int, int, int],
Expand Down Expand Up @@ -149,15 +211,14 @@ def unified_ici_collectives_metrics(
/ rank
)


sparsecore_used = "NA"
if LOG_SPARSECORE_USAGE:
print("trace_dir: ", trace_dir)
if trace_dir:
sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir)
sparsecore_used, _ = find_sparsecore_usage_from_xplane(trace_dir)
print("sparsecore_used: ", sparsecore_used)
print("hlo first replica group: ", hlo_first_replica_group)

metadata = {
"iteration": iteration,
"op_type": op_type,
Expand Down Expand Up @@ -381,38 +442,38 @@ def psum_scatter_benchmark(

sharding_axis = get_sharding_axis(sharding_strategy, mesh)

sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x")))
op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple)
multiplier = 6
m = op_dimension_tuple_multiplier * multiplier
n = matrix_dim
k = 256

def f(x):
with jax.named_scope(MARKER):
return jax.lax.psum_scatter(x, sharding_axis, tiled=True)

jit_sharded_f = jax.jit(
compiled_function = jax.jit(
shard_map(
f,
mesh=mesh,
in_specs=P(None, None, None),
out_specs=P(sharding_axis, None, None),
out_specs=P(None, None, None),
check_rep=False,
)
)
sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x")))
op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple)
m = op_dimension_tuple_multiplier
n = matrix_dim
k = 256

def data_generator():
"""Creates new random data on host and puts it on device."""
matrix = jnp.ones((m, n, k), dtype=dtype)
return (matrix,)

time_ms_list = multiple_iteration_timeit_from_trace(
jit_sharded_f,
data_generator,
matrix_dim=f"{m}x{n}x{k}",
tries=num_runs,
task="psum_scatter_ici_op",
trace_dir=trace_dir,
matrix = jnp.ones((m, n, k), dtype=dtype)

# Measures the longest wait time in milliseconds, across all the runs.
time_ms_list = _run_under_xprof(
compiled_function,
[matrix],
1,
"psum_scatter_ici_op",
trace_dir,
matrix_dim,
)

print("Running psum_scatter benchmark", num_runs, matrix_dim)
print("Matrix shape: ", m, n, k)
return {
Expand Down
15 changes: 10 additions & 5 deletions Ironwood/src/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def find_sparsecore_usage_from_xplane(log_dir: str) -> xplane_pb2.XSpace:
if "SparseCore" in plane.name:
sparsecore_found = True
break
return sparsecore_found
return sparsecore_found, space


def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
Expand Down Expand Up @@ -846,7 +846,7 @@ def rename_xla_dump(
return

new_base_name = f"{benchmark_name}_{serialized_benchmark_param}"
after_optimizations_path = input_shape = output_shape = replica_groups = first_replica_group = None
after_optimizations_path = before_optimizations_path = input_shape = output_shape = replica_groups = first_replica_group = None

for original_filepath in all_related_files:
original_filename = os.path.basename(original_filepath)
Expand All @@ -872,6 +872,10 @@ def rename_xla_dump(
if "after_optimizations.txt" in original_suffix_with_extension:
after_optimizations_path = new_filepath

if "before_optimizations.txt" in original_suffix_with_extension:
before_optimizations_path = new_filepath


if original_filepath == new_filepath:
print(
f"Skipping: '{original_filename}' already has the desired name or path."
Expand All @@ -890,17 +894,18 @@ def rename_xla_dump(
else:
upload_to_storage(trace_dir=new_filepath, local_file=original_filepath)
print(f"The XLA dump is stored in {dest_xla_dump_dir}")
if after_optimizations_path:
if before_optimizations_path:
input_shape, output_shape, replica_groups, first_replica_group = (
extract_hlo_features_from_file(after_optimizations_path)
extract_hlo_features_from_file(before_optimizations_path)
)
else:
print(
"No files found with 'after_optimizations.txt' suffix. "
"No files found with 'before_optimizations.txt' suffix. "
"Please check the XLA dump directory."
)
return json.dumps({
"after_optimizations_path": after_optimizations_path,
"before_optimizations_path": before_optimizations_path,
"hlo_input_shape": input_shape,
"hlo_output_shape": output_shape,
"hlo_replica_groups": replica_groups,
Expand Down