diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index a865fc5..d6f8108 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -218,9 +218,12 @@ mha_fwd_kvcache_mla( run_mha_fwd_splitkv_mla(params, stream); } #endif + #ifndef FLASH_MLA_DISABLE_FP8 else if (q_dtype == torch::kFloat8_e4m3fn) { run_mha_fwd_splitkv_mla(params, stream); - } else { + } + #endif + else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } diff --git a/setup.py b/setup.py index ef1a8a7..0b971c4 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ from torch.utils.cpp_extension import ( ) DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): @@ -23,12 +24,13 @@ def get_sources(): sources = [ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp8_sm90.cu", "csrc/flash_fwd_mla_metadata.cu", ] if not DISABLE_FP16: sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + if not DISABLE_FP8: + sources.append("csrc/flash_fwd_mla_fp8_sm90.cu") return sources @@ -37,6 +39,8 @@ def get_features_args(): features_args = [] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") + if DISABLE_FP8: + features_args.append("-DFLASH_MLA_DISABLE_FP8") return features_args