mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support Ampere architecture (#204)
* Update README * Update `setup.py` * Fix headers * Add `DISABLE_NVSHMEM` for APIs * Fix launch * Fix TMA settings * Fix TMA usages * Fix dlink * Separate layout kernels * Update version * Add `is_sm90_compiled` * Fix tests * Add NVLink connection checks * Update README * Fix tests * Add some comments * Minor fix * Minor fix * Fix bugs
This commit is contained in:
@@ -227,6 +227,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
|
||||
|
||||
// TMA stuffs
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto half_hidden_int4 = hidden_int4 / 2;
|
||||
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
|
||||
@@ -240,6 +241,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
|
||||
}
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
@@ -399,6 +401,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
|
||||
tma_store_wait();
|
||||
@@ -408,6 +411,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
#else
|
||||
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
|
||||
ld_nc_global, st_na_global);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Copy `src_idx`
|
||||
@@ -447,8 +454,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
}
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -473,12 +482,13 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 8192;
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
@@ -587,8 +597,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
// TMA stuffs
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
|
||||
#endif
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
@@ -778,9 +790,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
}
|
||||
|
||||
// Wait shared memory release
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
// Reduce data with pipeline
|
||||
constexpr int kNumStages = 8;
|
||||
@@ -810,6 +824,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
||||
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
// Wait TMA arrival
|
||||
if (lane_id == 0)
|
||||
tma_store_wait<kNumStages - 1>();
|
||||
@@ -828,6 +843,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
#else
|
||||
recv_int4[token_idx * hidden_int4 + i] = out_int4;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Reduce `topk_weights`
|
||||
@@ -850,8 +868,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
warp_retired[recv_warp_id] = true;
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -866,12 +886,13 @@ void combine(cudaDataType_t type,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 4096;
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
|
||||
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
|
||||
Reference in New Issue
Block a user