mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support statistics tensor for low-latency kernels (#196)
This commit is contained in:
@@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
@@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
|
||||
@@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
|
||||
Reference in New Issue
Block a user