mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
|
||||
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
|
||||
|
||||
Reference in New Issue
Block a user