mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support CUDA graph for intranode normal kernels (#203)
This commit is contained in:
parent
8da2d7b38d
commit
a8299ca7c2
@ -162,6 +162,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
|||||||
allocate_on_comm_stream=previous_event is not None)
|
allocate_on_comm_stream=previous_event is not None)
|
||||||
# Do MoE dispatch
|
# Do MoE dispatch
|
||||||
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
|
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
|
||||||
|
# Unless you specify `num_worst_tokens`, but this flag is for intranode only
|
||||||
# For more advanced usages, please refer to the docs of the `dispatch` function
|
# For more advanced usages, please refer to the docs of the `dispatch` function
|
||||||
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
|
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
|
||||||
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
|
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||||
|
|||||||
@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
|||||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||||
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
||||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
int expert_alignment, int num_worst_tokens, const Config& config,
|
||||||
|
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||||
bool cached_mode = cached_rank_prefix_matrix.has_value();
|
bool cached_mode = cached_rank_prefix_matrix.has_value();
|
||||||
|
|
||||||
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
|
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||||
@ -412,25 +413,34 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
|||||||
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
|
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
|
||||||
comm_stream, num_channels);
|
comm_stream, num_channels);
|
||||||
|
|
||||||
// Synchronize total received tokens and tokens per expert
|
if (num_worst_tokens > 0) {
|
||||||
auto start_time = std::chrono::high_resolution_clock::now();
|
// No CPU sync, just allocate the worst case
|
||||||
while (true) {
|
num_recv_tokens = num_worst_tokens;
|
||||||
// Read total count
|
|
||||||
num_recv_tokens = static_cast<int>(*moe_recv_counter);
|
|
||||||
|
|
||||||
// Read per-expert count
|
// Must be forward with top-k stuffs
|
||||||
bool ready = (num_recv_tokens >= 0);
|
EP_HOST_ASSERT(topk_idx.has_value());
|
||||||
for (int i = 0; i < num_local_experts and ready; ++i)
|
EP_HOST_ASSERT(topk_weights.has_value());
|
||||||
ready &= moe_recv_expert_counter[i] >= 0;
|
} else {
|
||||||
|
// Synchronize total received tokens and tokens per expert
|
||||||
|
auto start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
while (true) {
|
||||||
|
// Read total count
|
||||||
|
num_recv_tokens = static_cast<int>(*moe_recv_counter);
|
||||||
|
|
||||||
if (ready)
|
// Read per-expert count
|
||||||
break;
|
bool ready = (num_recv_tokens >= 0);
|
||||||
|
for (int i = 0; i < num_local_experts and ready; ++i)
|
||||||
|
ready &= moe_recv_expert_counter[i] >= 0;
|
||||||
|
|
||||||
// Timeout check
|
if (ready)
|
||||||
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
|
break;
|
||||||
throw std::runtime_error("DeepEP error: CPU recv timeout");
|
|
||||||
|
// Timeout check
|
||||||
|
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
|
||||||
|
throw std::runtime_error("DeepEP error: CPU recv timeout");
|
||||||
|
}
|
||||||
|
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
|
||||||
}
|
}
|
||||||
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate new tensors
|
// Allocate new tensors
|
||||||
@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
|||||||
send_head.data_ptr<int>(),
|
send_head.data_ptr<int>(),
|
||||||
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
|
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
|
||||||
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
||||||
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||||
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
|
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
|
||||||
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
|
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
|
||||||
|
|
||||||
|
|||||||
@ -108,7 +108,8 @@ public:
|
|||||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||||
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
||||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
int expert_alignment, int num_worst_tokens, const Config& config,
|
||||||
|
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||||
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||||
|
|||||||
@ -45,7 +45,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
|||||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||||
void** buffer_ptrs, int rank, int num_ranks,
|
void** buffer_ptrs, int rank, int num_ranks,
|
||||||
cudaStream_t stream, int num_sms,
|
cudaStream_t stream, int num_sms,
|
||||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||||
|
|||||||
@ -25,7 +25,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
|||||||
|
|
||||||
int *per_rank_buffer, *per_expert_buffer;
|
int *per_rank_buffer, *per_expert_buffer;
|
||||||
if (thread_id < kNumRanks) {
|
if (thread_id < kNumRanks) {
|
||||||
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
|
per_rank_buffer = static_cast<int*>(buffer_ptrs[thread_id]);
|
||||||
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
|
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
|||||||
|
|
||||||
// Sum per-rank counts and return to CPU
|
// Sum per-rank counts and return to CPU
|
||||||
// Also pre-compute the prefix sum for data sending
|
// Also pre-compute the prefix sum for data sending
|
||||||
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
auto local_per_rank_buffer = static_cast<int*>(buffer_ptrs[rank]);
|
||||||
if (thread_id < kNumRanks) {
|
if (thread_id < kNumRanks) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 1; i < kNumRanks; ++ i)
|
for (int i = 1; i < kNumRanks; ++ i)
|
||||||
@ -141,7 +141,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
|||||||
|
|
||||||
// Copy and clean
|
// Copy and clean
|
||||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
||||||
ptr[i] = rank_prefix_matrix[i];
|
ptr[i] = rank_prefix_matrix[i];
|
||||||
@ -173,7 +173,7 @@ __global__ void __launch_bounds__(kNumThreads, 1)
|
|||||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||||
void** buffer_ptrs, int rank,
|
void** buffer_ptrs, int rank,
|
||||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||||
@ -196,7 +196,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
|||||||
|
|
||||||
// Calculate pointers by the specific layout
|
// Calculate pointers by the specific layout
|
||||||
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
|
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
|
||||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
|
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
|
||||||
int target_rank = is_sender ? rank : responsible_rank;
|
int target_rank = is_sender ? rank : responsible_rank;
|
||||||
auto num_channels_total = num_channels * kNumRanks;
|
auto num_channels_total = num_channels * kNumRanks;
|
||||||
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
|
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
|
||||||
@ -286,7 +286,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
|||||||
|
|
||||||
int chunk_token_idx = 0;
|
int chunk_token_idx = 0;
|
||||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||||
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
|
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the following data
|
||||||
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||||
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
|
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
|
||||||
|
|
||||||
@ -349,7 +349,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
|||||||
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
|
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
|
||||||
|
|
||||||
// Calculate offset first
|
// Calculate offset first
|
||||||
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
|
||||||
int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;
|
int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;
|
||||||
|
|
||||||
// Receive channel offset
|
// Receive channel offset
|
||||||
@ -372,7 +372,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
|||||||
auto start_time = clock64();
|
auto start_time = clock64();
|
||||||
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
||||||
while (num_tokens_to_recv > 0) {
|
while (num_tokens_to_recv > 0) {
|
||||||
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
|
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are the same
|
||||||
while (recv_thread_id_in_rank == 0) {
|
while (recv_thread_id_in_rank == 0) {
|
||||||
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
|
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
|
||||||
|
|
||||||
@ -450,12 +450,25 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
|||||||
if (lane_id == 0)
|
if (lane_id == 0)
|
||||||
tma_store_wait();
|
tma_store_wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Clean unused `recv_topk_idx` as -1
|
||||||
|
if (num_worst_tokens > 0) {
|
||||||
|
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
|
||||||
|
const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];
|
||||||
|
const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads;
|
||||||
|
const auto clean_end = num_worst_tokens * num_topk;
|
||||||
|
const auto clean_stride = num_sms * kNumThreads;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
|
||||||
|
recv_topk_idx[i] = -1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||||
void** buffer_ptrs, int rank, int num_ranks,
|
void** buffer_ptrs, int rank, int num_ranks,
|
||||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||||
constexpr int kNumThreads = 768;
|
constexpr int kNumThreads = 768;
|
||||||
@ -470,7 +483,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
|||||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||||
is_token_in_rank, channel_prefix_matrix, \
|
is_token_in_rank, channel_prefix_matrix, \
|
||||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||||
buffer_ptrs, rank, \
|
buffer_ptrs, rank, \
|
||||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||||
} break
|
} break
|
||||||
@ -493,7 +506,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
|||||||
|
|
||||||
// Clean
|
// Clean
|
||||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||||
ptr[i] = 0;
|
ptr[i] = 0;
|
||||||
@ -590,7 +603,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
||||||
|
|
||||||
// Calculate pointers by the specific layout
|
// Calculate pointers by the specific layout
|
||||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||||
auto num_channels_total = num_channels * kNumRanks;
|
auto num_channels_total = num_channels * kNumRanks;
|
||||||
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
|
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
|
||||||
|
|
||||||
@ -682,7 +695,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
||||||
|
|
||||||
if (thread_id < 32) {
|
if (thread_id < 32) {
|
||||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
|
int* channel_head_idx_ptr = static_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
|
||||||
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
||||||
|
|
||||||
// Queue head updater
|
// Queue head updater
|
||||||
@ -720,13 +733,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
auto channel_rank_offset = responsible_channel * kNumRanks + i;
|
auto channel_rank_offset = responsible_channel * kNumRanks + i;
|
||||||
auto num_channels_total = num_channels * kNumRanks;
|
auto num_channels_total = num_channels * kNumRanks;
|
||||||
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
|
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
|
||||||
|
|
||||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||||
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||||
|
|
||||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||||
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
|
ptr = reinterpret_cast<void*>(static_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
|
||||||
|
|
||||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||||
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||||
|
|||||||
@ -249,7 +249,8 @@ class Buffer:
|
|||||||
handle: Optional[Tuple] = None,
|
handle: Optional[Tuple] = None,
|
||||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,
|
||||||
|
expert_alignment: int = 1, num_worst_tokens: int = 0,
|
||||||
config: Optional[Config] = None,
|
config: Optional[Config] = None,
|
||||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||||
allocate_on_comm_stream: bool = False) -> \
|
allocate_on_comm_stream: bool = False) -> \
|
||||||
@ -276,6 +277,8 @@ class Buffer:
|
|||||||
`-1` means no selections.
|
`-1` means no selections.
|
||||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
|
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
|
||||||
expert_alignment: align the number of tokens received by each local expert to this variable.
|
expert_alignment: align the number of tokens received by each local expert to this variable.
|
||||||
|
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
|
||||||
|
will be CUDA-graph compatible. Please also notice that this flag is for intranode only.
|
||||||
config: the performance tuning config.
|
config: the performance tuning config.
|
||||||
previous_event: the event to wait before actually executing the kernel.
|
previous_event: the event to wait before actually executing the kernel.
|
||||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||||
@ -296,6 +299,7 @@ class Buffer:
|
|||||||
|
|
||||||
# Internode
|
# Internode
|
||||||
if self.runtime.get_num_rdma_ranks() > 1:
|
if self.runtime.get_num_rdma_ranks() > 1:
|
||||||
|
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
|
||||||
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||||
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
|
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||||
|
|
||||||
@ -308,14 +312,16 @@ class Buffer:
|
|||||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
|
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
|
||||||
x, x_scales, None, None,
|
x, x_scales, None, None,
|
||||||
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
|
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
|
||||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
expert_alignment, num_worst_tokens, config,
|
||||||
|
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||||
else:
|
else:
|
||||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
|
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
|
||||||
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
|
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
|
||||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
||||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
expert_alignment, num_worst_tokens, config,
|
||||||
|
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||||
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
|
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
|
||||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||||
|
|
||||||
|
|||||||
@ -100,6 +100,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
|||||||
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
||||||
if current_x is not x_pure_rand:
|
if current_x is not x_pure_rand:
|
||||||
check_data(recv_x, rank_prefix_matrix)
|
check_data(recv_x, rank_prefix_matrix)
|
||||||
|
recv_topk_weights_clone = None
|
||||||
if with_topk:
|
if with_topk:
|
||||||
# Check `topk_idx`
|
# Check `topk_idx`
|
||||||
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
||||||
@ -107,10 +108,26 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
|||||||
assert recv_topk_idx.eq(i).sum().item() == count
|
assert recv_topk_idx.eq(i).sum().item() == count
|
||||||
|
|
||||||
# Check `topk_weights`
|
# Check `topk_weights`
|
||||||
|
recv_topk_weights_clone = recv_topk_weights.clone()
|
||||||
if current_x is not x_pure_rand:
|
if current_x is not x_pure_rand:
|
||||||
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
||||||
check_data(recv_topk_weights, rank_prefix_matrix)
|
check_data(recv_topk_weights, rank_prefix_matrix)
|
||||||
|
|
||||||
|
# Test `num_worst_tokens != 0`
|
||||||
|
if with_topk:
|
||||||
|
num_worst_tokens = num_tokens * num_ranks
|
||||||
|
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
|
||||||
|
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args)
|
||||||
|
event.current_stream_wait() if async_mode else ()
|
||||||
|
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
|
||||||
|
assert num_worst_tokens == recv_worst_x.size(0)
|
||||||
|
assert num_worst_tokens == recv_worst_topk_idx.size(0)
|
||||||
|
assert num_worst_tokens == recv_worst_topk_weights.size(0)
|
||||||
|
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
|
||||||
|
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
|
||||||
|
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
|
||||||
|
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
|
||||||
|
|
||||||
# Test cached dispatch (must without top-k staffs)
|
# Test cached dispatch (must without top-k staffs)
|
||||||
if not with_topk:
|
if not with_topk:
|
||||||
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user