mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Cache output stride parameters in registers to reduce global loads
This commit is contained in:
parent
ccb208bcac
commit
46bafd9e03
@ -28,7 +28,7 @@ constexpr auto getSmemLayoutK() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||||
struct Flash_fwd_kernel_traits_mla {
|
struct Flash_fwd_kernel_traits_mla {
|
||||||
using Element = elem_type;
|
using Element = elem_type;
|
||||||
using ElementAccum = float;
|
using ElementAccum = float;
|
||||||
|
Loading…
Reference in New Issue
Block a user