From e62bdb4d3fc49d7324ff3dce41b180984bc67cbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=85=83?= Date: Mon, 24 Feb 2025 11:29:36 +0800 Subject: [PATCH] support Windows build --- setup.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bab4792..0a3bd17 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ from setuptools import setup, find_packages from torch.utils.cpp_extension import ( BuildExtension, CUDAExtension, + IS_WINDOWS, ) @@ -24,6 +25,11 @@ cc_flag.append("arch=compute_90a,code=sm_90a") this_dir = os.path.dirname(os.path.abspath(__file__)) +if IS_WINDOWS: + cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"] +else: + cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"] + ext_modules = [] ext_modules.append( CUDAExtension( @@ -33,12 +39,13 @@ ext_modules.append( "csrc/flash_fwd_mla_bf16_sm90.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], + "cxx": cxx_args, "nvcc": append_nvcc_threads( [ "-O3", "-std=c++17", "-DNDEBUG", + "-D_USE_MATH_DEFINES", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__",