diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index d267147..f818fd5 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -26,22 +26,22 @@ def hash_to_hex(s: str) -> str: @functools.lru_cache(maxsize=None) def get_jit_include_dir() -> str: - return f'{os.path.dirname(os.path.abspath(__file__))}/../include' + return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') @functools.lru_cache(maxsize=None) def get_deep_gemm_version() -> str: # Update include directories - include_dir = f'{get_jit_include_dir()}/deep_gemm' + include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') assert os.path.exists( include_dir), f'Cannot find GEMM include directory {include_dir}' md5 = hashlib.md5() for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): - with open(f'{include_dir}/{filename}', 'rb') as f: + with open(os.path.join(include_dir, filename), 'rb') as f: md5.update(f.read()) # Update `interleave_ffma.py` - with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f: + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: md5.update(f.read()) return md5.hexdigest()[0:12] @@ -51,7 +51,7 @@ def get_nvcc_compiler() -> Tuple[str, str]: paths = [] if os.getenv('DG_NVCC_COMPILER'): paths.append(os.getenv('DG_NVCC_COMPILER')) - paths.append(f'{CUDA_HOME}/bin/nvcc') + paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) # Try to find the first available NVCC compiler least_version_required = '12.3' @@ -73,17 +73,17 @@ def get_default_user_dir(): path = os.getenv('DG_CACHE_DIR') os.makedirs(path, exist_ok=True) return path - return os.path.expanduser('~') + '/.deep_gemm' + return os.path.join(os.path.expanduser('~'), '.deep_gemm') @functools.lru_cache(maxsize=None) def get_tmp_dir(): - return f'{get_default_user_dir()}/tmp' + return os.path.join(get_default_user_dir(), 'tmp') @functools.lru_cache(maxsize=None) def get_cache_dir(): - return f'{get_default_user_dir()}/cache' + return os.path.join(get_default_user_dir(), 'cache') def make_tmp_dir(): @@ -96,7 +96,7 @@ def put(path, data): is_binary = isinstance(data, bytes) # Write and do POSIX atomic replace - tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' + tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') with open(tmp_file_path, 'wb' if is_binary else 'w') as f: f.write(data) os.replace(tmp_file_path, path) @@ -137,7 +137,7 @@ class Compiler(abc.ABC): os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' name = f'kernel.{name}.{hash_to_hex(signature)}' - path = f'{get_cache_dir()}/{name}' + path = os.path.join(get_cache_dir(), name) # Check runtime cache or file system hit global runtime_cache @@ -148,14 +148,14 @@ class Compiler(abc.ABC): # Compile into a temporary CU file os.makedirs(path, exist_ok=True) - cubin_path = f'{path}/kernel.cubin' - tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin' + cubin_path = os.path.join(path, 'kernel.cubin') + tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') start_time = time.time() kernel_name = cls.compile(name, code, tmp_cubin_path) end_time = time.time() elapsed_time = end_time - start_time - if os.getenv('DG_JIT_DEBUG', None) or True: + if os.getenv('DG_JIT_DEBUG', None): print( f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') @@ -194,8 +194,8 @@ class NvccCompiler(Compiler): @classmethod def compile(cls, name: str, code: str, target_path: str) -> str: # Write the code - path = f'{get_cache_dir()}/{name}' - src_path = f'{path}/kernel.cu' + path = os.path.join(get_cache_dir(), name) + src_path = os.path.join(path, 'kernel.cu') put(src_path, code) command = [get_nvcc_compiler()[0], src_path, '-o', target_path, diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 0bbc0ca..1e33a32 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -10,7 +10,9 @@ from .utils import run_gemm def get_symbol(file_path: str, pattern: str) -> Optional[str]: - command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', file_path] + if CUDA_HOME is None: + raise Exception("CUDA_HOME is not set") + command = [os.path.join(CUDA_HOME, 'bin', 'cuobjdump'), '-symbols', file_path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) assert result.returncode == 0