mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||
}
|
||||
|
||||
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
template <bool kUseFP8, bool kUseUE8M0,
|
||||
int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int* usage_flag, int phases) {
|
||||
bool round_scale, int* usage_flag, int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// May extract UE8M0 from the scales
|
||||
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
|
||||
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
|
||||
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
|
||||
|
||||
// FP8 staffs
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
|
||||
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
|
||||
@@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||
|
||||
// Overlap top-k index read and source token index write
|
||||
// Overlap top-k index read and source token index writes
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
@@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Read
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
|
||||
if (kUseFP8) {
|
||||
if constexpr (kUseFP8) {
|
||||
// Calculate local amax
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
@@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
// Reduce amax and scale
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
|
||||
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
|
||||
amax = half_warp_reduce_max(amax);
|
||||
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
|
||||
if (lane_id == 0 or lane_id == 16)
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
@@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
|
||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
|
||||
const auto recv_x_scales = reinterpret_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
|
||||
|
||||
// Shared between sub-warps in warp groups
|
||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
||||
@@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
|
||||
|
||||
// Copy scales
|
||||
if (kUseFP8) {
|
||||
if constexpr (kUseFP8) {
|
||||
// Equivalent CuTe layout:
|
||||
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
|
||||
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
|
||||
const auto token_idx = recv_token_begin_idx + i;
|
||||
const auto token_stride = num_elems_per_pack;
|
||||
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
|
||||
if (lane_id < num_scales) {
|
||||
const auto pack_idx = lane_id / num_elems_per_pack;
|
||||
const auto elem_idx = lane_id % num_elems_per_pack;
|
||||
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
|
||||
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
|
||||
}
|
||||
if (lane_id + 32 < num_scales) {
|
||||
const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
|
||||
const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
|
||||
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
|
||||
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
void* workspace, int* usage_flag,
|
||||
cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
@@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_counter_per_expert = static_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
// FP8 checks
|
||||
if (use_ue8m0)
|
||||
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
|
||||
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
if (use_fp8 and not use_ue8m0) \
|
||||
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
if (use_fp8 and use_ue8m0) \
|
||||
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
@@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, \
|
||||
usage_flag, phases); } break
|
||||
round_scale, usage_flag, phases); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
|
||||
Reference in New Issue
Block a user