Support CUDA graph for intranode normal kernels (#203)

This commit is contained in:
Chenggang Zhao
2025-06-11 11:08:54 +08:00
committed by GitHub
parent 8da2d7b38d
commit a8299ca7c2
7 changed files with 86 additions and 38 deletions

View File

@@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
int expert_alignment, int num_worst_tokens, const Config& config,
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
bool cached_mode = cached_rank_prefix_matrix.has_value();
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
@@ -412,25 +413,34 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
comm_stream, num_channels);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
if (num_worst_tokens > 0) {
// No CPU sync, just allocate the worst case
num_recv_tokens = num_worst_tokens;
// Read per-expert count
bool ready = (num_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
// Must be forward with top-k stuffs
EP_HOST_ASSERT(topk_idx.has_value());
EP_HOST_ASSERT(topk_weights.has_value());
} else {
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
if (ready)
break;
// Read per-expert count
bool ready = (num_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
throw std::runtime_error("DeepEP error: CPU recv timeout");
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
throw std::runtime_error("DeepEP error: CPU recv timeout");
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
// Allocate new tensors
@@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head.data_ptr<int>(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);