mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support zero-copy for low-latency combine
This commit is contained in:
@@ -73,15 +73,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: diff={diff}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook,
|
||||
return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: diff={diff}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
@@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hook()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func(return_recv_hook):
|
||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
return_recv_hook=return_recv_hook)
|
||||
zero_copy=zero_copy, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
|
||||
# Calculate bandwidth
|
||||
@@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
num_combine_comm_bytes += num_bf16_bytes * num_selections
|
||||
|
||||
# Dispatch + combine testing
|
||||
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
|
||||
avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
|
||||
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
|
||||
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
|
||||
|
||||
# Separate profiling
|
||||
for return_recv_hook in (False, True):
|
||||
group.barrier()
|
||||
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
|
||||
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
|
||||
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
|
||||
suppress_kineto_output=True)
|
||||
if not return_recv_hook:
|
||||
|
||||
Reference in New Issue
Block a user