mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support CUDA graph for intranode normal kernels (#203)
This commit is contained in:
@@ -100,6 +100,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
recv_topk_weights_clone = None
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
||||
@@ -107,10 +108,26 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
recv_topk_weights_clone = recv_topk_weights.clone()
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
||||
check_data(recv_topk_weights, rank_prefix_matrix)
|
||||
|
||||
# Test `num_worst_tokens != 0`
|
||||
if with_topk:
|
||||
num_worst_tokens = num_tokens * num_ranks
|
||||
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
|
||||
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
|
||||
assert num_worst_tokens == recv_worst_x.size(0)
|
||||
assert num_worst_tokens == recv_worst_topk_idx.size(0)
|
||||
assert num_worst_tokens == recv_worst_topk_weights.size(0)
|
||||
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
|
||||
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
|
||||
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
|
||||
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
|
||||
Reference in New Issue
Block a user