implement the index

This commit is contained in:
Gareth Jones 2025-02-23 20:08:19 -08:00
parent 46bafd9e03
commit 33e110bb66

View File

@ -187,7 +187,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const
// Stage accumulator fragment to shared memory using tiled copy
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_o = bidb * o_batch_stride + m_block * kBlockM * o_row_stride + bidh * o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;