Refactor some code.

This commit is contained in:
Shangyan Zhou
2025-04-22 10:22:30 +08:00
parent c07fdd197c
commit 20b2aaaf9e
4 changed files with 90 additions and 61 deletions

View File

@@ -218,15 +218,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv('WORLD_SIZE', 1))
num_sms = 24
qp_num = num_sms // 2
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
test_ll_compatibility = True
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_sms = 24
num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num))
num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)