diff --git a/README.md b/README.md index 4027334..6d0bcb6 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ python setup.py install python tests/test_flash_mla.py ``` -Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6. +Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. ### Usage @@ -42,6 +42,7 @@ for i in range(num_layers): - Hopper GPUs - CUDA 12.3 and above + - **But we highly recommend 12.8 or above for the best performance** - PyTorch 2.0 and above ## Acknowledgement @@ -52,7 +53,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ```bibtex @misc{flashmla2025, - title={FlashMLA: Efficient MLA decoding kernel}, + title={FlashMLA: Efficient MLA decoding kernels}, author={Jiashi Li}, year={2025}, publisher = {GitHub}, diff --git a/setup.py b/setup.py index 6377b1e..cd311f2 100644 --- a/setup.py +++ b/setup.py @@ -13,10 +13,12 @@ from torch.utils.cpp_extension import ( DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_sources(): sources = [ "csrc/flash_api.cpp", @@ -29,12 +31,14 @@ def get_sources(): return sources + def get_features_args(): features_args = [] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) cc_flag = []