Support CUDA graph for intranode normal kernels (#203)

This commit is contained in:
Chenggang Zhao 2025-06-11 11:08:54 +08:00 committed by GitHub
parent 8da2d7b38d
commit a8299ca7c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 86 additions and 38 deletions

View File

@ -162,6 +162,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
allocate_on_comm_stream=previous_event is not None)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# Unless you specify `num_worst_tokens`, but this flag is for intranode only
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,

View File

@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
int expert_alignment, int num_worst_tokens, const Config& config,
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
bool cached_mode = cached_rank_prefix_matrix.has_value();
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
@ -412,6 +413,14 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
comm_stream, num_channels);
if (num_worst_tokens > 0) {
// No CPU sync, just allocate the worst case
num_recv_tokens = num_worst_tokens;
// Must be forward with top-k stuffs
EP_HOST_ASSERT(topk_idx.has_value());
EP_HOST_ASSERT(topk_weights.has_value());
} else {
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
@ -432,6 +441,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head.data_ptr<int>(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);

View File

@ -108,7 +108,8 @@ public:
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
int expert_alignment, int num_worst_tokens, const Config& config,
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,

View File

@ -45,7 +45,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);

View File

@ -25,7 +25,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
per_rank_buffer = static_cast<int*>(buffer_ptrs[thread_id]);
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
}
@ -48,7 +48,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto local_per_rank_buffer = static_cast<int*>(buffer_ptrs[rank]);
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 1; i < kNumRanks; ++ i)
@ -141,7 +141,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
ptr[i] = rank_prefix_matrix[i];
@ -173,7 +173,7 @@ __global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
@ -196,7 +196,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Calculate pointers by the specific layout
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
int target_rank = is_sender ? rank : responsible_rank;
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
@ -286,7 +286,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the following data
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
@ -349,7 +349,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
// Calculate offset first
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;
// Receive channel offset
@ -372,7 +372,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto start_time = clock64();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) {
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are the same
while (recv_thread_id_in_rank == 0) {
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
@ -450,12 +450,25 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
if (lane_id == 0)
tma_store_wait();
}
// Clean unused `recv_topk_idx` as -1
if (num_worst_tokens > 0) {
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];
const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads;
const auto clean_end = num_worst_tokens * num_topk;
const auto clean_stride = num_sms * kNumThreads;
#pragma unroll
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
recv_topk_idx[i] = -1;
}
}
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
@ -470,7 +483,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
@ -493,7 +506,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
@ -590,7 +603,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
// Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[send_rank_id]));
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
@ -682,7 +695,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
if (thread_id < 32) {
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int* channel_head_idx_ptr = static_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater
@ -720,13 +733,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto channel_rank_offset = responsible_channel * kNumRanks + i;
auto num_channels_total = num_channels * kNumRanks;
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
ptr = reinterpret_cast<void*>(static_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);

View File

@ -249,7 +249,8 @@ class Buffer:
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,
expert_alignment: int = 1, num_worst_tokens: int = 0,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
@ -276,6 +277,8 @@ class Buffer:
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
will be CUDA-graph compatible. Please also notice that this flag is for intranode only.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
@ -296,6 +299,7 @@ class Buffer:
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
@ -308,14 +312,16 @@ class Buffer:
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
x, x_scales, None, None,
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
expert_alignment, num_worst_tokens, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
expert_alignment, num_worst_tokens, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)

View File

@ -100,6 +100,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
recv_topk_weights_clone = None
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
@ -107,10 +108,26 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, rank_prefix_matrix)
# Test `num_worst_tokens != 0`
if with_topk:
num_worst_tokens = num_tokens * num_ranks
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0)
assert num_worst_tokens == recv_worst_topk_weights.size(0)
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}