Initial commit

This commit is contained in:
Chenggang Zhao
2025-02-24 21:56:14 +08:00
commit ebfe47e46f
32 changed files with 8461 additions and 0 deletions

246
tests/test_internode.py Normal file
View File

@@ -0,0 +1,246 @@
import os
import time
import torch
import torch.distributed as dist
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
# Test compatibility with low latency functions
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print()
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
# NOTES: handle must be refreshed
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print()
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
# Please make sure AR (Adaptive Routing) is turned off when running normal internode kernels,
num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0:
print()
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
if __name__ == '__main__':
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)

223
tests/test_intranode.py Normal file
View File

@@ -0,0 +1,223 @@
import os
import time
import torch
import torch.distributed as dist
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
# Test compatibility with low latency functions
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print()
group.barrier()
time.sleep(1)
# Config
nvl_buffer_size = 256
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, rank_prefix_matrix):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = rank_prefix_matrix[i][rank].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
rank_prefix_matrix = handle[0]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, rank_prefix_matrix)
# Test cached dispatch (must without top-k staffs)
# NOTES: handle must be refreshed
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print()
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, rank, buffer, group)
if local_rank == 0:
print()
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
if __name__ == '__main__':
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)

160
tests/test_low_latency.py Normal file
View File

@@ -0,0 +1,160 @@
import random
import torch
import torch.distributed as dist
from functools import partial
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for i in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for return_recv_hook in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
# Check combine correctness
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f'Error: diff={diff}'
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us')
return hash_value
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

192
tests/utils.py Normal file
View File

@@ -0,0 +1,192 @@
import os
import sys
import numpy as np
import torch
import torch.distributed as dist
from typing import Optional
def init_dist(local_rank: int, num_local_ranks: int):
# NOTES: you may rewrite this function with your own cluster settings
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
port = int(os.getenv('MASTER_PORT', '8361'))
num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group(
backend='nccl',
init_method=f'tcp://{ip}:{port}',
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank
)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')
torch.cuda.set_device(local_rank)
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
# Warmup
for _ in range(num_warmups):
fn()
# Flush L2
cache.zero_()
# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()
times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
return np.average(times), np.min(times), np.max(times)
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
fn()
prof.step()
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)
# Return average kernel times
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()