Support zero-copy for low-latency combine

This commit is contained in:
Chenggang Zhao 2025-03-18 15:41:50 +08:00
parent 82dcf48fd3
commit dcaf73e5ff
7 changed files with 80 additions and 28 deletions

View File

@ -102,6 +102,9 @@ struct LowLatencyBuffer {
void* combine_rdma_recv_data_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr; int* combine_rdma_recv_flag_buffer = nullptr;
void* combine_rdma_send_buffer_data_start = nullptr;
int num_bytes_per_combine_msg = 0;
std::pair<int*, int> clean_meta() { std::pair<int*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int}; return {dispatch_rdma_recv_count_buffer, num_clean_int};
@ -163,7 +166,9 @@ struct LowLatencyLayout {
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i) advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
static_cast<int>(num_bytes_per_combine_msg)
}; };
} }
} }

View File

@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) { bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// Tensor checks // Tensor checks
@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, phases); workspace, launch_stream,
phases, zero_copy);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
return {combined_x, event, recv_hook}; return {combined_x, event, recv_hook};
} }
torch::Tensor
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
}
} // namespace deep_ep } // namespace deep_ep
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("internode_combine", &deep_ep::Buffer::internode_combine) .def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine); .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
} }

View File

@ -143,7 +143,11 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt); bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
}; };
} // namespace deep_ep } // namespace deep_ep

View File

@ -147,7 +147,8 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases); void* workspace, cudaStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll } // namespace internode_ll

View File

@ -353,7 +353,7 @@ combine(void* combined_x,
int num_combined_tokens, int hidden, int num_topk, int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank, int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks, int num_experts, int rank, int num_ranks,
int phases) { int phases, bool zero_copy) {
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x); const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
@ -420,6 +420,7 @@ combine(void* combined_x,
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else { } else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr); const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
} }
@ -500,7 +501,8 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) { void* workspace, cudaStream_t stream,
int phases, bool zero_copy) {
constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3; constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9; constexpr int kNumMaxTopk = 9;
@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \ num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \ num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \ num_experts, rank, num_ranks, \
phases); } break phases, zero_copy); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);

View File

@ -488,7 +488,7 @@ class Buffer:
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, async_finish, return_recv_hook) use_fp8, async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx, tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range) packed_recv_src_info, packed_recv_layout_range)
@ -497,8 +497,8 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False, handle: tuple, zero_copy: bool = False, async_finish: bool = False,
out: Optional[torch.Tensor] = None) -> \ return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@ -517,6 +517,8 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
@ -528,9 +530,24 @@ class Buffer:
event: the event after executing the kernel (valid only if `async_finish` is set). event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set).
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook, out) zero_copy, async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
def get_next_low_latency_combine_buffer(self, handle: object):
"""
Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying.
Arguments:
handle: the communication handle given by the `dispatch` function.
Returns:
buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer
by yourself.
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)

View File

@ -73,9 +73,13 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
for zero_copy in (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out) async_finish=not return_recv_hook,
return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook() hook()
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_func(return_recv_hook): def test_func(zero_copy: bool, return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=False, return_recv_hook=return_recv_hook) async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
return_recv_hook=return_recv_hook) zero_copy=zero_copy, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth # Calculate bandwidth
@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes += num_bf16_bytes * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing # Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling # Separate profiling
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
group.barrier() group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True) suppress_kineto_output=True)
if not return_recv_hook: if not return_recv_hook: