diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index d96acd8..0e6f951 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -13,7 +13,6 @@ using namespace cute; #include "static_switch.h" #include "flash_mla.h" - template constexpr auto getSmemLayoutK() { constexpr int headSizeBytes = sizeof(PrecType) * DIM; @@ -133,6 +132,7 @@ struct SharedStorageMLA { cute::array_aligned> smem_scale; }; struct { + cute::array_aligned> smem_q_2; cute::array_aligned> smem_max; cute::array_aligned> smem_sum; cute::array_aligned> smem_o; @@ -233,6 +233,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, + const bool if_first_batch, const bool if_last_batch, SharedStorage &shared_storage) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; @@ -347,22 +348,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int cur_block_table = __ldg(&block_table[n_block]); - const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); + if(if_first_batch) { + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + } const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, @@ -396,6 +398,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { + bool need_copy_in_advance = (n_block == n_block_min && !if_last_batch); flash::cp_async_wait<0>(); __syncthreads(); @@ -413,6 +416,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + if(__builtin_expect(need_copy_in_advance, 0)) { + const index_t row_offset_q = (bidb + 1) * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + } + if (n_block - 2 >= n_block_min) { cur_block_table = __ldg(&block_table[n_block - 2]); } @@ -467,6 +488,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const bool if_first_batch = batch_id == begin_idx; + const bool if_last_batch = batch_id == end_idx; const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; @@ -475,7 +498,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, if_first_batch, if_last_batch, shared_storage); } }