diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index b7b11f1..d2567fe 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,6 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -189,9 +188,15 @@ mha_fwd_kvcache_mla( if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, stream); - } else { + } + #ifndef FLASH_MLA_DISABLE_FP16 + else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, stream); } + #endif + else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/setup.py b/setup.py index 662a301..6377b1e 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,29 @@ from torch.utils.cpp_extension import ( IS_WINDOWS, ) +DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] +def get_sources(): + sources = [ + "csrc/flash_api.cpp", + "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", + ] + + if not DISABLE_FP16: + sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + + return sources + +def get_features_args(): + features_args = [] + if DISABLE_FP16: + features_args.append("-DFLASH_MLA_DISABLE_FP16") + return features_args subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) @@ -34,14 +52,9 @@ ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla_cuda", - sources=[ - "csrc/flash_api.cpp", - "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp16_sm90.cu", - "csrc/flash_fwd_mla_metadata.cu", - ], + sources=get_sources(), extra_compile_args={ - "cxx": cxx_args, + "cxx": cxx_args + get_features_args(), "nvcc": append_nvcc_threads( [ "-O3", @@ -59,7 +72,7 @@ ext_modules.append( "--ptxas-options=-v,--register-usage-level=10" ] + cc_flag - ), + ) + get_features_args(), }, include_dirs=[ Path(this_dir) / "csrc",