Fix zero-copy mode tests

This commit is contained in:
Chenggang Zhao 2025-03-28 16:49:33 +08:00
parent c4d12b4f8f
commit 26fa72d80f

View File

@ -78,7 +78,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, async_finish=not return_recv_hook, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, out=out) return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check: