mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
@@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test combine
|
||||
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
if with_topk:
|
||||
combine_args.update({'topk_weights': recv_topk_weights})
|
||||
if previous_mode:
|
||||
combine_args.update({'previous_event': buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
|
||||
Reference in New Issue
Block a user