diff --git a/Ironwood/configs/collectives/reduce_scatter_3d.yaml b/Ironwood/configs/collectives/reduce_scatter_3d.yaml new file mode 100644 index 00000000..137246f1 --- /dev/null +++ b/Ironwood/configs/collectives/reduce_scatter_3d.yaml @@ -0,0 +1,9 @@ +benchmarks: +- benchmark_name: psum_scatter + benchmark_sweep_params: + - {matrix_dim_range: {start: 2, end: 8192, multiplier: 2}, dtype: "float32", mesh_shape: "64x2", ici_size_range: 128, sharding_strategy: "64x1", op_dimension: 3, num_runs: 5} # Parallel Replica Groups + - {matrix_dim_range: {start: 2, end: 8192, multiplier: 2}, dtype: "float32", mesh_shape: "4x4x8", ici_size_range: 128, sharding_strategy: "4x4x8", op_dimension: 3, num_runs: 5} # Non Parallel Replica Groups + trace_dir: "../microbenchmarks/reduce_scatter_3d" + csv_path: "../microbenchmarks/reduce_scatter_3d" + xlml_metrics_dir: "../microbenchmarks/reduce_scatter_3d" + xla_dump_dir: "../microbenchmarks/reduce_scatter_3d/hlo_graphs" \ No newline at end of file diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 69b4d212..2f1313e3 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -23,6 +23,7 @@ from jax.sharding import Mesh from jax.sharding import PartitionSpec as P + BASE_SHAPE = [1, 8, 128] SEED = 0 GLOBAL_SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING @@ -76,6 +77,67 @@ def get_metrics_helper( return metadata +def _run_under_xprof( + function: jax.stages.Compiled, + inputs: list[jax.Array], + n_repeats: int, + task: str, + trace_dir: str = None, + matrix_dim: str = None, +): + """Runs a function under xprof.""" + # warmup + + trace_name = f"{task}_dim_{matrix_dim}" + trace_full_dir = f"{trace_dir}/{trace_name}" + + jax.block_until_ready(function(*inputs)) + with jax.profiler.trace(trace_full_dir, create_perfetto_link=False): + for i in range(n_repeats): + with jax.profiler.StepTraceAnnotation(task, step_num=i): + with jax.named_scope(f"{MARKER}_{i}"): + result = function(*inputs) + jax.block_until_ready(result) + + # Relying on xplane to get the metrics as the trace.json is truncating the last + # operations + _, space = find_sparsecore_usage_from_xplane(trace_full_dir) + + rel_events = [] + for plane in space.planes: + event_metadata_map = plane.event_metadata + for line in plane.lines: + for event in line.events: + metadata = event_metadata_map.get(event.metadata_id) + if not metadata: + continue + event_name = metadata.display_name + if "reduce-scatter" in event_name or "all-reduce" in event_name: + obj = {} + obj["name"] = event_name + obj["device_duration_ps"] = event.duration_ps + obj["pid"] = plane.id + obj["process_name"] = plane.name + rel_events.append(obj) + + min_pid = min([e["pid"] for e in rel_events]) + events_from_min_pid = [e for e in rel_events if e["pid"] == min_pid] + print("events_from_min_pid: ", events_from_min_pid) + + marker_call_done_events = [ + e for e in events_from_min_pid if e.get("name", "").endswith("call-done") + ] + if marker_call_done_events: + events_from_min_pid = marker_call_done_events + print("events_from_min_pid: ", events_from_min_pid) + durations_ms = [ + float(e["device_duration_ps"]) / 1e9 for e in events_from_min_pid + ] + print("durations_ms: ", durations_ms) + print(trace_full_dir) + return [sum(durations_ms)] + + def unified_ici_collectives_metrics( xla_output: str, matrix_shape: tuple[int, int, int], @@ -149,15 +211,14 @@ def unified_ici_collectives_metrics( / rank ) - sparsecore_used = "NA" if LOG_SPARSECORE_USAGE: print("trace_dir: ", trace_dir) if trace_dir: - sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir) + sparsecore_used, _ = find_sparsecore_usage_from_xplane(trace_dir) print("sparsecore_used: ", sparsecore_used) print("hlo first replica group: ", hlo_first_replica_group) - + metadata = { "iteration": iteration, "op_type": op_type, @@ -381,38 +442,38 @@ def psum_scatter_benchmark( sharding_axis = get_sharding_axis(sharding_strategy, mesh) + sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x"))) + op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple) + multiplier = 6 + m = op_dimension_tuple_multiplier * multiplier + n = matrix_dim + k = 256 + def f(x): with jax.named_scope(MARKER): return jax.lax.psum_scatter(x, sharding_axis, tiled=True) - jit_sharded_f = jax.jit( + compiled_function = jax.jit( shard_map( f, mesh=mesh, in_specs=P(None, None, None), - out_specs=P(sharding_axis, None, None), + out_specs=P(None, None, None), check_rep=False, ) ) - sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x"))) - op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple) - m = op_dimension_tuple_multiplier - n = matrix_dim - k = 256 - - def data_generator(): - """Creates new random data on host and puts it on device.""" - matrix = jnp.ones((m, n, k), dtype=dtype) - return (matrix,) - - time_ms_list = multiple_iteration_timeit_from_trace( - jit_sharded_f, - data_generator, - matrix_dim=f"{m}x{n}x{k}", - tries=num_runs, - task="psum_scatter_ici_op", - trace_dir=trace_dir, + matrix = jnp.ones((m, n, k), dtype=dtype) + + # Measures the longest wait time in milliseconds, across all the runs. + time_ms_list = _run_under_xprof( + compiled_function, + [matrix], + 1, + "psum_scatter_ici_op", + trace_dir, + matrix_dim, ) + print("Running psum_scatter benchmark", num_runs, matrix_dim) print("Matrix shape: ", m, n, k) return { diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 19a6642d..358233a5 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -570,7 +570,7 @@ def find_sparsecore_usage_from_xplane(log_dir: str) -> xplane_pb2.XSpace: if "SparseCore" in plane.name: sparsecore_found = True break - return sparsecore_found + return sparsecore_found, space def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: @@ -846,7 +846,7 @@ def rename_xla_dump( return new_base_name = f"{benchmark_name}_{serialized_benchmark_param}" - after_optimizations_path = input_shape = output_shape = replica_groups = first_replica_group = None + after_optimizations_path = before_optimizations_path = input_shape = output_shape = replica_groups = first_replica_group = None for original_filepath in all_related_files: original_filename = os.path.basename(original_filepath) @@ -872,6 +872,10 @@ def rename_xla_dump( if "after_optimizations.txt" in original_suffix_with_extension: after_optimizations_path = new_filepath + if "before_optimizations.txt" in original_suffix_with_extension: + before_optimizations_path = new_filepath + + if original_filepath == new_filepath: print( f"Skipping: '{original_filename}' already has the desired name or path." @@ -890,17 +894,18 @@ def rename_xla_dump( else: upload_to_storage(trace_dir=new_filepath, local_file=original_filepath) print(f"The XLA dump is stored in {dest_xla_dump_dir}") - if after_optimizations_path: + if before_optimizations_path: input_shape, output_shape, replica_groups, first_replica_group = ( - extract_hlo_features_from_file(after_optimizations_path) + extract_hlo_features_from_file(before_optimizations_path) ) else: print( - "No files found with 'after_optimizations.txt' suffix. " + "No files found with 'before_optimizations.txt' suffix. " "Please check the XLA dump directory." ) return json.dumps({ "after_optimizations_path": after_optimizations_path, + "before_optimizations_path": before_optimizations_path, "hlo_input_shape": input_shape, "hlo_output_shape": output_shape, "hlo_replica_groups": replica_groups,