mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix test combine args
Signed-off-by: Hao Lin <linhaomails@gmail.com>
This commit is contained in:
parent
8a0ca8e2ec
commit
23c54150ba
@ -143,7 +143,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
||||
if with_topk:
|
||||
combine_args.update({'topk_weights': recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
combine_args.update({'previous_event': buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
|
@ -127,7 +127,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
if with_topk:
|
||||
combine_args.update({'topk_weights': recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
combine_args.update({'previous_event': buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
|
Loading…
Reference in New Issue
Block a user