import functools import hashlib import os import re import subprocess import time import uuid from typing import List, Tuple, Type import cuda.bindings import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME from . import interleave_ffma from .runtime import Runtime, RuntimeCache runtime_cache = RuntimeCache() def hash_to_hex(s: str) -> str: md5 = hashlib.md5() md5.update(s.encode('utf-8')) return md5.hexdigest()[0:12] @functools.lru_cache(maxsize=None) def get_jit_include_dir() -> str: return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') @functools.lru_cache(maxsize=None) def get_deep_gemm_version() -> str: md5 = hashlib.md5() # Update include directories include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): with open(os.path.join(include_dir, filename), 'rb') as f: md5.update(f.read()) # Update `interleave_ffma.py` with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: md5.update(f.read()) return md5.hexdigest()[0:12] @functools.lru_cache(maxsize=None) def get_nvcc_compiler() -> Tuple[str, str]: paths = [] if os.getenv('DG_JIT_NVCC_COMPILER'): paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) # Try to find the first available NVCC compiler least_version_required = '12.3' version_pattern = re.compile(r'release (\d+\.\d+)') for path in paths: if os.path.exists(path): command = [path, '--version'] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) match = version_pattern.search(result.stdout) version = match.group(1) assert match, f'Cannot get the version of NVCC compiler {path}' assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' return path, version raise RuntimeError('Cannot find any available NVCC compiler') @functools.lru_cache(maxsize=None) def get_default_user_dir(): if 'DG_JIT_CACHE_DIR' in os.environ: path = os.getenv('DG_JIT_CACHE_DIR') os.makedirs(path, exist_ok=True) return path return os.path.join(os.path.expanduser('~'), '.deep_gemm') @functools.lru_cache(maxsize=None) def get_tmp_dir(): return os.path.join(get_default_user_dir(), 'tmp') @functools.lru_cache(maxsize=None) def get_cache_dir(): return os.path.join(get_default_user_dir(), 'cache') def make_tmp_dir(): tmp_dir = get_tmp_dir() os.makedirs(tmp_dir, exist_ok=True) return tmp_dir def put(path, data): # Write and do POSIX atomic replace tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: f.write(data) os.replace(tmp_file_path, path) class Compiler: @classmethod def signature(cls) -> str: pass @staticmethod def __version__() -> Tuple[int, int]: pass @classmethod def compile(cls, name: str, code: str, target_path: str) -> None: pass @staticmethod def flags() -> List[str]: cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) return [f'-std=c++{cpp_standard}', '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases '--diag-suppress=39,161,174,177,940'] @staticmethod def include_dirs() -> List[str]: return [get_jit_include_dir()] @classmethod def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: # Compiler flags flags = cls.flags() # Build signature enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' name = f'kernel.{name}.{hash_to_hex(signature)}' path = os.path.join(get_cache_dir(), name) # Check runtime cache or file system hit global runtime_cache cached_runtime = runtime_cache.get(path, runtime_cls) if cached_runtime is not None: if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Using cached JIT runtime {name} during build') return cached_runtime # Compile into a temporary CU file os.makedirs(path, exist_ok=True) cubin_path = os.path.join(path, 'kernel.cubin') tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') start_time = time.time() cls.compile(name, code, tmp_cubin_path) end_time = time.time() elapsed_time = end_time - start_time if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') # Interleave FFMA reuse if enable_sass_opt: interleave_ffma.process(tmp_cubin_path) # Atomic replace files os.replace(tmp_cubin_path, cubin_path) # Put cache and return runtime = runtime_cls(path) runtime_cache[path] = runtime return runtime class NVCCCompiler(Compiler): @staticmethod def __version__() -> Tuple[int, int]: _, version = get_nvcc_compiler() major, minor = map(int, version.split('.')) return major, minor @classmethod def signature(cls) -> str: return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' @classmethod def flags(cls) -> List[str]: cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], '-gencode=arch=compute_90a,code=sm_90a', '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', f'--compiler-options={",".join(cxx_flags)}'] @classmethod def compile(cls, name: str, code: str, target_path: str) -> None: # Write the code path = os.path.join(get_cache_dir(), name) src_path = os.path.join(path, 'kernel.cu') put(src_path, code) command = [get_nvcc_compiler()[0], src_path, '-o', target_path, *cls.flags()] if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): print(f'Compiling JIT runtime {name} with command {command}') result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if result.returncode != 0: print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') assert False, f'Failed to compile {src_path}' class NVRTCCompiler(Compiler): @staticmethod def __version__() -> Tuple[int, int]: res, major, minor = nvrtc.nvrtcVersion() if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: # Failed to get the actual NVRTC version, use cuda-bindings version instead major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) return major, minor @classmethod def signature(cls) -> str: return f'nvrtc+{cls.__version__()}' @staticmethod def include_dirs() -> List[str]: if CUDA_HOME is None: raise RuntimeError('CUDA_HOME is required for NVRTC compilation') return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] @classmethod def flags(cls) -> List[str]: flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], '--gpu-architecture=sm_90a', '-default-device'] # NOTES: PCH is vital for compilation speed if cls.__version__() >= (12, 8): flags += ['--pch'] if int(os.getenv('DG_JIT_DEBUG', 0)): flags += ['--pch-verbose=true'] return flags @classmethod def compile(cls, name: str, code: str, target_path: str) -> None: # Create program code_bytes = bytes(code, 'utf-8') result, program = nvrtc.nvrtcCreateProgram( code_bytes, bytes(name, 'utf-8'), 0, [], []) assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' # Compile options = [bytes(flag, 'utf-8') for flag in cls.flags()] if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): print(f'Compiling JIT runtime {name} with options: {options}') compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] # Print compiler log if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: result, log_size = nvrtc.nvrtcGetProgramLogSize(program) assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' log_bytes = bytes(log_size) result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' print(f'Compiler log: {log_bytes.decode("utf-8")}') # Exit if failed assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' # Create CUBIN result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' cubin_bytes = bytes(cubin_size) result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' # Write into the file system put(target_path, cubin_bytes) # Destroy handler assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler return compiler_cls.build(name, code, runtime_cls=runtime_cls)