Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu
2025-06-12 09:38:19 +08:00
committed by GitHub
parent 9ec061204e
commit 21efbe9b48
14 changed files with 255 additions and 115 deletions

View File

@@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
@@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
}
// Dispatch
@@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head.data_ptr<int>(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
num_topk, num_experts, num_scales,
scale_token_stride, scale_hidden_stride,
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
@@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
@@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
}
// Launch data dispatch
@@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
is_token_in_rank.data_ptr<bool>(),
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
scale_token_stride, scale_hidden_stride,
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, cached_mode,
@@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook) {
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
@@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
int num_local_experts = num_experts / num_ranks;
auto num_local_experts = num_experts / num_ranks;
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
@@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
float* packed_recv_x_scales_ptr = nullptr;
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (not use_ue8m0) {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}
// Kernel launch
@@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));