mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support CUDA graph for intranode normal kernels (#203)
This commit is contained in:
@@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
int expert_alignment, int num_worst_tokens, const Config& config,
|
||||
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
bool cached_mode = cached_rank_prefix_matrix.has_value();
|
||||
|
||||
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
@@ -412,25 +413,34 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
|
||||
comm_stream, num_channels);
|
||||
|
||||
// Synchronize total received tokens and tokens per expert
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
while (true) {
|
||||
// Read total count
|
||||
num_recv_tokens = static_cast<int>(*moe_recv_counter);
|
||||
if (num_worst_tokens > 0) {
|
||||
// No CPU sync, just allocate the worst case
|
||||
num_recv_tokens = num_worst_tokens;
|
||||
|
||||
// Read per-expert count
|
||||
bool ready = (num_recv_tokens >= 0);
|
||||
for (int i = 0; i < num_local_experts and ready; ++i)
|
||||
ready &= moe_recv_expert_counter[i] >= 0;
|
||||
// Must be forward with top-k stuffs
|
||||
EP_HOST_ASSERT(topk_idx.has_value());
|
||||
EP_HOST_ASSERT(topk_weights.has_value());
|
||||
} else {
|
||||
// Synchronize total received tokens and tokens per expert
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
while (true) {
|
||||
// Read total count
|
||||
num_recv_tokens = static_cast<int>(*moe_recv_counter);
|
||||
|
||||
if (ready)
|
||||
break;
|
||||
// Read per-expert count
|
||||
bool ready = (num_recv_tokens >= 0);
|
||||
for (int i = 0; i < num_local_experts and ready; ++i)
|
||||
ready &= moe_recv_expert_counter[i] >= 0;
|
||||
|
||||
// Timeout check
|
||||
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
|
||||
throw std::runtime_error("DeepEP error: CPU recv timeout");
|
||||
if (ready)
|
||||
break;
|
||||
|
||||
// Timeout check
|
||||
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
|
||||
throw std::runtime_error("DeepEP error: CPU recv timeout");
|
||||
}
|
||||
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
|
||||
}
|
||||
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
|
||||
}
|
||||
|
||||
// Allocate new tensors
|
||||
@@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
send_head.data_ptr<int>(),
|
||||
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
|
||||
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
||||
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
|
||||
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user