Support Ampere architecture (#204)

* Update README

* Update `setup.py`

* Fix headers

* Add `DISABLE_NVSHMEM` for APIs

* Fix launch

* Fix TMA settings

* Fix TMA usages

* Fix dlink

* Separate layout kernels

* Update version

* Add `is_sm90_compiled`

* Fix tests

* Add NVLink connection checks

* Update README

* Fix tests

* Add some comments

* Minor fix

* Minor fix

* Fix bugs
This commit is contained in:
Chenggang Zhao
2025-06-11 15:48:18 +08:00
committed by GitHub
parent dd13c7145c
commit b8d90fb753
16 changed files with 413 additions and 174 deletions

View File

@@ -1,4 +1,3 @@
import os
import time
import torch
import torch.distributed as dist
@@ -21,7 +20,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
@@ -80,7 +79,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
@@ -168,7 +167,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)):
best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
@@ -189,8 +188,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True)
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
# Gather the best config from rank 0 and the first test setting
if best_dispatch_results is None:
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)