diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index db1f0b2..005607a 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -9,7 +9,10 @@ set(CUDA_SEPARABLE_COMPILATION ON) list(APPEND CUDA_NVCC_FLAGS "-O3") list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") -set(TORCH_CUDA_ARCH_LIST "9.0") +set(USE_SYSTEM_NVTX on) +set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") +set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") + find_package(CUDAToolkit REQUIRED) find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) @@ -19,9 +22,8 @@ add_library(nvshmem ALIAS nvshmem::nvshmem) add_library(nvshmem_host ALIAS nvshmem::nvshmem_host) add_library(nvshmem_device ALIAS nvshmem::nvshmem_device) -# Seems bugs with CMake, NVCC 12 and C++ 17 set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CUDA_STANDARD 14) +set(CMAKE_CUDA_STANDARD 17) include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR}) link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR}) diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index 757e8a3..f6937da 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -19,8 +19,6 @@ #ifdef __CLION_IDE__ #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } -#define printf host_device_printf #endif // Remove Torch restrictions diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 4280166..2954ac2 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -174,7 +174,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, #undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE } -template +template __global__ void __launch_bounds__(kNumThreads, 1) dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, @@ -183,7 +183,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to void **buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); - const auto thread_id = static_cast(threadIdx.x); + const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); const bool is_sender = sm_id % 2 == 0; EP_DEVICE_ASSERT(num_sms % 2 == 0); @@ -232,19 +232,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); auto channel_x_scales_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto half_hidden_int4 = hidden_int4 / 2; + auto half_hidden_bytes = half_hidden_int4 * static_cast(sizeof(int4)); + auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + half_hidden_bytes); + uint32_t tma_phase = 0; + if (lane_id == 0) { + mbarrier_init(tma_mbarrier, 1); + fence_view_async_shared(); + fence_barrier_init(); + EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp); + } + __syncwarp(); + if (is_sender) { // Workers for sending constexpr int num_send_warps = kNumThreads / 32; constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; const auto send_thread_id = thread_id; - const auto send_lane_id = send_thread_id % 32; const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; EP_DEVICE_ASSERT(kNumRanks <= 32); EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0); // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 // NOTES: this is for distinguishing zero tokens - if (send_lane_id == 0 and send_warp_id_in_rank == 0) { + if (lane_id == 0 and send_warp_id_in_rank == 0) { int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0; st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1); value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel]; @@ -262,7 +276,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to // Check destination queue emptiness, or wait a buffer to be released (rare cases) // NOTES: the head index received by different warps may not be the same auto start_time = clock64(); - while (send_lane_id == 0) { + while (lane_id == 0) { // NOTES: we only consider the worst case, because counting the real numbers are time-consuming int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens) @@ -279,7 +293,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to int chunk_token_idx = 0; while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data - if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) + if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; // Skip if not selected @@ -294,30 +308,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to // Copy data auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4; - UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, - __ldg, st_na_global); + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global); // Copy source index - if (send_lane_id == 0) + if (lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = static_cast(token_idx); // Copy `topk_idx` and `topk_weights` with transformed index - if (send_lane_id < num_topk) { + if (lane_id < num_topk) { // Top-k index int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; - auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id); + auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id); idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; - channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value; + channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value; // Top-k weights - auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id); + auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id); weight_value = (idx_value >= 0) ? weight_value : 0.0f; - channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value; + channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = weight_value; } // Copy `x_scales` #pragma unroll - for (int i = send_lane_id; i < num_scales; i += 32) + for (int i = lane_id; i < num_scales; i += 32) channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i); } @@ -328,7 +341,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to // Move tail index // NOTES: here all warps should share the same new tail asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); - if (send_warp_id_in_rank == 0 and send_lane_id == 0) + if (send_warp_id_in_rank == 0 and lane_id == 0) st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); } } else { @@ -336,7 +349,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to constexpr int num_recv_warps = kNumThreads / 32; constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; const auto recv_thread_id = thread_id; - const auto recv_lane_id = recv_thread_id % 32; const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; EP_DEVICE_ASSERT(kNumRanks <= 32); @@ -348,9 +360,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to // Receive channel offset int total_offset, num_tokens_to_recv; - while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0); - while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0); - if (recv_lane_id == 0) { + while (lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0); + while (lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0); + if (lane_id == 0) { total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; if (recv_warp_id_in_rank == 0) recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset; @@ -393,8 +405,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; - UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, - ld_nc_global, st_na_global); + #pragma unroll + for (int i = 0; i < 2; ++ i) if (lane_id == 0) { + tma_store_wait(); + tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes); + mbarrier_wait(tma_mbarrier, tma_phase); + tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false); + } + __syncwarp(); } // Copy `src_idx` @@ -426,12 +445,16 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to cached_channel_head_idx += num_recv_tokens; total_offset += num_recv_tokens; asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); - if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0) + if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0) st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); // Exit num_tokens_to_recv -= num_recv_tokens; } + + // Make TMA store visible to the next kernel + if (lane_id == 0) + tma_store_wait(); } } @@ -441,17 +464,22 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { - constexpr int kNumThreads = 512; + constexpr int kNumThreads = 768; + constexpr int kNumTMABytesPerWarp = 8192; + constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); -#define DISPATCH_LAUNCH_CASE(ranks) \ -LAUNCH_KERNEL(&cfg, dispatch, \ - reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ - send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ - is_token_in_rank, channel_prefix_matrix, \ - num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ - buffer_ptrs, rank, \ - num_max_send_tokens, num_recv_buffer_tokens); \ -break +#define DISPATCH_LAUNCH_CASE(ranks) { \ + auto kernel = dispatch; \ + EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ + cfg.dynamicSmemBytes = smem_size; \ + LAUNCH_KERNEL(&cfg, kernel, \ + reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ + send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ + is_token_in_rank, channel_prefix_matrix, \ + num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ + buffer_ptrs, rank, \ + num_max_send_tokens, num_recv_buffer_tokens); \ + } break // Even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(num_sms % 2 == 0); @@ -532,7 +560,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, #undef CACHED_NOTIFY_COMBINE } -template +template __global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x, float* recv_topk_weights, const dtype_t* x, const float* topk_weights, @@ -542,7 +570,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); - const auto sm_id = static_cast(blockIdx.x); + const auto sm_id = static_cast(blockIdx.x), lane_id = get_lane_id(); const auto num_channels = num_sms / 2; const bool is_sender = sm_id % 2 == 0; const int responsible_channel = sm_id / 2; @@ -553,16 +581,21 @@ combine(dtype_t* recv_x, float* recv_topk_weights, auto x_int4 = reinterpret_cast(x); auto recv_int4 = reinterpret_cast(recv_x); + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp; + if (is_sender) { // Workers for sending // Several warps are responsible for a single rank - constexpr int num_send_warps = kNumThreads / 32; - constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; + constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks; + constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks; const auto num_threads_per_rank = num_send_warps_per_rank * 32; const auto send_thread_id = thread_id; - const auto send_lane_id = send_thread_id % 32; + const auto send_warp_id = send_thread_id / 32; const auto send_rank_id = thread_id / num_threads_per_rank; - const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; + const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank; + EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count"); // Calculate pointers by the specific layout auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[send_rank_id])); @@ -595,7 +628,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, // Check destination queue emptiness, or wait a buffer to be released (rare cases) auto start_time = clock64(); int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast(token_idx)); - while (send_lane_id == 0) { + while (lane_id == 0) { // NOTES: we only consider the worst case, because counting the real numbers are time-consuming int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens) @@ -618,22 +651,22 @@ combine(dtype_t* recv_x, float* recv_topk_weights, // Copy data auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); // Send source index - if (send_lane_id == 0) + if (lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); // Send `topk_weights` - if (num_topk > 0 and send_lane_id < num_topk) - channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); + if (num_topk > 0 and lane_id < num_topk) + channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + lane_id); } token_idx += num_round_tokens; current_channel_tail_idx += num_round_tokens; // Move tail index asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank)); - if (send_lane_id == 0 and send_warp_id_in_rank == 0) + if (lane_id == 0 and send_warp_id_in_rank == 0) st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); } } else { @@ -641,7 +674,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights, // One warp for moving the queue head, others for reduction constexpr int num_recv_warps = kNumThreads / 32; const auto recv_warp_id = thread_id / 32; - const auto recv_lane_id = thread_id % 32; EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32); EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0); @@ -651,19 +683,19 @@ combine(dtype_t* recv_x, float* recv_topk_weights, __shared__ volatile bool warp_retired[num_recv_warps]; if (thread_id < num_recv_warps) warp_retired[thread_id] = false; - if (recv_lane_id < kNumRanks) - warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; + if (lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][lane_id] = 0; if (thread_id < kNumRanks) channel_tail_idx[thread_id] = 0; asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads)); if (thread_id < 32) { - int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id; + int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id; int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; // Queue head updater int last_head = 0; - while (recv_lane_id < kNumRanks) { + while (lane_id < kNumRanks) { // Check retired bool retired = true; #pragma unroll @@ -673,13 +705,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights, break; // Update queue tail - channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); + channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); // Update minimum head int min_head = std::numeric_limits::max(); #pragma unroll for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i]) - min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); + min_head = min(min_head, warp_channel_head_idx[i][lane_id]); if (min_head != std::numeric_limits::max() and min_head > last_head) st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); } @@ -716,11 +748,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights, for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) { // Read expected head int expected_head = -1; - if (recv_lane_id < kNumRanks) - expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); + if (lane_id < kNumRanks) + expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id); auto start_time = clock64(); - while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) { + while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); @@ -740,9 +772,16 @@ combine(dtype_t* recv_x, float* recv_topk_weights, } } - // Reduce data + // Wait shared memory release + if (lane_id == 0) + tma_store_wait(); + __syncwarp(); + + // Reduce data with pipeline + constexpr int kNumStages = 8; + EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count"); #pragma unroll - for (int i = recv_lane_id; i < hidden_int4; i += 32) { + for (int i = lane_id; i < hidden_int4; i += 32) { // Read buffers int4 recv_value_int4[kNumRanks]; #pragma unroll @@ -759,33 +798,55 @@ combine(dtype_t* recv_x, float* recv_topk_weights, values[k] += static_cast(recv_value_dtypes[k]); } - // Cast back to `dtype_t` and write + // Cast back to `dtype_t` int4 out_int4; auto out_dtypes = reinterpret_cast(&out_int4); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++ j) out_dtypes[j] = static_cast(values[j]); - recv_int4[token_idx * hidden_int4 + i] = out_int4; + + // Wait TMA arrival + if (lane_id == 0) + tma_store_wait(); + __syncwarp(); + + // Write into TMA buffer + auto tma_stage_idx = (i / 32) % kNumStages; + reinterpret_cast(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4; + + // Issue TMA + tma_store_fence(); + __syncwarp(); + if (lane_id == 0) { + auto tma_bytes = min(32, hidden_int4 - i) * static_cast(sizeof(int4)); + tma_store_1d(reinterpret_cast(tma_buffer) + tma_stage_idx * 32, + recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false); + } + __syncwarp(); } // Reduce `topk_weights` - if (recv_lane_id < num_topk) { + if (lane_id < num_topk) { float value = 0; #pragma unroll for (int i = 0; i < num_topk_ranks; ++ i) - value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id); - recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; + value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + lane_id); + recv_topk_weights[token_idx * num_topk + lane_id] = value; } // Update head - if (recv_lane_id < kNumRanks) - warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; + if (lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; } // Retired __syncwarp(); - if (recv_lane_id == 0) + if (lane_id == 0) warp_retired[recv_warp_id] = true; + + // Make TMA store visible to the next kernel + if (lane_id == 0) + tma_store_wait(); } } } @@ -799,15 +860,20 @@ void combine(cudaDataType_t type, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 768; + constexpr int kNumTMABytesPerWarp = 4096; + constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); -#define COMBINE_LAUNCH_CASE(dtype, ranks) \ - LAUNCH_KERNEL(&cfg, (combine), \ +#define COMBINE_LAUNCH_CASE(dtype, ranks) { \ + auto kernel = combine; \ + EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ + cfg.dynamicSmemBytes = smem_size; \ + LAUNCH_KERNEL(&cfg, kernel, \ reinterpret_cast(recv_x), recv_topk_weights, \ reinterpret_cast(x), topk_weights, \ src_idx, rank_prefix_matrix, channel_prefix_matrix, \ send_head, num_tokens, num_recv_tokens, hidden, num_topk, \ buffer_ptrs, rank, \ - num_max_send_tokens, num_recv_buffer_tokens); \ + num_max_send_tokens, num_recv_buffer_tokens); } \ break #define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 53f89c1..98296dc 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); } +__device__ __forceinline__ void fence_view_async_shared() { + asm volatile("fence.proxy.async.shared::cta; \n" :: ); +} + +__device__ __forceinline__ void fence_barrier_init() { + asm volatile("fence.mbarrier_init.release.cluster; \n" :: ); +} + +__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr)); +} + +__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680)); + phase ^= 1; +} + +__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr)); +} + +__device__ __forceinline__ void tma_store_fence() { + asm volatile ("fence.proxy.async.shared::cta;"); +} + +constexpr uint64_t kEvictFirst = 0x12f0000000000000; +constexpr uint64_t kEvictNormal = 0x1000000000000000; + +__device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes, + bool evict_first = true) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + auto smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal; + asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" + :: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory"); +} + +__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes, + bool evict_first = true) { + auto smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal; + asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" + :: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory"); + asm volatile("cp.async.bulk.commit_group;"); +} + +template +__device__ __forceinline__ void tma_store_wait() { + asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory"); +} + template __host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 3473b56..3fa069d 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -171,8 +171,8 @@ class Buffer: """ config_map = { - 2: Config(Buffer.num_sms, 16, 256, 6, 128), - 4: Config(Buffer.num_sms, 16, 256, 6, 128), + 2: Config(Buffer.num_sms, 24, 256, 6, 128), + 4: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 6, 256, 6, 128), 16: Config(Buffer.num_sms, 16, 288, 20, 128), 24: Config(Buffer.num_sms, 8, 288, 32, 128), @@ -198,9 +198,9 @@ class Buffer: """ config_map = { - 2: Config(Buffer.num_sms, 6, 256, 6, 128), - 4: Config(Buffer.num_sms, 6, 256, 6, 128), - 8: Config(Buffer.num_sms, 6, 256, 6, 128), + 2: Config(Buffer.num_sms, 10, 256, 6, 128), + 4: Config(Buffer.num_sms, 9, 256, 6, 128), + 8: Config(Buffer.num_sms, 4, 256, 6, 128), 16: Config(Buffer.num_sms, 2, 288, 28, 128), 24: Config(Buffer.num_sms, 1, 288, 20, 128), 32: Config(Buffer.num_sms, 1, 288, 20, 128), diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 68f16f7..c069c6d 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: for current_x in (x_e4m3, x): best_time, best_results = 1e10, None nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes - for nvl_chunk_size in range(4, 33, 4): - config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ): + if nvl_chunk_size > 0: + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + else: + # Test default config as well + deep_ep.Buffer.set_num_sms(num_sms) + config = deep_ep.Buffer.get_dispatch_config(num_ranks) tune_args = {'x': current_x, 'handle': handle, 'config': config} t = bench(lambda: buffer.dispatch(**tune_args))[0] - if t < best_time: + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) if local_rank == 0: print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) print('', flush=True) @@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: # Tune combine performance best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 7, 1): - config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ): + if nvl_chunk_size > 0: + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + else: + # Test default config as well + deep_ep.Buffer.set_num_sms(num_sms) + config = deep_ep.Buffer.get_combine_config(num_ranks) tune_args = {'x': recv_x, 'handle': handle, 'config': config} t = bench(lambda: buffer.combine(**tune_args))[0] if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) - if t < best_time: + print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: @@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int): ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) - buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, + buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) torch.manual_seed(rank) @@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int): buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) + # Destroy the communication group + dist.barrier() + dist.destroy_process_group() + if __name__ == '__main__': num_processes = 8