mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -22,6 +22,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
|
||||
@@ -241,6 +242,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
# Destroy the communication group
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_processes = 8
|
||||
|
||||
Reference in New Issue
Block a user