mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-05 20:44:48 +00:00
In our optimized Internode Normal Kernel, we implemented multiple QPs for data transmission between two GPUs, setting a different QP for each channel. Additionally, we modified the transmission method from IBRC to IBGAD. Through these optimizations, the Internode Normal Kernel achieves optimal performance in both H800 and H20 environments, with RDMA transmission performance nearly reaching the physical network performance limit. Using the current default statistical method, in 4-node H800 and H20 environments, RDMA performance can reach 60GB/s+.
247 lines
15 KiB
Python
247 lines
15 KiB
Python
import os
|
|
import time
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
import deep_ep
|
|
from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
|
|
|
|
# Test compatibility with low latency functions
|
|
import test_low_latency
|
|
|
|
|
|
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
|
|
# Settings
|
|
num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks
|
|
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
|
if local_rank == 0:
|
|
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
|
|
|
|
# Random data
|
|
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
|
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
|
x_e4m3 = per_token_cast_to_fp8(x)
|
|
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
|
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
|
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
|
|
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
|
|
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
|
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
|
|
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
|
|
rank_idx = topk_idx // (num_experts // num_ranks)
|
|
rank_idx.masked_fill_(topk_idx == -1, -1)
|
|
inplace_unique(rank_idx, num_ranks)
|
|
rdma_rank_idx = rank_idx // num_local_ranks
|
|
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
|
|
inplace_unique(rdma_rank_idx, num_nodes)
|
|
|
|
# RDMA dispatch counts
|
|
rdma_idx = topk_idx // (num_experts // num_nodes)
|
|
rdma_idx.masked_fill_(topk_idx == -1, -1)
|
|
inplace_unique(rdma_idx, num_nodes)
|
|
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
|
|
|
|
# Expert meta
|
|
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
|
|
for i in range(num_experts):
|
|
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
|
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
|
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
|
|
|
# Rank layout meta
|
|
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
|
|
num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')
|
|
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
|
|
for i in range(num_ranks):
|
|
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
|
token_sel = (rank_idx == i).max(dim=-1)[0]
|
|
count = token_sel.sum().item()
|
|
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
|
tokens[:count] = torch.sort(tokens[:count])[0]
|
|
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
|
|
for i in range(num_nodes):
|
|
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
|
|
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
|
is_token_in_rank = token_idx_in_rank >= 0
|
|
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
|
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
|
|
|
ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
|
|
buffer.get_dispatch_layout(topk_idx, num_experts)
|
|
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
|
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
|
|
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
|
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
|
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
|
if local_rank == 0:
|
|
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
|
|
print('', flush=True)
|
|
group.barrier()
|
|
time.sleep(1)
|
|
|
|
# Config
|
|
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
|
|
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
|
|
|
|
# Test dispatch
|
|
# noinspection PyShadowingNames
|
|
def check_data(check_x, recv_gbl_rank_prefix_sum):
|
|
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
|
check_start = 0
|
|
for i in range(num_ranks):
|
|
check_end = recv_gbl_rank_prefix_sum[i].item()
|
|
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
|
check_start = check_end
|
|
|
|
for previous_mode in (False, True):
|
|
for async_mode in (False, True):
|
|
for current_x in (x_pure_rand, x, x_e4m3):
|
|
for with_topk in (False, True):
|
|
if local_rank == 0:
|
|
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
|
|
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank,
|
|
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
|
|
if with_topk:
|
|
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
|
|
if previous_mode:
|
|
dispatch_args.update({'previous_event': buffer.capture()})
|
|
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
|
|
event.current_stream_wait() if async_mode else ()
|
|
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
|
|
|
|
# Checks
|
|
recv_gbl_rank_prefix_sum = handle[-4]
|
|
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
|
|
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
|
if current_x is not x_pure_rand:
|
|
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
|
if with_topk:
|
|
# Check `topk_idx`
|
|
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
|
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
|
assert recv_topk_idx.eq(i).sum().item() == count
|
|
|
|
# Check `topk_weights`
|
|
if current_x is not x_pure_rand:
|
|
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
|
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
|
|
|
|
# Test cached dispatch (must without top-k staffs)
|
|
if not with_topk:
|
|
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
|
if previous_mode:
|
|
dispatch_args.update({'previous_event': buffer.capture()})
|
|
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
|
event.current_stream_wait() if async_mode else ()
|
|
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
|
|
if current_x is not x_pure_rand:
|
|
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
|
|
|
# Test combine
|
|
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
|
if with_topk:
|
|
combine_args.update({'topk_weights': recv_topk_weights})
|
|
if previous_mode:
|
|
combine_args.update({'previous_event': buffer.capture()})
|
|
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
|
event.current_stream_wait() if async_mode else ()
|
|
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
|
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
|
assert calc_diff(check_x, ref_x) < 5e-6
|
|
if with_topk:
|
|
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
|
|
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
|
|
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
|
|
|
# For later tuning
|
|
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
|
|
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
|
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
|
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
|
|
|
|
if local_rank == 0:
|
|
print(' passed', flush=True)
|
|
if local_rank == 0:
|
|
print('', flush=True)
|
|
|
|
# Tune dispatch performance
|
|
best_dispatch_results = None
|
|
fp8_factor = (1 + 4 / 128) / 2
|
|
for current_x in (x_e4m3, x):
|
|
best_time, best_results = 1e10, None
|
|
rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes
|
|
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
|
|
for nvl_chunk_size in range(4, 33, 4):
|
|
for rdma_chunk_size in range(4, 33, 4):
|
|
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
|
|
tune_args = {'x': current_x, 'handle': handle, 'config': config}
|
|
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
|
if t < best_time:
|
|
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
|
|
if local_rank == 0:
|
|
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
|
if local_rank == 0:
|
|
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
|
print('', flush=True)
|
|
|
|
if isinstance(current_x, tuple):
|
|
# Gather FP8 the best config from rank 0
|
|
best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')
|
|
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
|
|
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
|
|
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
|
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size)
|
|
|
|
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,
|
|
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
|
|
'config': dispatch_config if dispatch_config is not None else config}
|
|
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
|
|
|
# Tune combine performance
|
|
best_time, best_results = 1e10, None
|
|
for nvl_chunk_size in range(1, 5, 1):
|
|
for rdma_chunk_size in range(8, 33, 4):
|
|
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
|
|
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
|
t = bench(lambda: buffer.combine(**tune_args))[0]
|
|
if local_rank == 0:
|
|
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
|
if t < best_time:
|
|
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
|
|
|
|
if local_rank == 0:
|
|
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
|
print('', flush=True)
|
|
|
|
|
|
# noinspection PyUnboundLocalVariable
|
|
def test_loop(local_rank: int, num_local_ranks: int):
|
|
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
|
num_sms = 24
|
|
qp_num = num_sms // 2
|
|
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
|
test_ll_compatibility = False
|
|
if test_ll_compatibility:
|
|
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
|
|
|
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
|
|
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num))
|
|
assert num_local_ranks == 8 and num_ranks > 8
|
|
torch.manual_seed(rank)
|
|
|
|
for i in (num_sms, ):
|
|
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
|
|
if local_rank == 0:
|
|
print('', flush=True)
|
|
|
|
# Test compatibility with low latency functions
|
|
if test_ll_compatibility:
|
|
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
|
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
num_processes = 8
|
|
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
|