Support BF16 for low-latency kernels

This commit is contained in:
Chenggang Zhao
2025-03-10 17:24:41 +08:00
parent 1fc40d50f3
commit ed7487c15e
8 changed files with 138 additions and 111 deletions

View File

@@ -1011,10 +1011,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
at::cuda::getCurrentCUDAStream());
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) {
bool use_fp8, bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
@@ -1045,20 +1045,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
auto packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
packed_recv_x_scales = torch::transpose(packed_recv_x_scales, 1, 2);
auto packed_recv_x_scales = std::optional<torch::Tensor>();
float* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
@@ -1066,7 +1072,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
num_topk, num_experts, rank, num_ranks, use_fp8,
workspace, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));