mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CUDA_STANDARD_REQUIRED ON
|
||||
CXX_STANDARD 14
|
||||
CUDA_STANDARD 14
|
||||
CXX_STANDARD 17
|
||||
CUDA_STANDARD 17
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
)
|
||||
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
|
||||
|
||||
@@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
@@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
@@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
void* workspace, int* usage_flag,
|
||||
cudaStream_t stream, int phases);
|
||||
|
||||
|
||||
@@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks) {
|
||||
@@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Copy `x_scales` into symmetric send buffer
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_scales; i += 32) {
|
||||
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
|
||||
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
|
||||
auto value = ld_nc_global(x_scales + offset);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
|
||||
@@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
cudaStream_t stream, int num_channels, bool low_latency_mode) {
|
||||
constexpr int kNumDispatchRDMASenderWarps = 7;
|
||||
|
||||
// Make sure never OOB
|
||||
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
||||
auto dispatch_func = low_latency_mode ? \
|
||||
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
|
||||
@@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
||||
is_token_in_rank, \
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
||||
scale_token_stride, scale_hidden_stride, \
|
||||
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
|
||||
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
|
||||
rank, num_ranks); } break
|
||||
|
||||
@@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||
}
|
||||
|
||||
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
template <bool kUseFP8, bool kUseUE8M0,
|
||||
int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int* usage_flag, int phases) {
|
||||
bool round_scale, int* usage_flag, int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// May extract UE8M0 from the scales
|
||||
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
|
||||
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
|
||||
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
|
||||
|
||||
// FP8 staffs
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
|
||||
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
|
||||
@@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||
|
||||
// Overlap top-k index read and source token index write
|
||||
// Overlap top-k index read and source token index writes
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
@@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Read
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
|
||||
if (kUseFP8) {
|
||||
if constexpr (kUseFP8) {
|
||||
// Calculate local amax
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
@@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
// Reduce amax and scale
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
|
||||
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
|
||||
amax = half_warp_reduce_max(amax);
|
||||
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
|
||||
if (lane_id == 0 or lane_id == 16)
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
@@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
|
||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
|
||||
const auto recv_x_scales = reinterpret_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
|
||||
|
||||
// Shared between sub-warps in warp groups
|
||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
||||
@@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
|
||||
|
||||
// Copy scales
|
||||
if (kUseFP8) {
|
||||
if constexpr (kUseFP8) {
|
||||
// Equivalent CuTe layout:
|
||||
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
|
||||
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
|
||||
const auto token_idx = recv_token_begin_idx + i;
|
||||
const auto token_stride = num_elems_per_pack;
|
||||
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
|
||||
if (lane_id < num_scales) {
|
||||
const auto pack_idx = lane_id / num_elems_per_pack;
|
||||
const auto elem_idx = lane_id % num_elems_per_pack;
|
||||
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
|
||||
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
|
||||
}
|
||||
if (lane_id + 32 < num_scales) {
|
||||
const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
|
||||
const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
|
||||
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
|
||||
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
void* workspace, int* usage_flag,
|
||||
cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
@@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_counter_per_expert = static_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
// FP8 checks
|
||||
if (use_ue8m0)
|
||||
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
|
||||
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
if (use_fp8 and not use_ue8m0) \
|
||||
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
if (use_fp8 and use_ue8m0) \
|
||||
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
@@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, \
|
||||
usage_flag, phases); } break
|
||||
round_scale, usage_flag, phases); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
|
||||
@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
for (int i = lane_id; i < num_scales; i += 32) {
|
||||
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Move token index
|
||||
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
// Make sure never OOB
|
||||
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
scale_token_stride, scale_hidden_stride, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
} break
|
||||
|
||||
@@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
constexpr float kFP8Margin = 1e-4;
|
||||
constexpr float kFinfoAmaxE4M3 = 448.0f;
|
||||
constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;
|
||||
|
||||
__forceinline__ __device__ float fast_pow2(int x) {
|
||||
// We can ensure `-126 <= x and x <= 127`
|
||||
uint32_t bits_x = (x + 127) << 23;
|
||||
return *reinterpret_cast<float*>(&bits_x);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int fast_log2_ceil(float x) {
|
||||
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
|
||||
auto exp_x = (bits_x >> 23) & 0xff;
|
||||
auto man_bits = bits_x & ((1 << 23) - 1);
|
||||
return exp_x - 127 + (man_bits != 0);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
|
||||
if (round_scale) {
|
||||
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
|
||||
scale = fast_pow2(-exp_scale_inv);
|
||||
scale_inv = fast_pow2(exp_scale_inv);
|
||||
} else {
|
||||
scale_inv = amax * kFinfoAmaxInvE4M3;
|
||||
scale = kFinfoAmaxE4M3 / amax;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
|
||||
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
|
||||
if constexpr (kIsUE8M0) {
|
||||
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
barrier_block(int** barrier_signal_ptrs, int rank) {
|
||||
|
||||
Reference in New Issue
Block a user