mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Performance optimization for compute-bound cases
This commit is contained in:
25
setup.py
25
setup.py
@@ -11,29 +11,13 @@ from torch.utils.cpp_extension import (
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
|
||||
def get_sources():
|
||||
sources = [
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
"csrc/flash_fwd_mla_metadata.cu",
|
||||
]
|
||||
|
||||
if not DISABLE_FP16:
|
||||
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def get_features_args():
|
||||
features_args = []
|
||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
|
||||
if DISABLE_FP16:
|
||||
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||
return features_args
|
||||
@@ -56,7 +40,12 @@ ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
sources=get_sources(),
|
||||
sources=[
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/kernels/get_mla_metadata.cu",
|
||||
"csrc/kernels/mla_combine.cu",
|
||||
"csrc/kernels/splitkv_mla.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": cxx_args + get_features_args(),
|
||||
"nvcc": append_nvcc_threads(
|
||||
|
||||
Reference in New Issue
Block a user