diff --git a/csrc/flash_fwd_mla_fp8_sm90.cu b/csrc/flash_fwd_mla_fp8_sm90.cu new file mode 100644 index 0000000..2384a30 --- /dev/null +++ b/csrc/flash_fwd_mla_fp8_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/setup.py b/setup.py index 0a3bd17..c622b7c 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ ext_modules.append( sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args,