diff --git a/csrc/config.hpp b/csrc/config.hpp index 37be06b..1c7a681 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -102,6 +102,9 @@ struct LowLatencyBuffer { void* combine_rdma_recv_data_buffer = nullptr; int* combine_rdma_recv_flag_buffer = nullptr; + void* combine_rdma_send_buffer_data_start = nullptr; + size_t num_bytes_per_combine_msg = 0; + std::pair clean_meta() { EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); return {dispatch_rdma_recv_count_buffer, num_clean_int}; @@ -163,7 +166,9 @@ struct LowLatencyLayout { advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i) + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)), + num_bytes_per_combine_msg }; } } diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 404c2b4..e0c290d 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1100,7 +1100,8 @@ std::tuple, std::optional out) { + bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out) { EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id next_clean_meta.first, next_clean_meta.second, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, - workspace, launch_stream, phases); + workspace, launch_stream, + phases, zero_copy); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id return {combined_x, event, recv_hook}; } +torch::Tensor +Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto dtype = torch::kBFloat16; + auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); + + EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); + return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, + {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, + torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); +} + } // namespace deep_ep PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("internode_combine", &deep_ep::Buffer::internode_combine) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) - .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine); + .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) + .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); } diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 1a5bb79..e0ad4d6 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -143,7 +143,11 @@ public: low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, int num_experts, - bool async, bool return_recv_hook, std::optional out = std::nullopt); + bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out = std::nullopt); + + torch::Tensor + get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); }; } // namespace deep_ep diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 74c962b..89937a8 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -147,7 +147,8 @@ void combine(void* combined_x, int* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, int phases); + void* workspace, cudaStream_t stream, + int phases, bool zero_copy); } // namespace internode_ll diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 76ae2e2..6d0c871 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -353,7 +353,7 @@ combine(void* combined_x, int num_combined_tokens, int hidden, int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, - int phases) { + int phases, bool zero_copy) { const auto sm_id = static_cast(blockIdx.x); const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); @@ -420,7 +420,8 @@ combine(void* combined_x, UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); } else { const auto buf_int4_ptr = reinterpret_cast(buf_ptr); - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + if (not zero_copy) + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); } } @@ -500,7 +501,8 @@ void combine(void* combined_x, int* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, int phases) { + void* workspace, cudaStream_t stream, + int phases, bool zero_copy) { constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpGroups = 3; constexpr int kNumMaxTopk = 9; @@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \ num_combined_tokens, hidden, num_topk, \ num_max_dispatch_tokens_per_rank, \ num_experts, rank, num_ranks, \ - phases); } break + phases, zero_copy); } break SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 098b346..11a2ce3 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -488,7 +488,7 @@ class Buffer: self.runtime.low_latency_dispatch(x, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8, async_finish, return_recv_hook) - handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts) + handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) tensors_to_record = (x, topk_idx, packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range) @@ -497,8 +497,8 @@ class Buffer: # noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: tuple, async_finish: bool = False, return_recv_hook: bool = False, - out: Optional[torch.Tensor] = None) -> \ + handle: tuple, zero_copy: bool = False, async_finish: bool = False, + return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. @@ -517,6 +517,8 @@ class Buffer: topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched tokens. The received tokens will be reduced with the weights in this tensor. handle: the communication handle given by the `dispatch` function. + zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative + with `get_next_low_latency_combine_buffer`. async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. @@ -528,9 +530,24 @@ class Buffer: event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ - src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle + src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts, - async_finish, return_recv_hook, out) + zero_copy, async_finish, return_recv_hook, out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook + + def get_next_low_latency_combine_buffer(self, handle: object): + """ + Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying. + + Arguments: + handle: the communication handle given by the `dispatch` function. + + Returns: + buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape + `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer + by yourself. + """ + src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle + return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 3ed681b..c033c72 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -73,15 +73,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) # Check combine correctness - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out) - hook() if return_recv_hook else event.current_stream_wait() - if do_check: - diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) - assert torch.isnan(combined_x).sum().item() == 0 - assert diff < 1e-5, f'Error: diff={diff}' - hash_value ^= hash_tensor(combined_x) + for zero_copy in (False, True): + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + async_finish=not return_recv_hook, + return_recv_hook=return_recv_hook, out=out) + hook() if return_recv_hook else event.current_stream_wait() + if do_check: + diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + assert diff < 1e-5, f'Error: diff={diff}' + hash_value ^= hash_tensor(combined_x) def create_test_cast_with_outliers(num_outliers): tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') @@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, hook() # noinspection PyShadowingNames - def test_func(return_recv_hook): + def test_func(zero_copy: bool, return_recv_hook: bool): recv_x, recv_count, handle, event, hook = \ buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - return_recv_hook=return_recv_hook) + zero_copy=zero_copy, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None # Calculate bandwidth @@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, num_combine_comm_bytes += num_bf16_bytes * num_selections # Dispatch + combine testing - avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) + avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False)) print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) # Separate profiling for return_recv_hook in (False, True): group.barrier() - dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), + dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook), kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, suppress_kineto_output=True) if not return_recv_hook: