Use TMA instead of LD/ST for intra-node normal kernels (#191)

* Update CMake files

* Use TMA instead of LD/ST for intranode dispatch

* Use TMA instead of LD/ST for intranode combine

* Adjust configs

* Test default configs as well

* More warps for combine

* Add inter-thread fence

* Enable more warps

* Do not use TMA for senders

* Update configs

* Remove useless wait
This commit is contained in:
Chenggang Zhao 2025-06-06 15:40:17 +08:00 committed by GitHub
parent df4debe30c
commit c8dceba110
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 230 additions and 87 deletions

View File

@ -9,7 +9,10 @@ set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-O3") list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(TORCH_CUDA_ARCH_LIST "9.0") set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED) find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
@ -19,9 +22,8 @@ add_library(nvshmem ALIAS nvshmem::nvshmem)
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host) add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device) add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
# Seems bugs with CMake, NVCC 12 and C++ 17
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 14) set(CMAKE_CUDA_STANDARD 17)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR}) include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR}) link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})

View File

@ -19,8 +19,6 @@
#ifdef __CLION_IDE__ #ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif #endif
// Remove Torch restrictions // Remove Torch restrictions

View File

@ -174,7 +174,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE #undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
} }
template <int kNumRanks, int kNumThreads> template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
__global__ void __launch_bounds__(kNumThreads, 1) __global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
@ -183,7 +183,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
void **buffer_ptrs, int rank, void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) { int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x); const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const bool is_sender = sm_id % 2 == 0; const bool is_sender = sm_id % 2 == 0;
EP_DEVICE_ASSERT(num_sms % 2 == 0); EP_DEVICE_ASSERT(num_sms % 2 == 0);
@ -232,19 +232,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto half_hidden_int4 = hidden_int4 / 2;
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + half_hidden_bytes);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
if (is_sender) { if (is_sender) {
// Workers for sending // Workers for sending
constexpr int num_send_warps = kNumThreads / 32; constexpr int num_send_warps = kNumThreads / 32;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
const auto send_thread_id = thread_id; const auto send_thread_id = thread_id;
const auto send_lane_id = send_thread_id % 32;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
EP_DEVICE_ASSERT(kNumRanks <= 32); EP_DEVICE_ASSERT(kNumRanks <= 32);
EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0); EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens // NOTES: this is for distinguishing zero tokens
if (send_lane_id == 0 and send_warp_id_in_rank == 0) { if (lane_id == 0 and send_warp_id_in_rank == 0) {
int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0; int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1); st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel]; value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
@ -262,7 +276,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Check destination queue emptiness, or wait a buffer to be released (rare cases) // Check destination queue emptiness, or wait a buffer to be released (rare cases)
// NOTES: the head index received by different warps may not be the same // NOTES: the head index received by different warps may not be the same
auto start_time = clock64(); auto start_time = clock64();
while (send_lane_id == 0) { while (lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming // NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens) if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
@ -279,7 +293,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int chunk_token_idx = 0; int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
// Skip if not selected // Skip if not selected
@ -294,30 +308,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy data // Copy data
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global);
__ldg, st_na_global);
// Copy source index // Copy source index
if (send_lane_id == 0) if (lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx); channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
// Copy `topk_idx` and `topk_weights` with transformed index // Copy `topk_idx` and `topk_weights` with transformed index
if (send_lane_id < num_topk) { if (lane_id < num_topk) {
// Top-k index // Top-k index
int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id); auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id);
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value; channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;
// Top-k weights // Top-k weights
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id); auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);
weight_value = (idx_value >= 0) ? weight_value : 0.0f; weight_value = (idx_value >= 0) ? weight_value : 0.0f;
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value; channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = weight_value;
} }
// Copy `x_scales` // Copy `x_scales`
#pragma unroll #pragma unroll
for (int i = send_lane_id; i < num_scales; i += 32) for (int i = lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i); channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
} }
@ -328,7 +341,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Move tail index // Move tail index
// NOTES: here all warps should share the same new tail // NOTES: here all warps should share the same new tail
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
if (send_warp_id_in_rank == 0 and send_lane_id == 0) if (send_warp_id_in_rank == 0 and lane_id == 0)
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
} }
} else { } else {
@ -336,7 +349,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
constexpr int num_recv_warps = kNumThreads / 32; constexpr int num_recv_warps = kNumThreads / 32;
constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
const auto recv_thread_id = thread_id; const auto recv_thread_id = thread_id;
const auto recv_lane_id = recv_thread_id % 32;
const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;
const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32;
EP_DEVICE_ASSERT(kNumRanks <= 32); EP_DEVICE_ASSERT(kNumRanks <= 32);
@ -348,9 +360,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Receive channel offset // Receive channel offset
int total_offset, num_tokens_to_recv; int total_offset, num_tokens_to_recv;
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0); while (lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0); while (lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
if (recv_lane_id == 0) { if (lane_id == 0) {
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
if (recv_warp_id_in_rank == 0) if (recv_warp_id_in_rank == 0)
recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset; recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;
@ -393,8 +405,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4; auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, #pragma unroll
ld_nc_global, st_na_global); for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
tma_store_wait();
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
mbarrier_wait(tma_mbarrier, tma_phase);
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
}
__syncwarp();
} }
// Copy `src_idx` // Copy `src_idx`
@ -426,12 +445,16 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
cached_channel_head_idx += num_recv_tokens; cached_channel_head_idx += num_recv_tokens;
total_offset += num_recv_tokens; total_offset += num_recv_tokens;
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0) if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
// Exit // Exit
num_tokens_to_recv -= num_recv_tokens; num_tokens_to_recv -= num_recv_tokens;
} }
// Make TMA store visible to the next kernel
if (lane_id == 0)
tma_store_wait();
} }
} }
@ -441,17 +464,22 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks, void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 512; constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 8192;
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#define DISPATCH_LAUNCH_CASE(ranks) \ #define DISPATCH_LAUNCH_CASE(ranks) { \
LAUNCH_KERNEL(&cfg, dispatch<ranks, kNumThreads>, \ auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \ cfg.dynamicSmemBytes = smem_size; \
is_token_in_rank, channel_prefix_matrix, \ LAUNCH_KERNEL(&cfg, kernel, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
buffer_ptrs, rank, \ send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
num_max_send_tokens, num_recv_buffer_tokens); \ is_token_in_rank, channel_prefix_matrix, \
break num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
// Even-numbered blocks for sending, odd-numbered blocks for receiving. // Even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(num_sms % 2 == 0); EP_HOST_ASSERT(num_sms % 2 == 0);
@ -532,7 +560,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
#undef CACHED_NOTIFY_COMBINE #undef CACHED_NOTIFY_COMBINE
} }
template<typename dtype_t, int kNumRanks, int kNumThreads> template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
__global__ void __launch_bounds__(kNumThreads, 1) __global__ void __launch_bounds__(kNumThreads, 1)
combine(dtype_t* recv_x, float* recv_topk_weights, combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights, const dtype_t* x, const float* topk_weights,
@ -542,7 +570,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
int num_max_send_tokens, int num_recv_buffer_tokens) { int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x); const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();
const auto num_channels = num_sms / 2; const auto num_channels = num_sms / 2;
const bool is_sender = sm_id % 2 == 0; const bool is_sender = sm_id % 2 == 0;
const int responsible_channel = sm_id / 2; const int responsible_channel = sm_id / 2;
@ -553,16 +581,21 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto x_int4 = reinterpret_cast<const int4*>(x); auto x_int4 = reinterpret_cast<const int4*>(x);
auto recv_int4 = reinterpret_cast<int4*>(recv_x); auto recv_int4 = reinterpret_cast<int4*>(recv_x);
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
if (is_sender) { if (is_sender) {
// Workers for sending // Workers for sending
// Several warps are responsible for a single rank // Several warps are responsible for a single rank
constexpr int num_send_warps = kNumThreads / 32; constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks;
const auto num_threads_per_rank = num_send_warps_per_rank * 32; const auto num_threads_per_rank = num_send_warps_per_rank * 32;
const auto send_thread_id = thread_id; const auto send_thread_id = thread_id;
const auto send_lane_id = send_thread_id % 32; const auto send_warp_id = send_thread_id / 32;
const auto send_rank_id = thread_id / num_threads_per_rank; const auto send_rank_id = thread_id / num_threads_per_rank;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
// Calculate pointers by the specific layout // Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id])); auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
@ -595,7 +628,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Check destination queue emptiness, or wait a buffer to be released (rare cases) // Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64(); auto start_time = clock64();
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx)); int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
while (send_lane_id == 0) { while (lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming // NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens) if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
@ -618,22 +651,22 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Copy data // Copy data
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Send source index // Send source index
if (send_lane_id == 0) if (lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
// Send `topk_weights` // Send `topk_weights`
if (num_topk > 0 and send_lane_id < num_topk) if (num_topk > 0 and lane_id < num_topk)
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + lane_id);
} }
token_idx += num_round_tokens; token_idx += num_round_tokens;
current_channel_tail_idx += num_round_tokens; current_channel_tail_idx += num_round_tokens;
// Move tail index // Move tail index
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank)); asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
if (send_lane_id == 0 and send_warp_id_in_rank == 0) if (lane_id == 0 and send_warp_id_in_rank == 0)
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
} }
} else { } else {
@ -641,7 +674,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// One warp for moving the queue head, others for reduction // One warp for moving the queue head, others for reduction
constexpr int num_recv_warps = kNumThreads / 32; constexpr int num_recv_warps = kNumThreads / 32;
const auto recv_warp_id = thread_id / 32; const auto recv_warp_id = thread_id / 32;
const auto recv_lane_id = thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32); EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0); EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
@ -651,19 +683,19 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
__shared__ volatile bool warp_retired[num_recv_warps]; __shared__ volatile bool warp_retired[num_recv_warps];
if (thread_id < num_recv_warps) if (thread_id < num_recv_warps)
warp_retired[thread_id] = false; warp_retired[thread_id] = false;
if (recv_lane_id < kNumRanks) if (lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; warp_channel_head_idx[recv_warp_id][lane_id] = 0;
if (thread_id < kNumRanks) if (thread_id < kNumRanks)
channel_tail_idx[thread_id] = 0; channel_tail_idx[thread_id] = 0;
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads)); asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
if (thread_id < 32) { if (thread_id < 32) {
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id; int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater // Queue head updater
int last_head = 0; int last_head = 0;
while (recv_lane_id < kNumRanks) { while (lane_id < kNumRanks) {
// Check retired // Check retired
bool retired = true; bool retired = true;
#pragma unroll #pragma unroll
@ -673,13 +705,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
break; break;
// Update queue tail // Update queue tail
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
// Update minimum head // Update minimum head
int min_head = std::numeric_limits<int>::max(); int min_head = std::numeric_limits<int>::max();
#pragma unroll #pragma unroll
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i]) for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); min_head = min(min_head, warp_channel_head_idx[i][lane_id]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_head) if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
} }
@ -716,11 +748,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) { for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
// Read expected head // Read expected head
int expected_head = -1; int expected_head = -1;
if (recv_lane_id < kNumRanks) if (lane_id < kNumRanks)
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
auto start_time = clock64(); auto start_time = clock64();
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) { while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
// Timeout check // Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
@ -740,9 +772,16 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
} }
} }
// Reduce data // Wait shared memory release
if (lane_id == 0)
tma_store_wait();
__syncwarp();
// Reduce data with pipeline
constexpr int kNumStages = 8;
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
#pragma unroll #pragma unroll
for (int i = recv_lane_id; i < hidden_int4; i += 32) { for (int i = lane_id; i < hidden_int4; i += 32) {
// Read buffers // Read buffers
int4 recv_value_int4[kNumRanks]; int4 recv_value_int4[kNumRanks];
#pragma unroll #pragma unroll
@ -759,33 +798,55 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
values[k] += static_cast<float>(recv_value_dtypes[k]); values[k] += static_cast<float>(recv_value_dtypes[k]);
} }
// Cast back to `dtype_t` and write // Cast back to `dtype_t`
int4 out_int4; int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4); auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll #pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j) for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]); out_dtypes[j] = static_cast<dtype_t>(values[j]);
recv_int4[token_idx * hidden_int4 + i] = out_int4;
// Wait TMA arrival
if (lane_id == 0)
tma_store_wait<kNumStages - 1>();
__syncwarp();
// Write into TMA buffer
auto tma_stage_idx = (i / 32) % kNumStages;
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
// Issue TMA
tma_store_fence();
__syncwarp();
if (lane_id == 0) {
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
}
__syncwarp();
} }
// Reduce `topk_weights` // Reduce `topk_weights`
if (recv_lane_id < num_topk) { if (lane_id < num_topk) {
float value = 0; float value = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i) for (int i = 0; i < num_topk_ranks; ++ i)
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id); value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + lane_id);
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; recv_topk_weights[token_idx * num_topk + lane_id] = value;
} }
// Update head // Update head
if (recv_lane_id < kNumRanks) if (lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
} }
// Retired // Retired
__syncwarp(); __syncwarp();
if (recv_lane_id == 0) if (lane_id == 0)
warp_retired[recv_warp_id] = true; warp_retired[recv_warp_id] = true;
// Make TMA store visible to the next kernel
if (lane_id == 0)
tma_store_wait();
} }
} }
} }
@ -799,15 +860,20 @@ void combine(cudaDataType_t type,
cudaStream_t stream, int num_sms, cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens) { int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768; constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 4096;
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#define COMBINE_LAUNCH_CASE(dtype, ranks) \ #define COMBINE_LAUNCH_CASE(dtype, ranks) { \
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \ auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \ reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \ reinterpret_cast<const dtype*>(x), topk_weights, \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \ src_idx, rank_prefix_matrix, channel_prefix_matrix, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \ send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
buffer_ptrs, rank, \ buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \ num_max_send_tokens, num_recv_buffer_tokens); } \
break break
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break #define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break

