From 33e110bb66ddda2185f2efaa71f3037b615120a1 Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 20:08:19 -0800 Subject: [PATCH] implement the index --- csrc/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index b78247d..63f05f8 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -187,7 +187,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const // Stage accumulator fragment to shared memory using tiled copy cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_o = bidb * o_batch_stride + m_block * kBlockM * o_row_stride + bidh * o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;