mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
parent
9ec061204e
commit
21efbe9b48
@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
// FP8 scales checks
|
||||
float* x_scales_ptr = nullptr;
|
||||
int num_scales = 0;
|
||||
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
|
||||
if (x_scales.has_value()) {
|
||||
EP_HOST_ASSERT(x.element_size() == 1);
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
|
||||
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
|
||||
EP_HOST_ASSERT(x_scales->dim() == 2);
|
||||
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
|
||||
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
|
||||
x_scales_ptr = x_scales->data_ptr<float>();
|
||||
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
|
||||
scale_token_stride = static_cast<int>(x_scales->stride(0));
|
||||
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
|
||||
}
|
||||
|
||||
// Allocate all tensors on comm stream if set
|
||||
@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_x_scales = x_scales->dim() == 1 ?
|
||||
torch::empty({num_recv_tokens}, x_scales->options()) :
|
||||
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
|
||||
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
|
||||
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
|
||||
}
|
||||
|
||||
// Dispatch
|
||||
@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
send_head.data_ptr<int>(),
|
||||
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
|
||||
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
||||
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
|
||||
num_topk, num_experts, num_scales,
|
||||
scale_token_stride, scale_hidden_stride,
|
||||
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
|
||||
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
|
||||
|
||||
@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
// FP8 scales checks
|
||||
float* x_scales_ptr = nullptr;
|
||||
int num_scales = 0;
|
||||
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
|
||||
if (x_scales.has_value()) {
|
||||
EP_HOST_ASSERT(x.element_size() == 1);
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
|
||||
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
|
||||
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt);
|
||||
EP_HOST_ASSERT(x_scales->dim() == 2);
|
||||
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
|
||||
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
|
||||
x_scales_ptr = x_scales->data_ptr<float>();
|
||||
x_scales_ptr = static_cast<float*>(x_scales->data_ptr());
|
||||
scale_token_stride = static_cast<int>(x_scales->stride(0));
|
||||
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
|
||||
}
|
||||
|
||||
// Allocate all tensors on comm stream if set
|
||||
@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_x_scales = x_scales->dim() == 1 ?
|
||||
torch::empty({num_recv_tokens}, x_scales->options()) :
|
||||
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
|
||||
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
|
||||
recv_x_scales_ptr = static_cast<float*>(recv_x_scales->data_ptr());
|
||||
}
|
||||
|
||||
// Launch data dispatch
|
||||
@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
|
||||
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
|
||||
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
|
||||
is_token_in_rank.data_ptr<bool>(),
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
|
||||
scale_token_stride, scale_hidden_stride,
|
||||
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
|
||||
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
|
||||
rank, num_ranks, cached_mode,
|
||||
@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
bool async, bool return_recv_hook) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
||||
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
||||
int num_local_experts = num_experts / num_ranks;
|
||||
auto num_local_experts = num_experts / num_ranks;
|
||||
|
||||
// Buffer control
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
// Allocate column-majored scales
|
||||
auto packed_recv_x_scales = std::optional<torch::Tensor>();
|
||||
float* packed_recv_x_scales_ptr = nullptr;
|
||||
void* packed_recv_x_scales_ptr = nullptr;
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
|
||||
if (use_fp8) {
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
// TODO: support unaligned cases
|
||||
EP_HOST_ASSERT(hidden % 512 == 0);
|
||||
if (not use_ue8m0) {
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
|
||||
torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
} else {
|
||||
EP_HOST_ASSERT(round_scale);
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
|
||||
torch::dtype(torch::kInt).device(torch::kCUDA));
|
||||
}
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks, use_fp8,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
use_fp8, round_scale, use_ue8m0,
|
||||
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
@ -141,7 +141,8 @@ public:
|
||||
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook);
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
bool async, bool return_recv_hook);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
|
@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CUDA_STANDARD_REQUIRED ON
|
||||
CXX_STANDARD 14
|
||||
CUDA_STANDARD 14
|
||||
CXX_STANDARD 17
|
||||
CUDA_STANDARD 17
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
)
|
||||
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
|
||||
|
@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
void* workspace, int* usage_flag,
|
||||
cudaStream_t stream, int phases);
|
||||
|
||||
|
@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks) {
|
||||
@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Copy `x_scales` into symmetric send buffer
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_scales; i += 32) {
|
||||
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
|
||||
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
|
||||
auto value = ld_nc_global(x_scales + offset);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
|
||||
@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
cudaStream_t stream, int num_channels, bool low_latency_mode) {
|
||||
constexpr int kNumDispatchRDMASenderWarps = 7;
|
||||
|
||||
// Make sure never OOB
|
||||
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
||||
auto dispatch_func = low_latency_mode ? \
|
||||
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
|
||||
@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
||||
is_token_in_rank, \
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
||||
scale_token_stride, scale_hidden_stride, \
|
||||
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
|
||||
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
|
||||
rank, num_ranks); } break
|
||||
|
@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||
}
|
||||
|
||||
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
template <bool kUseFP8, bool kUseUE8M0,
|
||||
int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int* usage_flag, int phases) {
|
||||
bool round_scale, int* usage_flag, int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// May extract UE8M0 from the scales
|
||||
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
|
||||
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
|
||||
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
|
||||
|
||||
// FP8 staffs
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
|
||||
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
|
||||
@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||
|
||||
// Overlap top-k index read and source token index write
|
||||
// Overlap top-k index read and source token index writes
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Read
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
|
||||
if (kUseFP8) {
|
||||
if constexpr (kUseFP8) {
|
||||
// Calculate local amax
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
// Reduce amax and scale
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
|
||||
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
|
||||
amax = half_warp_reduce_max(amax);
|
||||
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
|
||||
if (lane_id == 0 or lane_id == 16)
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
|
||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
|
||||
const auto recv_x_scales = reinterpret_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
|
||||
|
||||
// Shared between sub-warps in warp groups
|
||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
||||
@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
|
||||
|
||||
// Copy scales
|
||||
if (kUseFP8) {
|
||||
if constexpr (kUseFP8) {
|
||||
// Equivalent CuTe layout:
|
||||
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
|
||||
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
|
||||
const auto token_idx = recv_token_begin_idx + i;
|
||||
const auto token_stride = num_elems_per_pack;
|
||||
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
|
||||
if (lane_id < num_scales) {
|
||||
const auto pack_idx = lane_id / num_elems_per_pack;
|
||||
const auto elem_idx = lane_id % num_elems_per_pack;
|
||||
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
|
||||
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
|
||||
}
|
||||
if (lane_id + 32 < num_scales) {
|
||||
const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
|
||||
const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
|
||||
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
|
||||
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||
void* workspace, int* usage_flag,
|
||||
cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_counter_per_expert = static_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
// FP8 checks
|
||||
if (use_ue8m0)
|
||||
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
|
||||
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
if (use_fp8 and not use_ue8m0) \
|
||||
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
if (use_fp8 and use_ue8m0) \
|
||||
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, \
|
||||
usage_flag, phases); } break
|
||||
round_scale, usage_flag, phases); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
|
@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
for (int i = lane_id; i < num_scales; i += 32) {
|
||||
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Move token index
|
||||
@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
// Make sure never OOB
|
||||
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
scale_token_stride, scale_hidden_stride, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
} break
|
||||
|
@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
constexpr float kFP8Margin = 1e-4;
|
||||
constexpr float kFinfoAmaxE4M3 = 448.0f;
|
||||
constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f;
|
||||
|
||||
__forceinline__ __device__ float fast_pow2(int x) {
|
||||
// We can ensure `-126 <= x and x <= 127`
|
||||
uint32_t bits_x = (x + 127) << 23;
|
||||
return *reinterpret_cast<float*>(&bits_x);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int fast_log2_ceil(float x) {
|
||||
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
|
||||
auto exp_x = (bits_x >> 23) & 0xff;
|
||||
auto man_bits = bits_x & ((1 << 23) - 1);
|
||||
return exp_x - 127 + (man_bits != 0);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
|
||||
if (round_scale) {
|
||||
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
|
||||
scale = fast_pow2(-exp_scale_inv);
|
||||
scale_inv = fast_pow2(exp_scale_inv);
|
||||
} else {
|
||||
scale_inv = amax * kFinfoAmaxInvE4M3;
|
||||
scale = kFinfoAmaxE4M3 / amax;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
|
||||
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
|
||||
if constexpr (kIsUE8M0) {
|
||||
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
barrier_block(int** barrier_signal_ptrs, int rank) {
|
||||
|
@ -178,6 +178,7 @@ class Buffer:
|
||||
config: the recommended config.
|
||||
"""
|
||||
|
||||
# TODO: automatically tune
|
||||
config_map = {
|
||||
2: Config(Buffer.num_sms, 24, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
@ -205,6 +206,7 @@ class Buffer:
|
||||
config: the recommended config.
|
||||
"""
|
||||
|
||||
# TODO: automatically tune
|
||||
config_map = {
|
||||
2: Config(Buffer.num_sms, 10, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 9, 256, 6, 128),
|
||||
@ -486,14 +488,14 @@ class Buffer:
|
||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||||
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
|
||||
async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for dispatching with IBGDA.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
|
||||
low-latency kernels' result tensors at a single moment.
|
||||
|
||||
Arguments:
|
||||
@ -507,17 +509,21 @@ class Buffer:
|
||||
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
|
||||
monitoring.
|
||||
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
|
||||
round_scale: whether round the scaling factors into power of 2.
|
||||
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
If you do not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
recv_x: a tensor or tuple with received tokens for each expert.
|
||||
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
|
||||
The second tensor is the corresponding scales for the first element with shape
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
|
||||
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
|
||||
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
|
||||
With `use_fp8=False`, the result would be a tensor shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
|
||||
@ -533,7 +539,8 @@ class Buffer:
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
cumulative_local_expert_recv_stats,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
use_fp8, async_finish, return_recv_hook)
|
||||
use_fp8, round_scale, use_ue8m0,
|
||||
async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
@ -551,9 +558,8 @@ class Buffer:
|
||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
low-latency kernels' result tensor at a single moment.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
|
||||
low-latency kernels' result tensors at a single moment.
|
||||
|
||||
Arguments:
|
||||
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
|
||||
@ -569,7 +575,7 @@ class Buffer:
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
If you do not set this flag, the kernel will ensure the data's arrival.
|
||||
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
|
||||
|
||||
Returns:
|
||||
|
12
install.sh
Executable file
12
install.sh
Executable file
@ -0,0 +1,12 @@
|
||||
# Change current directory into project root
|
||||
original_dir=$(pwd)
|
||||
script_dir=$(dirname "$0")
|
||||
cd "$script_dir"
|
||||
|
||||
# Remove old dist file, build, and install
|
||||
rm -rf dist
|
||||
python setup.py bdist_wheel
|
||||
pip install dist/*.whl
|
||||
|
||||
# Open users' original directory
|
||||
cd "$original_dir"
|
@ -22,6 +22,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
|
||||
@ -241,6 +242,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
# Destroy the communication group
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_processes = 8
|
||||
|
@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
|
||||
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
|
||||
|
@ -34,61 +34,68 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hash_value, num_times = 0, 0
|
||||
for return_recv_hook in (False, True):
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
|
||||
if dispatch_use_fp8 else packed_recv_x.clone()
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
|
||||
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
|
||||
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
|
||||
for use_ue8m0 in (False, True) if round_scale else (False, ):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
|
||||
if dispatch_use_fp8 else packed_recv_x.clone()
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
|
||||
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
|
||||
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
|
||||
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
if round_scale:
|
||||
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
|
||||
else:
|
||||
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
|
||||
if not round_scale:
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
|
||||
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
# Check combine correctness
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
@ -112,7 +119,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
|
||||
@ -170,6 +177,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
for i in range(20):
|
||||
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
|
||||
|
||||
# Destroy the communication group
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO: you may modify NUMA binding for less CPU overhead
|
||||
|
@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
|
||||
|
||||
|
||||
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
||||
if x_scales.dtype == torch.int:
|
||||
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
|
||||
x_scales = x_scales.view(dtype=torch.float)
|
||||
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
||||
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
||||
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
||||
|
Loading…
Reference in New Issue
Block a user