Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu
2025-06-12 09:38:19 +08:00
committed by GitHub
parent 9ec061204e
commit 21efbe9b48
14 changed files with 255 additions and 115 deletions

View File

@@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank