From bd429ffefc50cad10b5b17d63eed47c0ab8db72a Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Wed, 25 Jun 2025 13:04:20 +0800 Subject: [PATCH] Support bias. (#257) * Support bias. * Fix. * Fix style. --- csrc/deep_ep.cpp | 32 ++++++++++++++++++++++++++++---- csrc/deep_ep.hpp | 2 ++ csrc/kernels/api.cuh | 2 ++ csrc/kernels/internode.cu | 34 +++++++++++++++++++++++++++++----- csrc/kernels/intranode.cu | 19 ++++++++++++++++++- deep_ep/buffer.py | 22 ++++++++++++++++++---- tests/test_internode.py | 6 ++++-- 7 files changed, 101 insertions(+), 16 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index d09e52e..e918adc 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional, std::optional> Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); @@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); @@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, @@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optionalrecord_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional, std::optional> Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, @@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(), combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), - x.data_ptr(), topk_weights_ptr, + x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), num_tokens, num_combined_tokens, hidden, num_topk, @@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optionalrecord_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index d96d726..00f8d0c 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -112,6 +112,7 @@ public: std::tuple, std::optional> intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); @@ -127,6 +128,7 @@ public: std::tuple, std::optional> internode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 0895456..84703c9 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -68,6 +68,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_ranks, @@ -121,6 +122,7 @@ void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index da1d203..47f62ac 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -1139,10 +1139,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to is_cached_dispatch, cpu_rdma_team); } -template +template __device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, int hidden_int4, int num_topk, int4* combined_row, float* combined_topk_weights, + const int4* bias_0_int4, const int4* bias_1_int4, int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); @@ -1160,15 +1161,33 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, // Reduce data #pragma unroll for (int i = lane_id; i < hidden_int4; i += 32) { + // Read bias + // TODO: make it as a finer-grained template + int4 bias_0_value_int4, bias_1_value_int4; + if (kMaybeWithBias) { + bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0); + bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0); + } + // Read buffers // TODO: maybe too many registers here int4 recv_value_int4[kMaxNumRanks]; #pragma unroll for (int j = 0; j < num_topk_ranks; ++ j) recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i); + + // Clean + // Reduce bias + float values[kDtypePerInt4] = {0}; + if (kMaybeWithBias) { + auto bias_0_values = reinterpret_cast(&bias_0_value_int4); + auto bias_1_values = reinterpret_cast(&bias_1_value_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + values[j] = static_cast(bias_0_values[j]) + static_cast(bias_1_values[j]); + } // Reduce all-to-all results - float values[kDtypePerInt4] = {0}; #pragma unroll for (int j = 0; j < num_topk_ranks; ++ j) { auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); @@ -1210,6 +1229,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, combine(int4* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const int4* x, const float* topk_weights, + const int4* bias_0, const int4* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, @@ -1470,12 +1490,12 @@ combine(int4* combined_x, float* combined_topk_weights, void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; - combine_token(expected_head >= 0, + combine_token(expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, static_cast(shifted), reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), - num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); + nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); // Update head if (lane_id < NUM_MAX_NVL_PEERS) @@ -1549,11 +1569,13 @@ combine(int4* combined_x, float* combined_topk_weights, // Combine current token auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);}; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; - combine_token(expected_head >= 0, + combine_token(expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, combined_x + token_idx * hidden_int4, combined_topk_weights + token_idx * num_topk, + bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4, + bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4, num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); } @@ -1614,6 +1636,7 @@ void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, @@ -1628,6 +1651,7 @@ void combine(cudaDataType_t type, LAUNCH_KERNEL(&cfg, combine_func, \ reinterpret_cast(combined_x), combined_topk_weights, is_combined_token_in_rank, \ reinterpret_cast(x), topk_weights, \ + reinterpret_cast(bias_0), reinterpret_cast(bias_1), \ combined_rdma_head, combined_nvl_head, \ reinterpret_cast(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ num_tokens, num_combined_tokens, hidden, num_topk, \ diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 0f3cb7e..ba8d005 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -587,6 +587,7 @@ template(x); + auto bias_0_int4 = reinterpret_cast(bias_0); + auto bias_1_int4 = reinterpret_cast(bias_1); auto recv_int4 = reinterpret_cast(recv_x); // TMA stuffs @@ -809,14 +812,26 @@ combine(dtype_t* recv_x, float* recv_topk_weights, EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count"); #pragma unroll for (int i = lane_id; i < hidden_int4; i += 32) { + // Read bias + // TODO: make it as a template + int4 bias_0_value_int4 = bias_0_int4 != nullptr ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0); + int4 bias_1_value_int4 = bias_1_int4 != nullptr ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0); + // Read buffers int4 recv_value_int4[kNumRanks]; #pragma unroll for (int j = 0; j < num_topk_ranks; ++ j) recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i); + // Reduce bias + float values[kDtypePerInt4]; + auto bias_0_values = reinterpret_cast(&bias_0_value_int4); + auto bias_1_values = reinterpret_cast(&bias_1_value_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + values[j] = static_cast(bias_0_values[j]) + static_cast(bias_1_values[j]); + // Reduce all-to-all results - float values[kDtypePerInt4] = {0}; #pragma unroll for (int j = 0; j < num_topk_ranks; ++ j) { auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); @@ -887,6 +902,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_ranks, @@ -904,6 +920,7 @@ void combine(cudaDataType_t type, LAUNCH_KERNEL(&cfg, kernel, \ reinterpret_cast(recv_x), recv_topk_weights, \ reinterpret_cast(x), topk_weights, \ + reinterpret_cast(bias_0), reinterpret_cast(bias_1), \ src_idx, rank_prefix_matrix, channel_prefix_matrix, \ send_head, num_tokens, num_recv_tokens, hidden, num_topk, \ buffer_ptrs, rank, \ diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 3738aba..755f88a 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -176,6 +176,16 @@ class Buffer: assert tensor.numel() >= size.numel() return tensor[:size.numel()].view(size) + @staticmethod + def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): + bias_0, bias_1 = None, None + if isinstance(bias, torch.Tensor): + bias_0 = bias + elif isinstance(bias, tuple): + assert len(bias) == 2 + bias_0, bias_1 = bias + return bias_0, bias_1 + @staticmethod def get_dispatch_config(num_ranks: int) -> Config: """ @@ -346,6 +356,7 @@ class Buffer: # noinspection PyTypeChecker def combine(self, x: torch.Tensor, handle: Tuple, topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -376,14 +387,15 @@ class Buffer: # Internode if self.runtime.get_num_rdma_ranks() > 1: - return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream) + return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream) # NOTES: the second `_` is for the sending side, so we should use the third one rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle + bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel recv_x, recv_topk_weights, event = self.runtime.intranode_combine( - x, topk_weights, + x, topk_weights, bias_0, bias_1, src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) return recv_x, recv_topk_weights, EventOverlap(event) @@ -442,6 +454,7 @@ class Buffer: # noinspection PyTypeChecker def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -452,15 +465,16 @@ class Buffer: """ assert config is not None - # Unpack handle + # Unpack handle and bias is_combined_token_in_rank, \ _, _, \ rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \ src_meta, send_rdma_head, send_nvl_head = handle + bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel combined_x, combined_topk_weights, event = self.runtime.internode_combine( - x, topk_weights, + x, topk_weights, bias_0, bias_1, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None), diff --git a/tests/test_internode.py b/tests/test_internode.py index e84f4eb..bb5cd58 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in check_data(recv_x, recv_gbl_rank_prefix_sum) # Test combine - combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} + bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode} if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () - check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) + check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) ref_x = x_pure_rand if current_x is x_pure_rand else x assert calc_diff(check_x, ref_x) < 5e-6 if with_topk: