diff --git a/.gitignore b/.gitignore index 9b500a0..4535280 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.png /.vscode compile_commands.json +.cache diff --git a/setup.py b/setup.py index 131ceff..217f540 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,12 @@ from torch.utils.cpp_extension import ( IS_WINDOWS, ) + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_features_args(): features_args = [] DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]