mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use TMA instead of LD/ST for intra-node normal kernels (#191)
* Update CMake files * Use TMA instead of LD/ST for intranode dispatch * Use TMA instead of LD/ST for intranode combine * Adjust configs * Test default configs as well * More warps for combine * Add inter-thread fence * Enable more warps * Do not use TMA for senders * Update configs * Remove useless wait
This commit is contained in:
parent
df4debe30c
commit
c8dceba110
@ -9,7 +9,10 @@ set(CUDA_SEPARABLE_COMPILATION ON)
|
||||
list(APPEND CUDA_NVCC_FLAGS "-O3")
|
||||
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
|
||||
|
||||
set(TORCH_CUDA_ARCH_LIST "9.0")
|
||||
set(USE_SYSTEM_NVTX on)
|
||||
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
|
||||
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
find_package(pybind11 REQUIRED)
|
||||
find_package(Torch REQUIRED)
|
||||
@ -19,9 +22,8 @@ add_library(nvshmem ALIAS nvshmem::nvshmem)
|
||||
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
|
||||
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
|
||||
|
||||
# Seems bugs with CMake, NVCC 12 and C++ 17
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 14)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
|
||||
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
|
||||
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
|
||||
|
@ -19,8 +19,6 @@
|
||||
#ifdef __CLION_IDE__
|
||||
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
|
||||
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
// Remove Torch restrictions
|
||||
|
@ -174,7 +174,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template <int kNumRanks, int kNumThreads>
|
||||
template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
@ -183,7 +183,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
void **buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
EP_DEVICE_ASSERT(num_sms % 2 == 0);
|
||||
|
||||
@ -232,19 +232,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
|
||||
|
||||
// TMA stuffs
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto half_hidden_int4 = hidden_int4 / 2;
|
||||
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
|
||||
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
|
||||
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + half_hidden_bytes);
|
||||
uint32_t tma_phase = 0;
|
||||
if (lane_id == 0) {
|
||||
mbarrier_init(tma_mbarrier, 1);
|
||||
fence_view_async_shared();
|
||||
fence_barrier_init();
|
||||
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
|
||||
|
||||
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
|
||||
// NOTES: this is for distinguishing zero tokens
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0) {
|
||||
if (lane_id == 0 and send_warp_id_in_rank == 0) {
|
||||
int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;
|
||||
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
|
||||
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
|
||||
@ -262,7 +276,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
// NOTES: the head index received by different warps may not be the same
|
||||
auto start_time = clock64();
|
||||
while (send_lane_id == 0) {
|
||||
while (lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
|
||||
@ -279,7 +293,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int chunk_token_idx = 0;
|
||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
|
||||
if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
|
||||
|
||||
// Skip if not selected
|
||||
@ -294,30 +308,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Copy data
|
||||
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x + token_idx * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
|
||||
__ldg, st_na_global);
|
||||
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global);
|
||||
|
||||
// Copy source index
|
||||
if (send_lane_id == 0)
|
||||
if (lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights` with transformed index
|
||||
if (send_lane_id < num_topk) {
|
||||
if (lane_id < num_topk) {
|
||||
// Top-k index
|
||||
int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id);
|
||||
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;
|
||||
|
||||
// Top-k weights
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);
|
||||
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = weight_value;
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = send_lane_id; i < num_scales; i += 32)
|
||||
for (int i = lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
}
|
||||
|
||||
@ -328,7 +341,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Move tail index
|
||||
// NOTES: here all warps should share the same new tail
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
if (send_warp_id_in_rank == 0 and send_lane_id == 0)
|
||||
if (send_warp_id_in_rank == 0 and lane_id == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
@ -336,7 +349,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
|
||||
const auto recv_thread_id = thread_id;
|
||||
const auto recv_lane_id = recv_thread_id % 32;
|
||||
const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;
|
||||
const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
@ -348,9 +360,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Receive channel offset
|
||||
int total_offset, num_tokens_to_recv;
|
||||
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
|
||||
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (recv_lane_id == 0) {
|
||||
while (lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
|
||||
while (lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (lane_id == 0) {
|
||||
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
|
||||
if (recv_warp_id_in_rank == 0)
|
||||
recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;
|
||||
@ -393,8 +405,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
|
||||
ld_nc_global, st_na_global);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
|
||||
tma_store_wait();
|
||||
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
|
||||
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
|
||||
mbarrier_wait(tma_mbarrier, tma_phase);
|
||||
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Copy `src_idx`
|
||||
@ -426,12 +445,16 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
cached_channel_head_idx += num_recv_tokens;
|
||||
total_offset += num_recv_tokens;
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0)
|
||||
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
|
||||
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
|
||||
|
||||
// Exit
|
||||
num_tokens_to_recv -= num_recv_tokens;
|
||||
}
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
}
|
||||
}
|
||||
|
||||
@ -441,17 +464,22 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 512;
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 8192;
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<ranks, kNumThreads>, \
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
break
|
||||
} break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
@ -532,7 +560,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
#undef CACHED_NOTIFY_COMBINE
|
||||
}
|
||||
|
||||
template<typename dtype_t, int kNumRanks, int kNumThreads>
|
||||
template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
@ -542,7 +570,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();
|
||||
const auto num_channels = num_sms / 2;
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
const int responsible_channel = sm_id / 2;
|
||||
@ -553,16 +581,21 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
auto x_int4 = reinterpret_cast<const int4*>(x);
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
// TMA stuffs
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
// Several warps are responsible for a single rank
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks;
|
||||
constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks;
|
||||
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_rank_id = thread_id / num_threads_per_rank;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
|
||||
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||
@ -595,7 +628,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
auto start_time = clock64();
|
||||
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
|
||||
while (send_lane_id == 0) {
|
||||
while (lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
|
||||
@ -618,22 +651,22 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
// Copy data
|
||||
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
||||
UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
||||
|
||||
// Send source index
|
||||
if (send_lane_id == 0)
|
||||
if (lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
|
||||
|
||||
// Send `topk_weights`
|
||||
if (num_topk > 0 and send_lane_id < num_topk)
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
|
||||
if (num_topk > 0 and lane_id < num_topk)
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + lane_id);
|
||||
}
|
||||
token_idx += num_round_tokens;
|
||||
current_channel_tail_idx += num_round_tokens;
|
||||
|
||||
// Move tail index
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
|
||||
if (lane_id == 0 and send_warp_id_in_rank == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
@ -641,7 +674,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
// One warp for moving the queue head, others for reduction
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
const auto recv_warp_id = thread_id / 32;
|
||||
const auto recv_lane_id = thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
|
||||
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
|
||||
|
||||
@ -651,19 +683,19 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
__shared__ volatile bool warp_retired[num_recv_warps];
|
||||
if (thread_id < num_recv_warps)
|
||||
warp_retired[thread_id] = false;
|
||||
if (recv_lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
|
||||
if (lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][lane_id] = 0;
|
||||
if (thread_id < kNumRanks)
|
||||
channel_tail_idx[thread_id] = 0;
|
||||
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
||||
|
||||
if (thread_id < 32) {
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
|
||||
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
||||
|
||||
// Queue head updater
|
||||
int last_head = 0;
|
||||
while (recv_lane_id < kNumRanks) {
|
||||
while (lane_id < kNumRanks) {
|
||||
// Check retired
|
||||
bool retired = true;
|
||||
#pragma unroll
|
||||
@ -673,13 +705,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
break;
|
||||
|
||||
// Update queue tail
|
||||
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
|
||||
channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
|
||||
|
||||
// Update minimum head
|
||||
int min_head = std::numeric_limits<int>::max();
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
|
||||
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
|
||||
min_head = min(min_head, warp_channel_head_idx[i][lane_id]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
|
||||
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
|
||||
}
|
||||
@ -716,11 +748,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
|
||||
// Read expected head
|
||||
int expected_head = -1;
|
||||
if (recv_lane_id < kNumRanks)
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
|
||||
if (lane_id < kNumRanks)
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
|
||||
|
||||
auto start_time = clock64();
|
||||
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
|
||||
while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
|
||||
@ -740,9 +772,16 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce data
|
||||
// Wait shared memory release
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
__syncwarp();
|
||||
|
||||
// Reduce data with pipeline
|
||||
constexpr int kNumStages = 8;
|
||||
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
|
||||
#pragma unroll
|
||||
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
|
||||
for (int i = lane_id; i < hidden_int4; i += 32) {
|
||||
// Read buffers
|
||||
int4 recv_value_int4[kNumRanks];
|
||||
#pragma unroll
|
||||
@ -759,33 +798,55 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
values[k] += static_cast<float>(recv_value_dtypes[k]);
|
||||
}
|
||||
|
||||
// Cast back to `dtype_t` and write
|
||||
// Cast back to `dtype_t`
|
||||
int4 out_int4;
|
||||
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
||||
recv_int4[token_idx * hidden_int4 + i] = out_int4;
|
||||
|
||||
// Wait TMA arrival
|
||||
if (lane_id == 0)
|
||||
tma_store_wait<kNumStages - 1>();
|
||||
__syncwarp();
|
||||
|
||||
// Write into TMA buffer
|
||||
auto tma_stage_idx = (i / 32) % kNumStages;
|
||||
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
|
||||
|
||||
// Issue TMA
|
||||
tma_store_fence();
|
||||
__syncwarp();
|
||||
if (lane_id == 0) {
|
||||
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
|
||||
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
|
||||
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Reduce `topk_weights`
|
||||
if (recv_lane_id < num_topk) {
|
||||
if (lane_id < num_topk) {
|
||||
float value = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk_ranks; ++ i)
|
||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
|
||||
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
|
||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + lane_id);
|
||||
recv_topk_weights[token_idx * num_topk + lane_id] = value;
|
||||
}
|
||||
|
||||
// Update head
|
||||
if (recv_lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||
if (lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||
}
|
||||
|
||||
// Retired
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
if (lane_id == 0)
|
||||
warp_retired[recv_warp_id] = true;
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -799,15 +860,20 @@ void combine(cudaDataType_t type,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 4096;
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
|
||||
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
|
||||
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
||||
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); } \
|
||||
break
|
||||
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break
|
||||
|
||||
|
@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
|
||||
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void fence_view_async_shared() {
|
||||
asm volatile("fence.proxy.async.shared::cta; \n" :: );
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void fence_barrier_init() {
|
||||
asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
asm volatile("{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"LAB_WAIT: \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
|
||||
"@P1 bra DONE; \n\t"
|
||||
"bra LAB_WAIT; \n\t"
|
||||
"DONE: \n\t"
|
||||
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
|
||||
phase ^= 1;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tma_store_fence() {
|
||||
asm volatile ("fence.proxy.async.shared::cta;");
|
||||
}
|
||||
|
||||
constexpr uint64_t kEvictFirst = 0x12f0000000000000;
|
||||
constexpr uint64_t kEvictNormal = 0x1000000000000000;
|
||||
|
||||
__device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes,
|
||||
bool evict_first = true) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
|
||||
asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n"
|
||||
:: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes,
|
||||
bool evict_first = true) {
|
||||
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
|
||||
asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n"
|
||||
:: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory");
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
}
|
||||
|
||||
template <int N = 0>
|
||||
__device__ __forceinline__ void tma_store_wait() {
|
||||
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
|
@ -171,8 +171,8 @@ class Buffer:
|
||||
"""
|
||||
|
||||
config_map = {
|
||||
2: Config(Buffer.num_sms, 16, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 16, 256, 6, 128),
|
||||
2: Config(Buffer.num_sms, 24, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
8: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
16: Config(Buffer.num_sms, 16, 288, 20, 128),
|
||||
24: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
@ -198,9 +198,9 @@ class Buffer:
|
||||
"""
|
||||
|
||||
config_map = {
|
||||
2: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
8: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
2: Config(Buffer.num_sms, 10, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 9, 256, 6, 128),
|
||||
8: Config(Buffer.num_sms, 4, 256, 6, 128),
|
||||
16: Config(Buffer.num_sms, 2, 288, 28, 128),
|
||||
24: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
32: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
|
@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
|
||||
if nvl_chunk_size > 0:
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
else:
|
||||
# Test default config as well
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
config = deep_ep.Buffer.get_dispatch_config(num_ranks)
|
||||
tune_args = {'x': current_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
if t < best_time and nvl_chunk_size > 0:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
|
||||
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
||||
print('', flush=True)
|
||||
@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 7, 1):
|
||||
for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
|
||||
if nvl_chunk_size > 0:
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
else:
|
||||
# Test default config as well
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
config = deep_ep.Buffer.get_combine_config(num_ranks)
|
||||
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
if t < best_time:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
|
||||
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
if t < best_time and nvl_chunk_size > 0:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
|
||||
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
|
||||
torch.manual_seed(rank)
|
||||
|
||||
@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
# Destroy the communication group
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_processes = 8
|
||||
|
Loading…
Reference in New Issue
Block a user