mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -34,61 +34,68 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hash_value, num_times = 0, 0
|
||||
for return_recv_hook in (False, True):
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
|
||||
if dispatch_use_fp8 else packed_recv_x.clone()
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
|
||||
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
|
||||
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
|
||||
for use_ue8m0 in (False, True) if round_scale else (False, ):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
|
||||
if dispatch_use_fp8 else packed_recv_x.clone()
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
|
||||
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
|
||||
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
|
||||
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
if round_scale:
|
||||
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
|
||||
else:
|
||||
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
|
||||
if not round_scale:
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
|
||||
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
# Check combine correctness
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
@@ -112,7 +119,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
@@ -170,6 +177,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
for i in range(20):
|
||||
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
|
||||
|
||||
# Destroy the communication group
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO: you may modify NUMA binding for less CPU overhead
|
||||
|
||||
Reference in New Issue
Block a user