diff --git a/tests/test_internode.py b/tests/test_internode.py index 7c73faa..5884a16 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -143,7 +143,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: - dispatch_args.update({'previous_event': buffer.capture()}) + combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 169668c..68f16f7 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -127,7 +127,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: - dispatch_args.update({'previous_event': buffer.capture()}) + combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)