mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add flag to disable FP16 compile
This commit is contained in:
parent
65fb7732fc
commit
a3b74b8574
@ -77,7 +77,6 @@ mha_fwd_kvcache_mla(
|
|||||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||||
|
|
||||||
auto q_dtype = q.dtype();
|
auto q_dtype = q.dtype();
|
||||||
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16);
|
|
||||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||||
|
|
||||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||||
@ -189,9 +188,15 @@ mha_fwd_kvcache_mla(
|
|||||||
|
|
||||||
if (q_dtype == torch::kBFloat16) {
|
if (q_dtype == torch::kBFloat16) {
|
||||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||||
} else {
|
}
|
||||||
|
#ifndef FLASH_MLA_DISABLE_FP16
|
||||||
|
else if (q_dtype == torch::kHalf) {
|
||||||
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||||
|
}
|
||||||
|
|
||||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||||
|
29
setup.py
29
setup.py
@ -11,11 +11,29 @@ from torch.utils.cpp_extension import (
|
|||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||||
|
|
||||||
def append_nvcc_threads(nvcc_extra_args):
|
def append_nvcc_threads(nvcc_extra_args):
|
||||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||||
|
|
||||||
|
def get_sources():
|
||||||
|
sources = [
|
||||||
|
"csrc/flash_api.cpp",
|
||||||
|
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||||
|
"csrc/flash_fwd_mla_metadata.cu",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not DISABLE_FP16:
|
||||||
|
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||||
|
|
||||||
|
return sources
|
||||||
|
|
||||||
|
def get_features_args():
|
||||||
|
features_args = []
|
||||||
|
if DISABLE_FP16:
|
||||||
|
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||||
|
return features_args
|
||||||
|
|
||||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||||
|
|
||||||
@ -34,14 +52,9 @@ ext_modules = []
|
|||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="flash_mla_cuda",
|
name="flash_mla_cuda",
|
||||||
sources=[
|
sources=get_sources(),
|
||||||
"csrc/flash_api.cpp",
|
|
||||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
|
||||||
"csrc/flash_fwd_mla_fp16_sm90.cu",
|
|
||||||
"csrc/flash_fwd_mla_metadata.cu",
|
|
||||||
],
|
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": cxx_args,
|
"cxx": cxx_args + get_features_args(),
|
||||||
"nvcc": append_nvcc_threads(
|
"nvcc": append_nvcc_threads(
|
||||||
[
|
[
|
||||||
"-O3",
|
"-O3",
|
||||||
@ -59,7 +72,7 @@ ext_modules.append(
|
|||||||
"--ptxas-options=-v,--register-usage-level=10"
|
"--ptxas-options=-v,--register-usage-level=10"
|
||||||
]
|
]
|
||||||
+ cc_flag
|
+ cc_flag
|
||||||
),
|
) + get_features_args(),
|
||||||
},
|
},
|
||||||
include_dirs=[
|
include_dirs=[
|
||||||
Path(this_dir) / "csrc",
|
Path(this_dir) / "csrc",
|
||||||
|
Loading…
Reference in New Issue
Block a user