Performance optimization for compute-bound cases

This commit is contained in:
Shengyu Liu
2025-04-21 17:22:59 +08:00
parent 063ffa8ec1
commit 287061ec34
20 changed files with 1799 additions and 1217 deletions

View File

@@ -11,29 +11,13 @@ from torch.utils.cpp_extension import (
IS_WINDOWS,
)
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return nvcc_extra_args + ["--threads", nvcc_threads]
def get_sources():
sources = [
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_metadata.cu",
]
if not DISABLE_FP16:
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
return sources
def get_features_args():
features_args = []
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
if DISABLE_FP16:
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args
@@ -56,7 +40,12 @@ ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_cuda",
sources=get_sources(),
sources=[
"csrc/flash_api.cpp",
"csrc/kernels/get_mla_metadata.cu",
"csrc/kernels/mla_combine.cu",
"csrc/kernels/splitkv_mla.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
"nvcc": append_nvcc_threads(