Use TMA instead of LD/ST for intra-node normal kernels (#191)

* Update CMake files

* Use TMA instead of LD/ST for intranode dispatch

* Use TMA instead of LD/ST for intranode combine

* Adjust configs

* Test default configs as well

* More warps for combine

* Add inter-thread fence

* Enable more warps

* Do not use TMA for senders

* Update configs

* Remove useless wait
This commit is contained in:
Chenggang Zhao
2025-06-06 15:40:17 +08:00
committed by GitHub
parent df4debe30c
commit c8dceba110
6 changed files with 230 additions and 87 deletions

View File

@@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_dispatch_config(num_ranks)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True)
@@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 7, 1):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_combine_config(num_ranks)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if t < best_time:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
@@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank)
@@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
num_processes = 8