mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
for (int i = lane_id; i < num_scales; i += 32) {
|
||||
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Move token index
|
||||
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
// Make sure never OOB
|
||||
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
scale_token_stride, scale_hidden_stride, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
} break
|
||||
|
||||
Reference in New Issue
Block a user