diff --git a/setup.py b/setup.py index f9df917..af8e4e6 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,8 @@ if __name__ == '__main__': print(f'NVSHMEM directory: {nvshmem_dir}') # TODO: currently, we only support Hopper architecture, we may add Ampere support later - os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' + if os.getenv('TORCH_CUDA_ARCH_LIST', None) is None: + os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes'] nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']