View File

@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
} }
__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}
__device__ __forceinline__ void fence_barrier_init() {
asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
}
__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
phase ^= 1;
}
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void tma_store_fence() {
asm volatile ("fence.proxy.async.shared::cta;");
}
constexpr uint64_t kEvictFirst = 0x12f0000000000000;
constexpr uint64_t kEvictNormal = 0x1000000000000000;
__device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes,
bool evict_first = true) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n"
:: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory");
}
__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes,
bool evict_first = true) {
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n"
:: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory");
asm volatile("cp.async.bulk.commit_group;");
}
template <int N = 0>
__device__ __forceinline__ void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
}
template <typename dtype_t> template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { __host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b; return (a + b - 1) / b;

View File

@ -171,8 +171,8 @@ class Buffer:
""" """
config_map = { config_map = {
2: Config(Buffer.num_sms, 16, 256, 6, 128), 2: Config(Buffer.num_sms, 24, 256, 6, 128),
4: Config(Buffer.num_sms, 16, 256, 6, 128), 4: Config(Buffer.num_sms, 6, 256, 6, 128),
8: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 6, 256, 6, 128),
16: Config(Buffer.num_sms, 16, 288, 20, 128), 16: Config(Buffer.num_sms, 16, 288, 20, 128),
24: Config(Buffer.num_sms, 8, 288, 32, 128), 24: Config(Buffer.num_sms, 8, 288, 32, 128),
@ -198,9 +198,9 @@ class Buffer:
""" """
config_map = { config_map = {
2: Config(Buffer.num_sms, 6, 256, 6, 128), 2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 6, 256, 6, 128), 4: Config(Buffer.num_sms, 9, 256, 6, 128),
8: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 4, 256, 6, 128),
16: Config(Buffer.num_sms, 2, 288, 28, 128), 16: Config(Buffer.num_sms, 2, 288, 28, 128),
24: Config(Buffer.num_sms, 1, 288, 20, 128), 24: Config(Buffer.num_sms, 1, 288, 20, 128),
32: Config(Buffer.num_sms, 1, 288, 20, 128), 32: Config(Buffer.num_sms, 1, 288, 20, 128),

View File

@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
for current_x in (x_e4m3, x): for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4): for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_dispatch_config(num_ranks)
tune_args = {'x': current_x, 'handle': handle, 'config': config} tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0] t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time: if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size) best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0: if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if local_rank == 0: if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True) print('', flush=True)
@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# Tune combine performance # Tune combine performance
best_time, best_results = 1e10, None best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 7, 1): for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_combine_config(num_ranks)
tune_args = {'x': recv_x, 'handle': handle, 'config': config} tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0] t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0: if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
if t < best_time: f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size) best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0: if local_rank == 0:
@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank) torch.manual_seed(rank)
@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__': if __name__ == '__main__':
num_processes = 8 num_processes = 8