mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
@@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
|
||||
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
|
||||
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
|
||||
@@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
|
||||
barrier_signal_ptrs_gpu, rank, num_ranks,
|
||||
comm_stream);
|
||||
|
||||
// Assign bias pointers
|
||||
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
|
||||
void* bias_ptrs[2] = {nullptr, nullptr};
|
||||
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
|
||||
auto bias = bias_opts[i].value();
|
||||
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
|
||||
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
|
||||
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
|
||||
bias_ptrs[i] = bias.data_ptr();
|
||||
}
|
||||
|
||||
// Combine data
|
||||
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
|
||||
@@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
<= num_nvl_bytes);
|
||||
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
|
||||
recv_x.data_ptr(), recv_topk_weights_ptr,
|
||||
x.data_ptr(), topk_weights_ptr,
|
||||
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
|
||||
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
|
||||
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
|
||||
buffer_ptrs_gpu, rank, num_ranks,
|
||||
@@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
if (allocate_on_comm_stream)
|
||||
t.record_stream(compute_stream);
|
||||
}
|
||||
for (auto& to: {topk_weights, recv_topk_weights}) {
|
||||
for (auto& to: {topk_weights, recv_topk_weights, bias_0, bias_1}) {
|
||||
to.has_value() ? to->record_stream(comm_stream) : void();
|
||||
if (allocate_on_comm_stream)
|
||||
to.has_value() ? to->record_stream(compute_stream) : void();
|
||||
@@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
|
||||
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
|
||||
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||
@@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
barrier_signal_ptrs_gpu, rank, comm_stream,
|
||||
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
||||
num_nvl_bytes, false, low_latency_mode);
|
||||
|
||||
// Assign bias pointers
|
||||
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
|
||||
void* bias_ptrs[2] = {nullptr, nullptr};
|
||||
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
|
||||
auto bias = bias_opts[i].value();
|
||||
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
|
||||
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
|
||||
EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);
|
||||
bias_ptrs[i] = bias.data_ptr();
|
||||
}
|
||||
|
||||
// Launch data combine
|
||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
|
||||
combined_x.data_ptr(), combined_topk_weights_ptr,
|
||||
is_combined_token_in_rank.data_ptr<bool>(),
|
||||
x.data_ptr(), topk_weights_ptr,
|
||||
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
|
||||
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
|
||||
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
|
||||
num_tokens, num_combined_tokens, hidden, num_topk,
|
||||
@@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
if (allocate_on_comm_stream)
|
||||
t.record_stream(compute_stream);
|
||||
}
|
||||
for (auto& to: {topk_weights, combined_topk_weights}) {
|
||||
for (auto& to: {topk_weights, combined_topk_weights, bias_0, bias_1}) {
|
||||
to.has_value() ? to->record_stream(comm_stream) : void();
|
||||
if (allocate_on_comm_stream)
|
||||
to.has_value() ? to->record_stream(compute_stream) : void();
|
||||
|
||||
Reference in New Issue
Block a user