From 767793bf95c1c91b38509f107defebf42a824ddd Mon Sep 17 00:00:00 2001 From: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Date: Tue, 22 Apr 2025 20:42:59 -0700 Subject: [PATCH] feat: compat for old drivers Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> --- deep_gemm/jit/compiler.py | 45 +++++++++++++++++++++++++++------------ deep_gemm/jit/runtime.py | 44 ++++++++++++++++++++++---------------- deep_gemm/jit/template.py | 6 ++++-- 3 files changed, 61 insertions(+), 34 deletions(-) diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index f1e4bc0..ab67d4d 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -13,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME from . import interleave_ffma -from .runtime import Runtime, RuntimeCache +from .runtime import Runtime, RuntimeCache, get_symbol runtime_cache = RuntimeCache() @@ -108,7 +108,7 @@ class Compiler(abc.ABC): @classmethod @abc.abstractmethod - def compile(cls, name: str, src_path: str, target_path: str): + def compile(cls, name: str, code: str, target_path: str) -> str: pass @staticmethod @@ -118,7 +118,7 @@ class Compiler(abc.ABC): '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - '--diag-suppress=39,174,177,940'] + '--diag-suppress=39,161,174,177,940'] @staticmethod def include_dirs() -> List[str]: @@ -144,17 +144,13 @@ class Compiler(abc.ABC): print(f'Using cached JIT runtime {name} during build') return runtime_cache[path] - # Write the code - os.makedirs(path, exist_ok=True) - src_path = f'{path}/kernel.cu' - put(src_path, code) - # Compile into a temporary CU file + os.makedirs(path, exist_ok=True) cubin_path = f'{path}/kernel.cubin' tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin' start_time = time.time() - cls.compile(name, src_path, tmp_cubin_path) + kernel_name = cls.compile(name, code, tmp_cubin_path) end_time = time.time() elapsed_time = end_time - start_time if os.getenv('DG_JIT_DEBUG', None): @@ -169,7 +165,7 @@ class Compiler(abc.ABC): os.replace(tmp_cubin_path, cubin_path) # Put cache and return - runtime_cache[path] = Runtime(path) + runtime_cache[path] = Runtime(path, kernel_name) return runtime_cache[path] @@ -190,7 +186,11 @@ class NvccCompiler(Compiler): f'--compiler-options={",".join(cxx_flags)}'] @classmethod - def compile(cls, name: str, src_path: str, target_path: str): + def compile(cls, name: str, code: str, target_path: str) -> str: + # Write the code + path = f'{get_cache_dir()}/{name}' + src_path = f'{path}/kernel.cu' + put(src_path, code) command = [get_nvcc_compiler()[0], src_path, '-o', target_path, *cls.flags()] @@ -200,6 +200,8 @@ class NvccCompiler(Compiler): return_code = subprocess.check_call(command) assert return_code == 0, f'Failed to compile {src_path}' + return get_symbol(target_path, 'fp8_gemm_kernel') + class NvrtcCompiler(Compiler): @staticmethod @@ -218,19 +220,27 @@ class NvrtcCompiler(Compiler): base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], '--gpu-architecture=sm_90a', '-default-device'] if cls.__version__() >= (12, 8): - base_flags += ['--pch', f'--pch-dir={get_cache_dir()}'] + base_flags += ['--pch'] if os.getenv('DG_JIT_DEBUG', None): base_flags += ['--pch-verbose=true'] return base_flags @classmethod - def compile(cls, name: str, src_path: str, target_path: str): - code_bytes = open(src_path, 'rb').read() + def compile(cls, name: str, code: str, target_path: str) -> str: + code_bytes = bytes(code, 'utf-8') res, program = nvrtc.nvrtcCreateProgram( code_bytes, bytes(name, 'utf-8'), 0, [], []) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to create program: {res}") + kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE) + kernel_name = kernel_regex.search(code).group( + 0).replace('\n', '').replace(' ', '') + res = nvrtc.nvrtcAddNameExpression( + program, bytes(kernel_name, 'utf-8'))[0] + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f"Failed to add name expression: {res}") + options = [bytes(flag, 'utf-8') for flag in cls.flags()] compile_res = nvrtc.nvrtcCompileProgram( program, len(options), options)[0] @@ -249,6 +259,11 @@ class NvrtcCompiler(Compiler): if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to compile program: {compile_res}") + res, lowered_name = nvrtc.nvrtcGetLoweredName( + program, bytes(kernel_name, 'utf-8')) + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f"Failed to get lowered name: {res}") + res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to get CUBIN size: {res}") @@ -264,6 +279,8 @@ class NvrtcCompiler(Compiler): if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to destroy program: {res}") + return lowered_name.decode('utf-8') + def build(name: str, code: str) -> Runtime: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 2a171b9..5c4e0ff 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,16 +1,31 @@ import os import time +import subprocess from typing import Any, Dict, Optional import cuda.bindings.driver as cuda +from torch.utils.cpp_extension import CUDA_HOME from .utils import run_gemm + +def get_symbol(file_path: str, pattern: str) -> Optional[str]: + command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', file_path] + result = subprocess.run(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True) + assert result.returncode == 0 + for line in result.stdout.splitlines(): + if pattern in line: + return line.split()[-1] + return None + + class Runtime: - def __init__(self, path: str) -> None: + def __init__(self, path: str, kernel_name: str) -> None: self.path = path self.lib = None self.kernel = None + self.kernel_name = kernel_name assert self.is_path_valid(self.path) @@ -21,34 +36,25 @@ class Runtime: return False # Contains all necessary files - files = ['kernel.cu', 'kernel.cubin'] + files = ['kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: # Load CUBIN - if self.lib is None: + if self.kernel is None: start_time = time.time_ns() res, lib = cuda.cuLibraryLoadFromFile( bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) if res != cuda.CUresult.CUDA_SUCCESS: raise Exception(f"Failed to load library: {res}") - res, kernel_count = cuda.cuLibraryGetKernelCount(lib) + print(f"Kernel name: {self.kernel_name}") + res, kernel = cuda.cuLibraryGetKernel( + lib, bytes(self.kernel_name, encoding='utf-8')) if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to get kernel count: {res}") - - res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib) - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to enumerate kernels: {res}") - - for kernel in kernels: - res, kernel_name = cuda.cuKernelGetName(kernel) - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to get kernel name: {res}") - if b"fp8" in kernel_name: - self.kernel = kernel - break + raise Exception(f"Failed to get kernel: {res}") + self.kernel = kernel if self.kernel is not None: self.lib = lib else: @@ -95,7 +101,9 @@ class RuntimeCache: # Already compiled if os.path.exists(path) and Runtime.is_path_valid(path): - runtime = Runtime(path) + kernel_name = get_symbol(os.path.join( + path, 'kernel.cubin'), 'fp8_gemm_kernel') + runtime = Runtime(path, kernel_name) self.cache[path] = runtime return runtime return None diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index 2a99322..6fd1322 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -22,7 +22,9 @@ def generate(**kwargs: Dict[str, Any]) -> str: #include #include -namespace deep_gemm {{ +using namespace deep_gemm; + +#ifndef NVRTC_JIT_COMPILATION __global__ void dummy_kernel() {{ void *ptr = (void *)&fp8_gemm_kernel< {kwargs['N']}, @@ -41,7 +43,7 @@ __global__ void dummy_kernel() {{ GemmType::{kwargs['GEMM_TYPE']} >; }} -}} +#endif ''' # Debug print