mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support BF16 for low-latency kernels
This commit is contained in:
@@ -1011,10 +1011,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook) {
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@@ -1045,20 +1045,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
// Allocate packed tensors
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
|
||||
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
|
||||
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
// Allocate column-majored scales
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
auto packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales, 1, 2);
|
||||
auto packed_recv_x_scales = std::optional<torch::Tensor>();
|
||||
float* packed_recv_x_scales_ptr = nullptr;
|
||||
if (use_fp8) {
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||
packed_recv_count.data_ptr<int>(),
|
||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||
@@ -1066,7 +1072,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
num_topk, num_experts, rank, num_ranks, use_fp8,
|
||||
workspace, launch_stream, phases);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
Reference in New Issue
Block a user