mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Performance: Async copy for q in advance, roughly 0.5% performance gain
Hi, I think we can optimize the process that load q from gmem to smem. In function compute_attn_1rowblock_splitkv_mla, after the last calculation of q @ k_T, we no longer need q. So we can use async copy before the next function compute_attn_1rowblock_splitkv_mla run, This means that load the next q to the smem in advance. In order to prevent the valid values from being overwritten in smem, I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py to test. The test can pass normally without any calculation errors. I tested it on H800, and I use the average of 10 tests as the final result, each test interval is 3 seconds to stabilize the GPU frequency. The number of times to load q is very small, so this does not bring much performance improvement. Under some parameters, there is a slight decrease in performance, but it is gratifying that there is a roughly 0.5% performance improvement overall. batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage 64,1087,128,1384,1407,1.66% 64,2111,128,1744,1761,0.97% 64,4159,128,2188,2197,0.41% 64,8255,128,2341,2345,0.17% 64,16447,128,2330,2338,0.34% 64,32831,128,2374,2374,0.0% 128,1151,128,1756,1763,0.4% 128,2175,128,2066,2072,0.29% 128,4223,128,2284,2290,0.26% 128,8319,128,2343,2349,0.26% 128,16511,128,2375,2373,-0.08% 128,32895,128,2351,2358,0.3% 256,1279,128,2033,2035,0.1% 256,2303,128,2232,2228,-0.18% 256,4351,128,2322,2340,0.78% 256,8447,128,2371,2367,-0.17% 256,16639,128,2359,2394,1.48% 256,33023,128,2381,2392,0.46% Thanks!
This commit is contained in:
parent
b31bfe72a8
commit
8997d13745
@ -13,7 +13,6 @@ using namespace cute;
|
||||
#include "static_switch.h"
|
||||
#include "flash_mla.h"
|
||||
|
||||
|
||||
template<typename PrecType, int DIM, int DIM2 = DIM>
|
||||
constexpr auto getSmemLayoutK() {
|
||||
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
||||
@ -133,6 +132,7 @@ struct SharedStorageMLA {
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||
};
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q_2;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
|
||||
@ -233,6 +233,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
const int bidb, const int bidh, const int m_block,
|
||||
const int n_split_idx, const int seqlen_k,
|
||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
||||
const bool if_first_batch, const bool if_last_batch,
|
||||
SharedStorage &shared_storage) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
@ -347,22 +348,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||
int cur_block_table = __ldg(&block_table[n_block]);
|
||||
|
||||
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
params.seqlen_q - m_block * kBlockM);
|
||||
if(if_first_batch) {
|
||||
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
params.seqlen_q - m_block * kBlockM);
|
||||
}
|
||||
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
@ -396,6 +398,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
|
||||
#pragma unroll 1
|
||||
for (; n_block >= n_block_min; --n_block) {
|
||||
bool need_copy_in_advance = (n_block == n_block_min && !if_last_batch);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
@ -413,6 +416,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
if(__builtin_expect(need_copy_in_advance, 0)) {
|
||||
const index_t row_offset_q = (bidb + 1) * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
params.seqlen_q - m_block * kBlockM);
|
||||
}
|
||||
|
||||
if (n_block - 2 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 2]);
|
||||
}
|
||||
@ -467,6 +488,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
|
||||
#pragma unroll 1
|
||||
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
|
||||
const bool if_first_batch = batch_id == begin_idx;
|
||||
const bool if_last_batch = batch_id == end_idx;
|
||||
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
|
||||
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
|
||||
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
|
||||
@ -475,7 +498,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
if (batch_id > begin_idx) {
|
||||
__syncthreads(); // Barrier between two tiles.
|
||||
}
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, if_first_batch, if_last_batch, shared_storage);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user