mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Cache output stride parameters in registers to reduce global loads
This commit is contained in:
parent
ccb208bcac
commit
46bafd9e03
@ -28,7 +28,7 @@ constexpr auto getSmemLayoutK() {
|
||||
}
|
||||
}
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
struct Flash_fwd_kernel_traits_mla {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
|
Loading…
Reference in New Issue
Block a user