Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu
2025-06-12 09:38:19 +08:00
committed by GitHub
parent 9ec061204e
commit 21efbe9b48
14 changed files with 255 additions and 115 deletions

View File

@@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 14
CUDA_STANDARD 14
CXX_STANDARD 17
CUDA_STANDARD 17
CUDA_SEPARABLE_COMPILATION ON
)
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)

View File

@@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
@@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
@@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
@@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int* usage_flag,
cudaStream_t stream, int phases);

View File

@@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks) {
@@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Copy `x_scales` into symmetric send buffer
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
auto value = ld_nc_global(x_scales + offset);
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
@@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto dispatch_func = low_latency_mode ? \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
@@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
is_token_in_rank, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
scale_token_stride, scale_hidden_stride, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); } break

View File

@@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
template <bool kUseFP8, bool kUseUE8M0,
int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
@@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int* usage_flag, int phases) {
bool round_scale, int* usage_flag, int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
@@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// May extract UE8M0 from the scales
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = 128;
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
const int num_scales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
@@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// Overlap top-k index read and source token index write
// Overlap top-k index read and source token index writes
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
@@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Read
auto int4_value = __ldg(x_int4 + i);
if (kUseFP8) {
if constexpr (kUseFP8) {
// Calculate local amax
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
@@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
amax = half_warp_reduce_max(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
@@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = reinterpret_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
@@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if (kUseFP8) {
if constexpr (kUseFP8) {
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if (lane_id < num_scales) {
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + 32 < num_scales) {
const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
}
}
}
}
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
@@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int* usage_flag,
cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
@@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_counter_per_expert = static_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// FP8 checks
if (use_ue8m0)
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
@@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
usage_flag, phases); } break
round_scale, usage_flag, phases); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);

View File

@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `x_scales`
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
for (int i = lane_id; i < num_scales; i += 32) {
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);
}
}
// Move token index
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
scale_token_stride, scale_hidden_stride, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break

View File

@@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
return lane_id;
}
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
__forceinline__ __device__ int fast_log2_ceil(float x) {
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
auto exp_x = (bits_x >> 23) & 0xff;
auto man_bits = bits_x & ((1 << 23) - 1);
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
} else {
return value;
}
}
template <int kNumRanks>
__forceinline__ __device__ void
barrier_block(int** barrier_signal_ptrs, int rank) {