mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
// FP8 scales checks
|
||||
float* x_scales_ptr = nullptr;
|
||||
int num_scales = 0;
|
||||
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
|
||||
if (x_scales.has_value()) {
|
||||
EP_HOST_ASSERT(x.element_size() == 1);
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
|
||||
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
|
||||
EP_HOST_ASSERT(x_scales->dim() == 2);
|
||||
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
|
||||
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
|
||||
x_scales_ptr = x_scales->data_ptr<float>();
|
||||
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
|
||||
scale_token_stride = static_cast<int>(x_scales->stride(0));
|
||||
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
|
||||
}
|
||||
|
||||
// Allocate all tensors on comm stream if set
|
||||
@@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_x_scales = x_scales->dim() == 1 ?
|
||||
torch::empty({num_recv_tokens}, x_scales->options()) :
|
||||
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
|
||||
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
|
||||
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
|
||||
}
|
||||
|
||||
// Dispatch
|
||||
@@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
send_head.data_ptr<int>(),
|
||||
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
|
||||
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
||||
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
|
||||
num_topk, num_experts, num_scales,
|
||||
scale_token_stride, scale_hidden_stride,
|
||||
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
|
||||
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
|
||||
|
||||
@@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
// FP8 scales checks
|
||||
float* x_scales_ptr = nullptr;
|
||||
int num_scales = 0;
|
||||
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
|
||||
if (x_scales.has_value()) {
|
||||
EP_HOST_ASSERT(x.element_size() == 1);
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
|
||||
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
|
||||
EP_HOST_ASSERT(x_scales->dim() == 2);
|
||||
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
|
||||
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
|
||||
x_scales_ptr = x_scales->data_ptr<float>();
|
||||
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
|
||||
scale_token_stride = static_cast<int>(x_scales->stride(0));
|
||||
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
|
||||
}
|
||||
|
||||
// Allocate all tensors on comm stream if set
|
||||
@@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_x_scales = x_scales->dim() == 1 ?
|
||||
torch::empty({num_recv_tokens}, x_scales->options()) :
|
||||
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
|
||||
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
|
||||
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
|
||||
}
|
||||
|
||||
// Launch data dispatch
|
||||
@@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
|
||||
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
|
||||
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
|
||||
is_token_in_rank.data_ptr<bool>(),
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
|
||||
scale_token_stride, scale_hidden_stride,
|
||||
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
|
||||
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
|
||||
rank, num_ranks, cached_mode,
|
||||
@@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
bool async, bool return_recv_hook) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
@@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
||||
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
||||
int num_local_experts = num_experts / num_ranks;
|
||||
auto num_local_experts = num_experts / num_ranks;
|
||||
|
||||
// Buffer control
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
@@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
// Allocate column-majored scales
|
||||
auto packed_recv_x_scales = std::optional<torch::Tensor>();
|
||||
float* packed_recv_x_scales_ptr = nullptr;
|
||||
void* packed_recv_x_scales_ptr = nullptr;
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
|
||||
if (use_fp8) {
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
// TODO: support unaligned cases
|
||||
EP_HOST_ASSERT(hidden % 512 == 0);
|
||||
if (not use_ue8m0) {
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
|
||||
torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
} else {
|
||||
EP_HOST_ASSERT(round_scale);
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
|
||||
torch::dtype(torch::kInt).device(torch::kCUDA));
|
||||
}
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
@@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks, use_fp8,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
use_fp8, round_scale, use_ue8m0,
|
||||
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
Reference in New Issue
Block a user