Cache output stride parameters in registers to reduce global loads

This commit is contained in:
Gareth Jones 2025-02-23 18:45:40 -08:00
parent ccb208bcac
commit 46bafd9e03

View File

@ -28,7 +28,7 @@ constexpr auto getSmemLayoutK() {
} }
} }
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::bfloat16_t, int kHeadDimV_ = 0> template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla { struct Flash_fwd_kernel_traits_mla {
using Element = elem_type; using Element = elem_type;
using ElementAccum = float; using ElementAccum = float;