From ccb208bcac49cf8fcc4ccefa2930271898abb69d Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 18:44:25 -0800 Subject: [PATCH] Cache output stride parameters in registers to reduce global loads --- csrc/flash_fwd_mla_kernel.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index e3e46fd..9265a1a 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -154,6 +154,11 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const const int tidx = threadIdx.x; + // Cache frequently used parameters into registers for optimization + const index_t o_batch_stride = __ldg(¶ms.o_batch_stride); + const index_t o_row_stride = __ldg(¶ms.o_row_stride); + const index_t o_head_stride = __ldg(¶ms.o_head_stride); + typename Kernel_traits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);