Merge pull request #118 from andylin-hao/main

Fix test combine args
This commit is contained in:
Chenggang Zhao 2025-04-14 15:51:30 +08:00 committed by GitHub
commit a0c69317ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -143,7 +143,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)

View File

@ -127,7 +127,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)