diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 80910b4..559a2f6 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -1,7 +1,6 @@ import functools import hashlib import os -import platform import re import subprocess import time @@ -52,8 +51,7 @@ def get_nvcc_compiler() -> Tuple[str, str]: if os.getenv('DG_JIT_NVCC_COMPILER'): paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) - nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc' - paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin)) + paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) # Try to find the first available NVCC compiler least_version_required = '12.3' @@ -187,11 +185,7 @@ class NVCCCompiler(Compiler): @classmethod def flags(cls) -> List[str]: - if platform.system() != 'Windows': - cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] - else: - cxx_flags = ['/O2', '/std:c++20'] - + cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], '-gencode=arch=compute_90a,code=sm_90a', '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',