This commit is contained in:
chenhongmin.will 2025-03-01 14:44:25 +08:00
parent 6199b0b4b5
commit 7fafcd217d
2 changed files with 9 additions and 2 deletions

View File

@ -218,9 +218,12 @@ mha_fwd_kvcache_mla(
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
}
#endif
#ifndef FLASH_MLA_DISABLE_FP8
else if (q_dtype == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
} else {
}
#endif
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}

View File

@ -12,6 +12,7 @@ from torch.utils.cpp_extension import (
)
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") == "TRUE"
def append_nvcc_threads(nvcc_extra_args):
@ -23,12 +24,13 @@ def get_sources():
sources = [
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_fp8_sm90.cu",
"csrc/flash_fwd_mla_metadata.cu",
]
if not DISABLE_FP16:
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
if not DISABLE_FP8:
sources.append("csrc/flash_fwd_mla_fp8_sm90.cu")
return sources
@ -37,6 +39,8 @@ def get_features_args():
features_args = []
if DISABLE_FP16:
features_args.append("-DFLASH_MLA_DISABLE_FP16")
if DISABLE_FP8:
features_args.append("-DFLASH_MLA_DISABLE_FP8")
return features_args