From 23c54150baf229f23407986cbe994f7929cf2fb7 Mon Sep 17 00:00:00 2001 From: Hao Lin Date: Fri, 11 Apr 2025 18:21:09 +0800 Subject: [PATCH] Fix test combine args Signed-off-by: Hao Lin --- tests/test_internode.py | 2 +- tests/test_intranode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_internode.py b/tests/test_internode.py index 7c73faa..5884a16 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -143,7 +143,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: - dispatch_args.update({'previous_event': buffer.capture()}) + combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 169668c..68f16f7 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -127,7 +127,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: - dispatch_args.update({'previous_event': buffer.capture()}) + combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)