Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu 2025-06-12 09:38:19 +08:00 committed by GitHub
parent 9ec061204e
commit 21efbe9b48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 255 additions and 115 deletions

View File

@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
}
// Dispatch
@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head.data_ptr<int>(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
num_topk, num_experts, num_scales,
scale_token_stride, scale_hidden_stride,
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
}
// Launch data dispatch
@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
is_token_in_rank.data_ptr<bool>(),
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
scale_token_stride, scale_hidden_stride,
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, cached_mode,
@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook) {
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
int num_local_experts = num_experts / num_ranks;
auto num_local_experts = num_experts / num_ranks;
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
float* packed_recv_x_scales_ptr = nullptr;
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (not use_ue8m0) {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}
// Kernel launch
@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));

View File

@ -141,7 +141,8 @@ public:
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook);
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,

View File

@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 14
CUDA_STANDARD 14
CXX_STANDARD 17
CUDA_STANDARD 17
CUDA_SEPARABLE_COMPILATION ON
)
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)

View File

@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int* usage_flag,
cudaStream_t stream, int phases);

View File

@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks) {
@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Copy `x_scales` into symmetric send buffer
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
auto value = ld_nc_global(x_scales + offset);
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto dispatch_func = low_latency_mode ? \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
is_token_in_rank, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
scale_token_stride, scale_hidden_stride, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); } break

View File

@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
template <bool kUseFP8, bool kUseUE8M0,
int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int* usage_flag, int phases) {
bool round_scale, int* usage_flag, int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// May extract UE8M0 from the scales
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = 128;
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
const int num_scales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// Overlap top-k index read and source token index write
// Overlap top-k index read and source token index writes
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Read
auto int4_value = __ldg(x_int4 + i);
if (kUseFP8) {
if constexpr (kUseFP8) {
// Calculate local amax
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
amax = half_warp_reduce_max(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = reinterpret_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if (kUseFP8) {
if constexpr (kUseFP8) {
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if (lane_id < num_scales) {
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + 32 < num_scales) {
const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
}
}
}
}
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int* usage_flag,
cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_counter_per_expert = static_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// FP8 checks
if (use_ue8m0)
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
usage_flag, phases); } break
round_scale, usage_flag, phases); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);

View File

@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `x_scales`
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
for (int i = lane_id; i < num_scales; i += 32) {
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);
}
}
// Move token index
@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
scale_token_stride, scale_hidden_stride, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break

View File

@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
return lane_id;
}
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
__forceinline__ __device__ int fast_log2_ceil(float x) {
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
auto exp_x = (bits_x >> 23) & 0xff;
auto man_bits = bits_x & ((1 << 23) - 1);
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
} else {
return value;
}
}
template <int kNumRanks>
__forceinline__ __device__ void
barrier_block(int** barrier_signal_ptrs, int rank) {

View File

@ -178,6 +178,7 @@ class Buffer:
config: the recommended config.
"""
# TODO: automatically tune
config_map = {
2: Config(Buffer.num_sms, 24, 256, 6, 128),
4: Config(Buffer.num_sms, 6, 256, 6, 128),
@ -205,6 +206,7 @@ class Buffer:
config: the recommended config.
"""
# TODO: automatically tune
config_map = {
2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 9, 256, 6, 128),
@ -486,14 +488,14 @@ class Buffer:
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
low-latency kernels' result tensors at a single moment.
Arguments:
@ -507,17 +509,21 @@ class Buffer:
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
@ -533,7 +539,8 @@ class Buffer:
self.runtime.low_latency_dispatch(x, topk_idx,
cumulative_local_expert_recv_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, async_finish, return_recv_hook)
use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
@ -551,9 +558,8 @@ class Buffer:
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
low-latency kernels' result tensors at a single moment.
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
@ -569,7 +575,7 @@ class Buffer:
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns:

12
install.sh Executable file
View File

@ -0,0 +1,12 @@
# Change current directory into project root
original_dir=$(pwd)
script_dir=$(dirname "$0")
cd "$script_dir"
# Remove old dist file, build, and install
rm -rf dist
python setup.py bdist_wheel
pip install dist/*.whl
# Open users' original directory
cd "$original_dir"

View File

@ -22,6 +22,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
@ -241,6 +242,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
num_processes = 8

View File

@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank

View File

@ -34,61 +34,68 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value, num_times = 0, 0
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone()
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
for use_ue8m0 in (False, True) if round_scale else (False, ):
num_times += 1
for i in range((num_times % 2) + 1):
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone()
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
# Check combine correctness
for zero_copy in (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
@ -112,7 +119,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
async_finish=False, return_recv_hook=return_recv_hook)
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
@ -170,6 +177,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the communication group
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead

View File

@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)