Support bias. (#257)

* Support bias.

* Fix.

* Fix style.
This commit is contained in:
Shangyan Zhou
2025-06-25 13:04:20 +08:00
committed by GitHub
parent b80e55e21f
commit bd429ffefc
7 changed files with 101 additions and 16 deletions

View File

@@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk: