Optimize intranode combine. (#247)

* Increase the test round.

* Add warp synchronization.

* Shuffle the send warps.

* Add time elapsed into bench result.
This commit is contained in:
Shangyan Zhou
2025-06-24 09:10:23 +08:00
committed by GitHub
parent fbcf430006
commit 9eb2f84b3e
3 changed files with 8 additions and 8 deletions

View File

@@ -184,9 +184,9 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
print('', flush=True)
# Gather the best config from rank 0 and the first test setting
@@ -215,12 +215,12 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
print('', flush=True)