Support bias. (#257)

* Support bias.

* Fix.

* Fix style.
This commit is contained in:
Shangyan Zhou
2025-06-25 13:04:20 +08:00
parent 85adc566e2
commit 4931324861
7 changed files with 101 additions and 16 deletions

View File

@@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
@@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
barrier_signal_ptrs_gpu, rank, num_ranks,
comm_stream);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
@@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
<= num_nvl_bytes);
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
recv_x.data_ptr(), recv_topk_weights_ptr,
x.data_ptr(), topk_weights_ptr,
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
buffer_ptrs_gpu, rank, num_ranks,
@@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, recv_topk_weights}) {
for (auto& to: {topk_weights, recv_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
@@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
@@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, false, low_latency_mode);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
combined_x.data_ptr(), combined_topk_weights_ptr,
is_combined_token_in_rank.data_ptr<bool>(),
x.data_ptr(), topk_weights_ptr,
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
num_tokens, num_combined_tokens, hidden, num_topk,
@@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, combined_topk_weights}) {
for (auto& to: {topk_weights, combined_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();