mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use TMA instead of LD/ST for intra-node normal kernels (#191)
* Update CMake files * Use TMA instead of LD/ST for intranode dispatch * Use TMA instead of LD/ST for intranode combine * Adjust configs * Test default configs as well * More warps for combine * Add inter-thread fence * Enable more warps * Do not use TMA for senders * Update configs * Remove useless wait
This commit is contained in:
@@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
|
||||
if nvl_chunk_size > 0:
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
else:
|
||||
# Test default config as well
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
config = deep_ep.Buffer.get_dispatch_config(num_ranks)
|
||||
tune_args = {'x': current_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
if t < best_time and nvl_chunk_size > 0:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
|
||||
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
||||
print('', flush=True)
|
||||
@@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 7, 1):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
|
||||
if nvl_chunk_size > 0:
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
else:
|
||||
# Test default config as well
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
config = deep_ep.Buffer.get_combine_config(num_ranks)
|
||||
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
if t < best_time:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
|
||||
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
if t < best_time and nvl_chunk_size > 0:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
@@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
|
||||
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
|
||||
torch.manual_seed(rank)
|
||||
|
||||
@@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
# Destroy the communication group
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_processes = 8
|
||||
|
||||
Reference in New Issue
Block a user