mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-04 04:03:56 +00:00
Merge pull request #79 from deepseek-ai/zero-copy-combine
Support zero-copy for low-latency combine
This commit is contained in:
commit
c4b8ffc37c
@ -102,6 +102,9 @@ struct LowLatencyBuffer {
|
||||
void* combine_rdma_recv_data_buffer = nullptr;
|
||||
int* combine_rdma_recv_flag_buffer = nullptr;
|
||||
|
||||
void* combine_rdma_send_buffer_data_start = nullptr;
|
||||
size_t num_bytes_per_combine_msg = 0;
|
||||
|
||||
std::pair<int*, int> clean_meta() {
|
||||
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
|
||||
return {dispatch_rdma_recv_count_buffer, num_clean_int};
|
||||
@ -163,7 +166,9 @@ struct LowLatencyLayout {
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
|
||||
num_bytes_per_combine_msg
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
|
||||
bool zero_copy, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
workspace, launch_stream, phases);
|
||||
workspace, launch_stream,
|
||||
phases, zero_copy);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
return {combined_x, event, recv_hook};
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
auto dtype = torch::kBFloat16;
|
||||
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
|
||||
|
||||
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
|
||||
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
|
||||
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
|
||||
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
.def("internode_combine", &deep_ep::Buffer::internode_combine)
|
||||
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
|
||||
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
|
||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine);
|
||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
|
||||
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
|
||||
}
|
||||
|
@ -143,7 +143,11 @@ public:
|
||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
|
||||
bool zero_copy, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out = std::nullopt);
|
||||
|
||||
torch::Tensor
|
||||
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
|
||||
};
|
||||
|
||||
} // namespace deep_ep
|
||||
|
@ -147,7 +147,8 @@ void combine(void* combined_x,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
|
@ -353,7 +353,7 @@ combine(void* combined_x,
|
||||
int num_combined_tokens, int hidden, int num_topk,
|
||||
int num_max_dispatch_tokens_per_rank,
|
||||
int num_experts, int rank, int num_ranks,
|
||||
int phases) {
|
||||
int phases, bool zero_copy) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@ -420,7 +420,8 @@ combine(void* combined_x,
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||
}
|
||||
}
|
||||
@ -500,7 +501,8 @@ void combine(void* combined_x,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy) {
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases); } break
|
||||
phases, zero_copy); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
|
||||
|
@ -488,7 +488,7 @@ class Buffer:
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
use_fp8, async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
packed_recv_src_info, packed_recv_layout_range)
|
||||
@ -497,8 +497,8 @@ class Buffer:
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
|
||||
out: Optional[torch.Tensor] = None) -> \
|
||||
handle: tuple, zero_copy: bool = False, async_finish: bool = False,
|
||||
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
|
||||
Tuple[torch.Tensor, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||
@ -517,6 +517,8 @@ class Buffer:
|
||||
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
|
||||
tokens. The received tokens will be reduced with the weights in this tensor.
|
||||
handle: the communication handle given by the `dispatch` function.
|
||||
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
|
||||
with `get_next_low_latency_combine_buffer`.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
@ -528,9 +530,24 @@ class Buffer:
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
|
||||
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook, out)
|
||||
zero_copy, async_finish, return_recv_hook, out)
|
||||
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
||||
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
def get_next_low_latency_combine_buffer(self, handle: object):
|
||||
"""
|
||||
Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying.
|
||||
|
||||
Arguments:
|
||||
handle: the communication handle given by the `dispatch` function.
|
||||
|
||||
Returns:
|
||||
buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape
|
||||
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer
|
||||
by yourself.
|
||||
"""
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
|
||||
return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
|
||||
|
@ -73,15 +73,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: diff={diff}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook,
|
||||
return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: diff={diff}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hook()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func(return_recv_hook):
|
||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
return_recv_hook=return_recv_hook)
|
||||
zero_copy=zero_copy, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
|
||||
# Calculate bandwidth
|
||||
@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
num_combine_comm_bytes += num_bf16_bytes * num_selections
|
||||
|
||||
# Dispatch + combine testing
|
||||
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
|
||||
avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
|
||||
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
|
||||
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
|
||||
|
||||
# Separate profiling
|
||||
for return_recv_hook in (False, True):
|
||||
group.barrier()
|
||||
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
|
||||
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
|
||||
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
|
||||
suppress_kineto_output=True)
|
||||
if not return_recv_hook:
|
||||
|
Loading…
Reference in New Issue
Block a user