Support Ampere architecture (#204)

* Update README

* Update `setup.py`

* Fix headers

* Add `DISABLE_NVSHMEM` for APIs

* Fix launch

* Fix TMA settings

* Fix TMA usages

* Fix dlink

* Separate layout kernels

* Update version

* Add `is_sm90_compiled`

* Fix tests

* Add NVLink connection checks

* Update README

* Fix tests

* Add some comments

* Minor fix

* Minor fix

* Fix bugs
This commit is contained in:
Chenggang Zhao
2025-06-11 15:48:18 +08:00
committed by GitHub
parent dd13c7145c
commit b8d90fb753
16 changed files with 413 additions and 174 deletions

View File

@@ -227,6 +227,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
// TMA stuffs
#ifndef DISABLE_SM90_FEATURES
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto half_hidden_int4 = hidden_int4 / 2;
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
@@ -240,6 +241,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
#endif
if (is_sender) {
// Workers for sending
@@ -399,6 +401,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
#ifndef DISABLE_SM90_FEATURES
#pragma unroll
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
tma_store_wait();
@@ -408,6 +411,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
}
__syncwarp();
#else
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
ld_nc_global, st_na_global);
#endif
}
// Copy `src_idx`
@@ -447,8 +454,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
}
// Make TMA store visible to the next kernel
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
#endif
}
@@ -473,12 +482,13 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 8192;
#ifndef DISABLE_SM90_FEATURES
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
@@ -587,8 +597,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
// TMA stuffs
#ifndef DISABLE_SM90_FEATURES
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
#endif
if (is_sender) {
// Workers for sending
@@ -778,9 +790,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
}
// Wait shared memory release
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
__syncwarp();
#endif
// Reduce data with pipeline
constexpr int kNumStages = 8;
@@ -810,6 +824,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
#ifndef DISABLE_SM90_FEATURES
// Wait TMA arrival
if (lane_id == 0)
tma_store_wait<kNumStages - 1>();
@@ -828,6 +843,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
}
__syncwarp();
#else
recv_int4[token_idx * hidden_int4 + i] = out_int4;
#endif
}
// Reduce `topk_weights`
@@ -850,8 +868,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
warp_retired[recv_warp_id] = true;
// Make TMA store visible to the next kernel
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
#endif
}
}
}
@@ -866,12 +886,13 @@ void combine(cudaDataType_t type,
int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 4096;
#ifndef DISABLE_SM90_FEATURES
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \