diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 9265a1a..b78247d 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -28,7 +28,7 @@ constexpr auto getSmemLayoutK() { } } -template +template struct Flash_fwd_kernel_traits_mla { using Element = elem_type; using ElementAccum = float;