From 8997d137456e7d22e9540d8b0ea44f08e2398a3c Mon Sep 17 00:00:00 2001 From: zhouzihan30 Date: Fri, 21 Mar 2025 16:17:08 +0800 Subject: [PATCH] Performance: Async copy for q in advance, roughly 0.5% performance gain Hi, I think we can optimize the process that load q from gmem to smem. In function compute_attn_1rowblock_splitkv_mla, after the last calculation of q @ k_T, we no longer need q. So we can use async copy before the next function compute_attn_1rowblock_splitkv_mla run, This means that load the next q to the smem in advance. In order to prevent the valid values from being overwritten in smem, I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py to test. The test can pass normally without any calculation errors. I tested it on H800, and I use the average of 10 tests as the final result, each test interval is 3 seconds to stabilize the GPU frequency. The number of times to load q is very small, so this does not bring much performance improvement. Under some parameters, there is a slight decrease in performance, but it is gratifying that there is a roughly 0.5% performance improvement overall. batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage 64,1087,128,1384,1407,1.66% 64,2111,128,1744,1761,0.97% 64,4159,128,2188,2197,0.41% 64,8255,128,2341,2345,0.17% 64,16447,128,2330,2338,0.34% 64,32831,128,2374,2374,0.0% 128,1151,128,1756,1763,0.4% 128,2175,128,2066,2072,0.29% 128,4223,128,2284,2290,0.26% 128,8319,128,2343,2349,0.26% 128,16511,128,2375,2373,-0.08% 128,32895,128,2351,2358,0.3% 256,1279,128,2033,2035,0.1% 256,2303,128,2232,2228,-0.18% 256,4351,128,2322,2340,0.78% 256,8447,128,2371,2367,-0.17% 256,16639,128,2359,2394,1.48% 256,33023,128,2381,2392,0.46% Thanks! --- csrc/flash_fwd_mla_kernel.h | 57 ++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index d96acd8..0e6f951 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -13,7 +13,6 @@ using namespace cute; #include "static_switch.h" #include "flash_mla.h" - template constexpr auto getSmemLayoutK() { constexpr int headSizeBytes = sizeof(PrecType) * DIM; @@ -133,6 +132,7 @@ struct SharedStorageMLA { cute::array_aligned> smem_scale; }; struct { + cute::array_aligned> smem_q_2; cute::array_aligned> smem_max; cute::array_aligned> smem_sum; cute::array_aligned> smem_o; @@ -233,6 +233,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, + const bool if_first_batch, const bool if_last_batch, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; @@ -347,22 +348,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int cur_block_table = __ldg(&block_table[n_block]); - const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); + if(if_first_batch) { + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + } const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, @@ -396,6 +398,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { + bool need_copy_in_advance = (n_block == n_block_min && !if_last_batch); flash::cp_async_wait<0>(); __syncthreads(); @@ -413,6 +416,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + if(__builtin_expect(need_copy_in_advance, 0)) { + const index_t row_offset_q = (bidb + 1) * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + } + if (n_block - 2 >= n_block_min) { cur_block_table = __ldg(&block_table[n_block - 2]); } @@ -467,6 +488,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const bool if_first_batch = batch_id == begin_idx; + const bool if_last_batch = batch_id == end_idx; const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; @@ -475,7 +498,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, if_first_batch, if_last_batch, shared_storage); } }