Support statistics tensor for low-latency kernels (#196)

This commit is contained in:
Chenggang Zhao
2025-06-09 15:50:56 +08:00
committed by GitHub
parent 0d1a855d81
commit 5a2e37fa28
6 changed files with 27 additions and 3 deletions

View File

@@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
for dispatch_use_fp8 in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
@@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
@@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
def test_func(zero_copy: bool, return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy: