mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-04-08 06:24:00 +00:00
commit
bcb90f2afd
9
setup.py
9
setup.py
@ -8,6 +8,7 @@ from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CUDAExtension,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
|
||||
@ -24,6 +25,11 @@ cc_flag.append("arch=compute_90a,code=sm_90a")
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if IS_WINDOWS:
|
||||
cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
|
||||
else:
|
||||
cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]
|
||||
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
@ -33,12 +39,13 @@ ext_modules.append(
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"],
|
||||
"cxx": cxx_args,
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-DNDEBUG",
|
||||
"-D_USE_MATH_DEFINES",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
|
Loading…
Reference in New Issue
Block a user