diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index c033c72..6cf852d 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -78,7 +78,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - async_finish=not return_recv_hook, + async_finish=not return_recv_hook, zero_copy=zero_copy, return_recv_hook=return_recv_hook, out=out) hook() if return_recv_hook else event.current_stream_wait() if do_check: