Support bias. (#257)

* Support bias.

* Fix.

* Fix style.
This commit is contained in:
Shangyan Zhou 2025-06-25 13:04:20 +08:00 committed by GitHub
parent b80e55e21f
commit bd429ffefc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 101 additions and 16 deletions

View File

@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
barrier_signal_ptrs_gpu, rank, num_ranks,
comm_stream);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
<= num_nvl_bytes);
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
recv_x.data_ptr(), recv_topk_weights_ptr,
x.data_ptr(), topk_weights_ptr,
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
buffer_ptrs_gpu, rank, num_ranks,
@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, recv_topk_weights}) {
for (auto& to: {topk_weights, recv_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, false, low_latency_mode);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
combined_x.data_ptr(), combined_topk_weights_ptr,
is_combined_token_in_rank.data_ptr<bool>(),
x.data_ptr(), topk_weights_ptr,
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
num_tokens, num_combined_tokens, hidden, num_topk,
@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, combined_topk_weights}) {
for (auto& to: {topk_weights, combined_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();

View File

@ -112,6 +112,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
@ -127,6 +128,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,

View File

@ -68,6 +68,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const void* bias_0, const void* bias_1,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
@ -121,6 +122,7 @@ void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const void* bias_0, const void* bias_1,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,

View File

@ -1139,10 +1139,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
is_cached_dispatch, cpu_rdma_team);
}
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
const int4* bias_0_int4, const int4* bias_1_int4,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
@ -1160,15 +1161,33 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
// Reduce data
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += 32) {
// Read bias
// TODO: make it as a finer-grained template
int4 bias_0_value_int4, bias_1_value_int4;
if (kMaybeWithBias) {
bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0);
bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0);
}
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Clean
// Reduce bias
float values[kDtypePerInt4] = {0};
if (kMaybeWithBias) {
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
}
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
@ -1210,6 +1229,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
combine(int4* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x, const float* topk_weights,
const int4* bias_0, const int4* bias_1,
const int* combined_rdma_head, const int* combined_nvl_head,
const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
@ -1470,12 +1490,12 @@ combine(int4* combined_x, float* combined_topk_weights,
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
static_cast<int4*>(shifted),
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head
if (lane_id < NUM_MAX_NVL_PEERS)
@ -1549,11 +1569,13 @@ combine(int4* combined_x, float* combined_topk_weights,
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4,
bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
}
@ -1614,6 +1636,7 @@ void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const void* bias_0, const void* bias_1,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
@ -1628,6 +1651,7 @@ void combine(cudaDataType_t type,
LAUNCH_KERNEL(&cfg, combine_func, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
reinterpret_cast<const int4*>(bias_0), reinterpret_cast<const int4*>(bias_1), \
combined_rdma_head, combined_nvl_head, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, \

View File

@ -587,6 +587,7 @@ template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWa
__global__ void __launch_bounds__(kNumThreads, 1)
combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights,
const dtype_t* bias_0, const dtype_t* bias_1,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank,
@ -602,6 +603,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
auto x_int4 = reinterpret_cast<const int4*>(x);
auto bias_0_int4 = reinterpret_cast<const int4*>(bias_0);
auto bias_1_int4 = reinterpret_cast<const int4*>(bias_1);
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
// TMA stuffs
@ -809,14 +812,26 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += 32) {
// Read bias
// TODO: make it as a template
int4 bias_0_value_int4 = bias_0_int4 != nullptr ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);
int4 bias_1_value_int4 = bias_1_int4 != nullptr ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);
// Read buffers
int4 recv_value_int4[kNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
// Reduce bias
float values[kDtypePerInt4];
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
@ -887,6 +902,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const void* bias_0, const void* bias_1,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
@ -904,6 +920,7 @@ void combine(cudaDataType_t type,
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \
reinterpret_cast<const dtype*>(bias_0), reinterpret_cast<const dtype*>(bias_1), \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
buffer_ptrs, rank, \

View File

@ -176,6 +176,16 @@ class Buffer:
assert tensor.numel() >= size.numel()
return tensor[:size.numel()].view(size)
@staticmethod
def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
bias_0, bias_1 = None, None
if isinstance(bias, torch.Tensor):
bias_0 = bias
elif isinstance(bias, tuple):
assert len(bias) == 2
bias_0, bias_1 = bias
return bias_0, bias_1
@staticmethod
def get_dispatch_config(num_ranks: int) -> Config:
"""
@ -346,6 +356,7 @@ class Buffer:
# noinspection PyTypeChecker
def combine(self, x: torch.Tensor, handle: Tuple,
topk_weights: Optional[torch.Tensor] = None,
bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
@ -376,14 +387,15 @@ class Buffer:
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream)
# NOTES: the second `_` is for the sending side, so we should use the third one
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
bias_0, bias_1 = Buffer._unpack_bias(bias)
# Launch the kernel
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
x, topk_weights,
x, topk_weights, bias_0, bias_1,
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return recv_x, recv_topk_weights, EventOverlap(event)
@ -442,6 +454,7 @@ class Buffer:
# noinspection PyTypeChecker
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
topk_weights: Optional[torch.Tensor] = None,
bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
@ -452,15 +465,16 @@ class Buffer:
"""
assert config is not None
# Unpack handle
# Unpack handle and bias
is_combined_token_in_rank, \
_, _, \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
src_meta, send_rdma_head, send_nvl_head = handle
bias_0, bias_1 = Buffer._unpack_bias(bias)
# Launch the kernel
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
x, topk_weights,
x, topk_weights, bias_0, bias_1,
src_meta, is_combined_token_in_rank,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),

View File

@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk: