mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Cache output stride parameters in registers to reduce global loads
This commit is contained in:
parent
5fb94d668f
commit
ccb208bcac
@ -154,6 +154,11 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// Cache frequently used parameters into registers for optimization
|
||||
const index_t o_batch_stride = __ldg(¶ms.o_batch_stride);
|
||||
const index_t o_row_stride = __ldg(¶ms.o_row_stride);
|
||||
const index_t o_head_stride = __ldg(¶ms.o_head_stride);
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user