Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu
2025-06-12 09:38:19 +08:00
committed by GitHub
parent 9ec061204e
commit 21efbe9b48
14 changed files with 255 additions and 115 deletions

View File

@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `x_scales`
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
for (int i = lane_id; i < num_scales; i += 32) {
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);
}
}
// Move token index
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
scale_token_stride, scale_hidden_stride, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break