mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add env
This commit is contained in:
parent
6199b0b4b5
commit
7fafcd217d
@ -218,9 +218,12 @@ mha_fwd_kvcache_mla(
|
|||||||
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
|
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#ifndef FLASH_MLA_DISABLE_FP8
|
||||||
else if (q_dtype == torch::kFloat8_e4m3fn) {
|
else if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||||
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
|
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||||
} else {
|
}
|
||||||
|
#endif
|
||||||
|
else {
|
||||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
6
setup.py
6
setup.py
@ -12,6 +12,7 @@ from torch.utils.cpp_extension import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||||
|
DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") == "TRUE"
|
||||||
|
|
||||||
|
|
||||||
def append_nvcc_threads(nvcc_extra_args):
|
def append_nvcc_threads(nvcc_extra_args):
|
||||||
@ -23,12 +24,13 @@ def get_sources():
|
|||||||
sources = [
|
sources = [
|
||||||
"csrc/flash_api.cpp",
|
"csrc/flash_api.cpp",
|
||||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||||
"csrc/flash_fwd_mla_fp8_sm90.cu",
|
|
||||||
"csrc/flash_fwd_mla_metadata.cu",
|
"csrc/flash_fwd_mla_metadata.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
if not DISABLE_FP16:
|
if not DISABLE_FP16:
|
||||||
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||||
|
if not DISABLE_FP8:
|
||||||
|
sources.append("csrc/flash_fwd_mla_fp8_sm90.cu")
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
@ -37,6 +39,8 @@ def get_features_args():
|
|||||||
features_args = []
|
features_args = []
|
||||||
if DISABLE_FP16:
|
if DISABLE_FP16:
|
||||||
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||||
|
if DISABLE_FP8:
|
||||||
|
features_args.append("-DFLASH_MLA_DISABLE_FP8")
|
||||||
return features_args
|
return features_args
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user