Extract latency from TPU trace for collective microbenchmarks#26
Extract latency from TPU trace for collective microbenchmarks#26hylin2002 wants to merge 3 commits intoAI-Hypercomputer:mainfrom
Conversation
chishuen
left a comment
There was a problem hiding this comment.
I'll continue the review later today.
src/benchmark_utils.py
Outdated
| import subprocess | ||
| import shutil | ||
|
|
||
| # The dictionary to map a CPU collective function to its corresponding operation on TPU |
There was a problem hiding this comment.
nit: drop the term "CPU". E.g. "map a JAX (collective) operation to its main HLO"
src/benchmark_utils.py
Outdated
| import shutil | ||
|
|
||
| # The dictionary to map a CPU collective function to its corresponding operation on TPU | ||
| # "psum_scatter_ici_op" has different implementation according to its `matrix_dim` and the number of TPUs, so it's not considered in this mapping dictionary. |
There was a problem hiding this comment.
Let's remove this comment
src/benchmark_utils.py
Outdated
| "all_to_all_ici_op": r"all-to-all.[0-9]+", | ||
| "all_gather_ici_op": r"all-gather.[0-9]+", | ||
| "psum_ici_op": r"all-reduce.[0-9]+", | ||
| "ppermute_ici_op": r"collective-permute-done", |
There was a problem hiding this comment.
LMK when you have the data with xla_enable_async_collective_permute=false
There was a problem hiding this comment.
I've already have the data for this flag.
src/benchmark_utils.py
Outdated
|
|
||
| # Check if the given task name is a collective with corresponding TPU opertion. | ||
| # This is a workaround and should be reverted or refactored in future. | ||
| if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys(): |
There was a problem hiding this comment.
nit: if task in TARGET_TASK_NAME_COLLECTIVES_MAP is sufficient.
src/benchmark_utils.py
Outdated
| # This is a workaround and should be reverted or refactored in future. | ||
| if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys(): | ||
| task = TARGET_TASK_NAME_COLLECTIVES_MAP[task] | ||
| return get_metrics_from_trace_tpu(trace, task) |
There was a problem hiding this comment.
I think you don't need to call the function again?
| def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: | ||
|
|
||
| # Check if the given task name is a collective with corresponding TPU opertion. | ||
| # This is a workaround and should be reverted or refactored in future. |
There was a problem hiding this comment.
add comment: if task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.
f2ddc4c to
4e09b70
Compare
0edd623 to
6ef9e45
Compare
Change write to csv filename from csv to tsv.
8a477b4 to
acd30c9
Compare
2c55acd to
afd28a0
Compare
This PR is a workaround to extract the latency from trace on TPU for collectives, and should be reverted or refactored in future.
The detailed changes include:
TARGET_TASK_NAME_COLLECTIVES_MAPinsrc/benchmark_utils.pyto map a collective to its corresponding operation on TPU devices.get_metrics_from_trace_tputo extract the execution time of collective operation on TPU.