mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
parent
b80e55e21f
commit
bd429ffefc
@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
|||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||||
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||||
|
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
|
||||||
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
|
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
|
||||||
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||||
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
|
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
|
||||||
@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
|||||||
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
|
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
|
||||||
barrier_signal_ptrs_gpu, rank, num_ranks,
|
barrier_signal_ptrs_gpu, rank, num_ranks,
|
||||||
comm_stream);
|
comm_stream);
|
||||||
|
|
||||||
|
// Assign bias pointers
|
||||||
|
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
|
||||||
|
void* bias_ptrs[2] = {nullptr, nullptr};
|
||||||
|
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
|
||||||
|
auto bias = bias_opts[i].value();
|
||||||
|
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
|
||||||
|
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
|
||||||
|
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
|
||||||
|
bias_ptrs[i] = bias.data_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
// Combine data
|
// Combine data
|
||||||
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
|
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
|
||||||
@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
|||||||
<= num_nvl_bytes);
|
<= num_nvl_bytes);
|
||||||
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
|
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
|
||||||
recv_x.data_ptr(), recv_topk_weights_ptr,
|
recv_x.data_ptr(), recv_topk_weights_ptr,
|
||||||
x.data_ptr(), topk_weights_ptr,
|
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
|
||||||
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
|
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
|
||||||
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
|
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
|
||||||
buffer_ptrs_gpu, rank, num_ranks,
|
buffer_ptrs_gpu, rank, num_ranks,
|
||||||
@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
|||||||
if (allocate_on_comm_stream)
|
if (allocate_on_comm_stream)
|
||||||
t.record_stream(compute_stream);
|
t.record_stream(compute_stream);
|
||||||
}
|
}
|
||||||
for (auto& to: {topk_weights, recv_topk_weights}) {
|
for (auto& to: {topk_weights, recv_topk_weights, bias_0, bias_1}) {
|
||||||
to.has_value() ? to->record_stream(comm_stream) : void();
|
to.has_value() ? to->record_stream(comm_stream) : void();
|
||||||
if (allocate_on_comm_stream)
|
if (allocate_on_comm_stream)
|
||||||
to.has_value() ? to->record_stream(compute_stream) : void();
|
to.has_value() ? to->record_stream(compute_stream) : void();
|
||||||
@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
|||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||||
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||||
|
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
|
||||||
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
|
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
|
||||||
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
||||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||||
@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
|||||||
barrier_signal_ptrs_gpu, rank, comm_stream,
|
barrier_signal_ptrs_gpu, rank, comm_stream,
|
||||||
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
||||||
num_nvl_bytes, false, low_latency_mode);
|
num_nvl_bytes, false, low_latency_mode);
|
||||||
|
|
||||||
|
// Assign bias pointers
|
||||||
|
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
|
||||||
|
void* bias_ptrs[2] = {nullptr, nullptr};
|
||||||
|
for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) {
|
||||||
|
auto bias = bias_opts[i].value();
|
||||||
|
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
|
||||||
|
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
|
||||||
|
EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);
|
||||||
|
bias_ptrs[i] = bias.data_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
// Launch data combine
|
// Launch data combine
|
||||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||||
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
|
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
|
||||||
combined_x.data_ptr(), combined_topk_weights_ptr,
|
combined_x.data_ptr(), combined_topk_weights_ptr,
|
||||||
is_combined_token_in_rank.data_ptr<bool>(),
|
is_combined_token_in_rank.data_ptr<bool>(),
|
||||||
x.data_ptr(), topk_weights_ptr,
|
x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
|
||||||
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
|
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
|
||||||
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
|
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
|
||||||
num_tokens, num_combined_tokens, hidden, num_topk,
|
num_tokens, num_combined_tokens, hidden, num_topk,
|
||||||
@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
|||||||
if (allocate_on_comm_stream)
|
if (allocate_on_comm_stream)
|
||||||
t.record_stream(compute_stream);
|
t.record_stream(compute_stream);
|
||||||
}
|
}
|
||||||
for (auto& to: {topk_weights, combined_topk_weights}) {
|
for (auto& to: {topk_weights, combined_topk_weights, bias_0, bias_1}) {
|
||||||
to.has_value() ? to->record_stream(comm_stream) : void();
|
to.has_value() ? to->record_stream(comm_stream) : void();
|
||||||
if (allocate_on_comm_stream)
|
if (allocate_on_comm_stream)
|
||||||
to.has_value() ? to->record_stream(compute_stream) : void();
|
to.has_value() ? to->record_stream(compute_stream) : void();
|
||||||
|
@ -112,6 +112,7 @@ public:
|
|||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||||
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||||
|
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
|
||||||
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
|
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
|
||||||
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||||
|
|
||||||
@ -127,6 +128,7 @@ public:
|
|||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||||
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||||
|
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
|
||||||
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
|
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
|
||||||
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
||||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||||
|
@ -68,6 +68,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
|||||||
void combine(cudaDataType_t type,
|
void combine(cudaDataType_t type,
|
||||||
void* recv_x, float* recv_topk_weights,
|
void* recv_x, float* recv_topk_weights,
|
||||||
const void* x, const float* topk_weights,
|
const void* x, const float* topk_weights,
|
||||||
|
const void* bias_0, const void* bias_1,
|
||||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||||
void** buffer_ptrs, int rank, int num_ranks,
|
void** buffer_ptrs, int rank, int num_ranks,
|
||||||
@ -121,6 +122,7 @@ void combine(cudaDataType_t type,
|
|||||||
void* combined_x, float* combined_topk_weights,
|
void* combined_x, float* combined_topk_weights,
|
||||||
const bool* is_combined_token_in_rank,
|
const bool* is_combined_token_in_rank,
|
||||||
const void* x, const float* topk_weights,
|
const void* x, const float* topk_weights,
|
||||||
|
const void* bias_0, const void* bias_1,
|
||||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||||
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||||
|
@ -1139,10 +1139,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
|||||||
is_cached_dispatch, cpu_rdma_team);
|
is_cached_dispatch, cpu_rdma_team);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
|
template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
|
||||||
__device__ int combine_token(bool is_token_in_rank, int head_idx,
|
__device__ int combine_token(bool is_token_in_rank, int head_idx,
|
||||||
int lane_id, int hidden_int4, int num_topk,
|
int lane_id, int hidden_int4, int num_topk,
|
||||||
int4* combined_row, float* combined_topk_weights,
|
int4* combined_row, float* combined_topk_weights,
|
||||||
|
const int4* bias_0_int4, const int4* bias_1_int4,
|
||||||
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
|
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
|
||||||
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
||||||
|
|
||||||
@ -1160,15 +1161,33 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
|
|||||||
// Reduce data
|
// Reduce data
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = lane_id; i < hidden_int4; i += 32) {
|
for (int i = lane_id; i < hidden_int4; i += 32) {
|
||||||
|
// Read bias
|
||||||
|
// TODO: make it as a finer-grained template
|
||||||
|
int4 bias_0_value_int4, bias_1_value_int4;
|
||||||
|
if (kMaybeWithBias) {
|
||||||
|
bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0);
|
||||||
|
bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
// Read buffers
|
// Read buffers
|
||||||
// TODO: maybe too many registers here
|
// TODO: maybe too many registers here
|
||||||
int4 recv_value_int4[kMaxNumRanks];
|
int4 recv_value_int4[kMaxNumRanks];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||||
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
|
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
|
||||||
|
|
||||||
|
// Clean
|
||||||
|
// Reduce bias
|
||||||
|
float values[kDtypePerInt4] = {0};
|
||||||
|
if (kMaybeWithBias) {
|
||||||
|
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
|
||||||
|
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||||
|
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
|
||||||
|
}
|
||||||
|
|
||||||
// Reduce all-to-all results
|
// Reduce all-to-all results
|
||||||
float values[kDtypePerInt4] = {0};
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < num_topk_ranks; ++ j) {
|
for (int j = 0; j < num_topk_ranks; ++ j) {
|
||||||
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
||||||
@ -1210,6 +1229,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
|||||||
combine(int4* combined_x, float* combined_topk_weights,
|
combine(int4* combined_x, float* combined_topk_weights,
|
||||||
const bool* is_combined_token_in_rank,
|
const bool* is_combined_token_in_rank,
|
||||||
const int4* x, const float* topk_weights,
|
const int4* x, const float* topk_weights,
|
||||||
|
const int4* bias_0, const int4* bias_1,
|
||||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||||
const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||||
@ -1470,12 +1490,12 @@ combine(int4* combined_x, float* combined_topk_weights,
|
|||||||
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
|
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||||
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
|
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
|
||||||
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
|
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
|
||||||
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
||||||
expected_head, lane_id,
|
expected_head, lane_id,
|
||||||
hidden_int4, num_topk,
|
hidden_int4, num_topk,
|
||||||
static_cast<int4*>(shifted),
|
static_cast<int4*>(shifted),
|
||||||
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
||||||
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
||||||
|
|
||||||
// Update head
|
// Update head
|
||||||
if (lane_id < NUM_MAX_NVL_PEERS)
|
if (lane_id < NUM_MAX_NVL_PEERS)
|
||||||
@ -1549,11 +1569,13 @@ combine(int4* combined_x, float* combined_topk_weights,
|
|||||||
// Combine current token
|
// Combine current token
|
||||||
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
|
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
|
||||||
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
|
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
|
||||||
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
|
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
|
||||||
expected_head, lane_id,
|
expected_head, lane_id,
|
||||||
hidden_int4, num_topk,
|
hidden_int4, num_topk,
|
||||||
combined_x + token_idx * hidden_int4,
|
combined_x + token_idx * hidden_int4,
|
||||||
combined_topk_weights + token_idx * num_topk,
|
combined_topk_weights + token_idx * num_topk,
|
||||||
|
bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4,
|
||||||
|
bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4,
|
||||||
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
|
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1614,6 +1636,7 @@ void combine(cudaDataType_t type,
|
|||||||
void* combined_x, float* combined_topk_weights,
|
void* combined_x, float* combined_topk_weights,
|
||||||
const bool* is_combined_token_in_rank,
|
const bool* is_combined_token_in_rank,
|
||||||
const void* x, const float* topk_weights,
|
const void* x, const float* topk_weights,
|
||||||
|
const void* bias_0, const void* bias_1,
|
||||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||||
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||||
@ -1628,6 +1651,7 @@ void combine(cudaDataType_t type,
|
|||||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||||
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
|
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
|
||||||
reinterpret_cast<const int4*>(x), topk_weights, \
|
reinterpret_cast<const int4*>(x), topk_weights, \
|
||||||
|
reinterpret_cast<const int4*>(bias_0), reinterpret_cast<const int4*>(bias_1), \
|
||||||
combined_rdma_head, combined_nvl_head, \
|
combined_rdma_head, combined_nvl_head, \
|
||||||
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
|
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
|
||||||
num_tokens, num_combined_tokens, hidden, num_topk, \
|
num_tokens, num_combined_tokens, hidden, num_topk, \
|
||||||
|
@ -587,6 +587,7 @@ template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWa
|
|||||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||||
combine(dtype_t* recv_x, float* recv_topk_weights,
|
combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||||
const dtype_t* x, const float* topk_weights,
|
const dtype_t* x, const float* topk_weights,
|
||||||
|
const dtype_t* bias_0, const dtype_t* bias_1,
|
||||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||||
void** buffer_ptrs, int rank,
|
void** buffer_ptrs, int rank,
|
||||||
@ -602,6 +603,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
||||||
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
|
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
|
||||||
auto x_int4 = reinterpret_cast<const int4*>(x);
|
auto x_int4 = reinterpret_cast<const int4*>(x);
|
||||||
|
auto bias_0_int4 = reinterpret_cast<const int4*>(bias_0);
|
||||||
|
auto bias_1_int4 = reinterpret_cast<const int4*>(bias_1);
|
||||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||||
|
|
||||||
// TMA stuffs
|
// TMA stuffs
|
||||||
@ -809,14 +812,26 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
|
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = lane_id; i < hidden_int4; i += 32) {
|
for (int i = lane_id; i < hidden_int4; i += 32) {
|
||||||
|
// Read bias
|
||||||
|
// TODO: make it as a template
|
||||||
|
int4 bias_0_value_int4 = bias_0_int4 != nullptr ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);
|
||||||
|
int4 bias_1_value_int4 = bias_1_int4 != nullptr ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);
|
||||||
|
|
||||||
// Read buffers
|
// Read buffers
|
||||||
int4 recv_value_int4[kNumRanks];
|
int4 recv_value_int4[kNumRanks];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||||
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
|
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
|
||||||
|
|
||||||
|
// Reduce bias
|
||||||
|
float values[kDtypePerInt4];
|
||||||
|
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
|
||||||
|
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||||
|
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
|
||||||
|
|
||||||
// Reduce all-to-all results
|
// Reduce all-to-all results
|
||||||
float values[kDtypePerInt4] = {0};
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < num_topk_ranks; ++ j) {
|
for (int j = 0; j < num_topk_ranks; ++ j) {
|
||||||
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
||||||
@ -887,6 +902,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
void combine(cudaDataType_t type,
|
void combine(cudaDataType_t type,
|
||||||
void* recv_x, float* recv_topk_weights,
|
void* recv_x, float* recv_topk_weights,
|
||||||
const void* x, const float* topk_weights,
|
const void* x, const float* topk_weights,
|
||||||
|
const void* bias_0, const void* bias_1,
|
||||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||||
void** buffer_ptrs, int rank, int num_ranks,
|
void** buffer_ptrs, int rank, int num_ranks,
|
||||||
@ -904,6 +920,7 @@ void combine(cudaDataType_t type,
|
|||||||
LAUNCH_KERNEL(&cfg, kernel, \
|
LAUNCH_KERNEL(&cfg, kernel, \
|
||||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||||
|
reinterpret_cast<const dtype*>(bias_0), reinterpret_cast<const dtype*>(bias_1), \
|
||||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
||||||
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
||||||
buffer_ptrs, rank, \
|
buffer_ptrs, rank, \
|
||||||
|
@ -176,6 +176,16 @@ class Buffer:
|
|||||||
assert tensor.numel() >= size.numel()
|
assert tensor.numel() >= size.numel()
|
||||||
return tensor[:size.numel()].view(size)
|
return tensor[:size.numel()].view(size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
|
||||||
|
bias_0, bias_1 = None, None
|
||||||
|
if isinstance(bias, torch.Tensor):
|
||||||
|
bias_0 = bias
|
||||||
|
elif isinstance(bias, tuple):
|
||||||
|
assert len(bias) == 2
|
||||||
|
bias_0, bias_1 = bias
|
||||||
|
return bias_0, bias_1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dispatch_config(num_ranks: int) -> Config:
|
def get_dispatch_config(num_ranks: int) -> Config:
|
||||||
"""
|
"""
|
||||||
@ -346,6 +356,7 @@ class Buffer:
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
def combine(self, x: torch.Tensor, handle: Tuple,
|
def combine(self, x: torch.Tensor, handle: Tuple,
|
||||||
topk_weights: Optional[torch.Tensor] = None,
|
topk_weights: Optional[torch.Tensor] = None,
|
||||||
|
bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
config: Optional[Config] = None,
|
config: Optional[Config] = None,
|
||||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||||
allocate_on_comm_stream: bool = False) -> \
|
allocate_on_comm_stream: bool = False) -> \
|
||||||
@ -376,14 +387,15 @@ class Buffer:
|
|||||||
|
|
||||||
# Internode
|
# Internode
|
||||||
if self.runtime.get_num_rdma_ranks() > 1:
|
if self.runtime.get_num_rdma_ranks() > 1:
|
||||||
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
|
return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||||
|
|
||||||
# NOTES: the second `_` is for the sending side, so we should use the third one
|
# NOTES: the second `_` is for the sending side, so we should use the third one
|
||||||
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
|
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
|
||||||
|
bias_0, bias_1 = Buffer._unpack_bias(bias)
|
||||||
|
|
||||||
# Launch the kernel
|
# Launch the kernel
|
||||||
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
|
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
|
||||||
x, topk_weights,
|
x, topk_weights, bias_0, bias_1,
|
||||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
|
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
|
||||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||||
return recv_x, recv_topk_weights, EventOverlap(event)
|
return recv_x, recv_topk_weights, EventOverlap(event)
|
||||||
@ -442,6 +454,7 @@ class Buffer:
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
|
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
|
||||||
topk_weights: Optional[torch.Tensor] = None,
|
topk_weights: Optional[torch.Tensor] = None,
|
||||||
|
bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
config: Optional[Config] = None,
|
config: Optional[Config] = None,
|
||||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||||
allocate_on_comm_stream: bool = False) -> \
|
allocate_on_comm_stream: bool = False) -> \
|
||||||
@ -452,15 +465,16 @@ class Buffer:
|
|||||||
"""
|
"""
|
||||||
assert config is not None
|
assert config is not None
|
||||||
|
|
||||||
# Unpack handle
|
# Unpack handle and bias
|
||||||
is_combined_token_in_rank, \
|
is_combined_token_in_rank, \
|
||||||
_, _, \
|
_, _, \
|
||||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
|
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
|
||||||
src_meta, send_rdma_head, send_nvl_head = handle
|
src_meta, send_rdma_head, send_nvl_head = handle
|
||||||
|
bias_0, bias_1 = Buffer._unpack_bias(bias)
|
||||||
|
|
||||||
# Launch the kernel
|
# Launch the kernel
|
||||||
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
|
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
|
||||||
x, topk_weights,
|
x, topk_weights, bias_0, bias_1,
|
||||||
src_meta, is_combined_token_in_rank,
|
src_meta, is_combined_token_in_rank,
|
||||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
|
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
|
||||||
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
|
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
|
||||||
|
@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
|||||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||||
|
|
||||||
# Test combine
|
# Test combine
|
||||||
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||||
|
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||||
|
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||||
if with_topk:
|
if with_topk:
|
||||||
combine_args.update({'topk_weights': recv_topk_weights})
|
combine_args.update({'topk_weights': recv_topk_weights})
|
||||||
if previous_mode:
|
if previous_mode:
|
||||||
combine_args.update({'previous_event': buffer.capture()})
|
combine_args.update({'previous_event': buffer.capture()})
|
||||||
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
||||||
event.current_stream_wait() if async_mode else ()
|
event.current_stream_wait() if async_mode else ()
|
||||||
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||||
assert calc_diff(check_x, ref_x) < 5e-6
|
assert calc_diff(check_x, ref_x) < 5e-6
|
||||||
if with_topk:
|
if with_topk:
|
||||||
|
Loading…
Reference in New Issue
Block a user