Performance: Async copy for q in advance, roughly 0.5% performance gain

Hi, I think we can optimize the process that load q from gmem to smem.
In function compute_attn_1rowblock_splitkv_mla, after the last calculation
of q @ k_T, we no longer need q. So we can use async copy before the
next function compute_attn_1rowblock_splitkv_mla run, This means that load
the next q to the smem in advance.

In order to prevent the valid values from being overwritten in smem,
I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py
to test. The test can pass normally without any calculation errors.

I tested it on H800, and I use the average of 10 tests as the final result,
each test interval is 3 seconds to stabilize the GPU frequency.

The number of times to load q is very small, so this does not bring much
performance improvement. Under some parameters, there is a slight decrease
in performance, but it is gratifying that there is a roughly 0.5%
performance improvement overall.

batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage
64,1087,128,1384,1407,1.66%
64,2111,128,1744,1761,0.97%
64,4159,128,2188,2197,0.41%
64,8255,128,2341,2345,0.17%
64,16447,128,2330,2338,0.34%
64,32831,128,2374,2374,0.0%
128,1151,128,1756,1763,0.4%
128,2175,128,2066,2072,0.29%
128,4223,128,2284,2290,0.26%
128,8319,128,2343,2349,0.26%
128,16511,128,2375,2373,-0.08%
128,32895,128,2351,2358,0.3%
256,1279,128,2033,2035,0.1%
256,2303,128,2232,2228,-0.18%
256,4351,128,2322,2340,0.78%
256,8447,128,2371,2367,-0.17%
256,16639,128,2359,2394,1.48%
256,33023,128,2381,2392,0.46%

Thanks!
This commit is contained in:
zhouzihan30 2025-03-21 16:17:08 +08:00
parent b31bfe72a8
commit 8997d13745

View File

@ -13,7 +13,6 @@ using namespace cute;
#include "static_switch.h"
#include "flash_mla.h"
template<typename PrecType, int DIM, int DIM2 = DIM>
constexpr auto getSmemLayoutK() {
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
@ -133,6 +132,7 @@ struct SharedStorageMLA {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
};
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q_2;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
@ -233,6 +233,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
const bool if_first_batch, const bool if_last_batch,
SharedStorage &shared_storage) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
@ -347,22 +348,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
if(if_first_batch) {
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
}
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
@ -396,6 +398,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
#pragma unroll 1
for (; n_block >= n_block_min; --n_block) {
bool need_copy_in_advance = (n_block == n_block_min && !if_last_batch);
flash::cp_async_wait<0>();
__syncthreads();
@ -413,6 +416,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
if(__builtin_expect(need_copy_in_advance, 0)) {
const index_t row_offset_q = (bidb + 1) * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
}
if (n_block - 2 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 2]);
}
@ -467,6 +488,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const bool if_first_batch = batch_id == begin_idx;
const bool if_last_batch = batch_id == end_idx;
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
@ -475,7 +498,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, if_first_batch, if_last_batch, shared_storage);
}
}