diff --git a/setup.py b/setup.py index bab4792..0a3bd17 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ from setuptools import setup, find_packages from torch.utils.cpp_extension import ( BuildExtension, CUDAExtension, + IS_WINDOWS, ) @@ -24,6 +25,11 @@ cc_flag.append("arch=compute_90a,code=sm_90a") this_dir = os.path.dirname(os.path.abspath(__file__)) +if IS_WINDOWS: + cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"] +else: + cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"] + ext_modules = [] ext_modules.append( CUDAExtension( @@ -33,12 +39,13 @@ ext_modules.append( "csrc/flash_fwd_mla_bf16_sm90.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], + "cxx": cxx_args, "nvcc": append_nvcc_threads( [ "-O3", "-std=c++17", "-DNDEBUG", + "-D_USE_MATH_DEFINES", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__",