Cache output stride parameters in registers to reduce global loads

This commit is contained in:
Gareth Jones 2025-02-23 18:44:25 -08:00
parent 5fb94d668f
commit ccb208bcac

View File

@ -154,6 +154,11 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const
const int tidx = threadIdx.x;
// Cache frequently used parameters into registers for optimization
const index_t o_batch_stride = __ldg(&params.o_batch_stride);
const index_t o_row_stride = __ldg(&params.o_row_stride);
const index_t o_head_stride = __ldg(&params.o_head_stride);
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);