mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks) {
|
||||
@@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Copy `x_scales` into symmetric send buffer
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_scales; i += 32) {
|
||||
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
|
||||
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
|
||||
auto value = ld_nc_global(x_scales + offset);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
|
||||
@@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
int scale_token_stride, int scale_hidden_stride,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
cudaStream_t stream, int num_channels, bool low_latency_mode) {
|
||||
constexpr int kNumDispatchRDMASenderWarps = 7;
|
||||
|
||||
// Make sure never OOB
|
||||
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
||||
auto dispatch_func = low_latency_mode ? \
|
||||
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
|
||||
@@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
|
||||
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
||||
is_token_in_rank, \
|
||||
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
||||
scale_token_stride, scale_hidden_stride, \
|
||||
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
|
||||
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
|
||||
rank, num_ranks); } break
|
||||
|
||||
Reference in New Issue
Block a user