feat: compat for old drivers

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-22 20:42:59 -07:00
parent 78c7fa347e
commit 767793bf95
3 changed files with 61 additions and 34 deletions

View File

@ -13,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma from . import interleave_ffma
from .runtime import Runtime, RuntimeCache from .runtime import Runtime, RuntimeCache, get_symbol
runtime_cache = RuntimeCache() runtime_cache = RuntimeCache()
@ -108,7 +108,7 @@ class Compiler(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def compile(cls, name: str, src_path: str, target_path: str): def compile(cls, name: str, code: str, target_path: str) -> str:
pass pass
@staticmethod @staticmethod
@ -118,7 +118,7 @@ class Compiler(abc.ABC):
'--ptxas-options=--register-usage-level=10' + '--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940'] '--diag-suppress=39,161,174,177,940']
@staticmethod @staticmethod
def include_dirs() -> List[str]: def include_dirs() -> List[str]:
@ -144,17 +144,13 @@ class Compiler(abc.ABC):
print(f'Using cached JIT runtime {name} during build') print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path] return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file # Compile into a temporary CU file
os.makedirs(path, exist_ok=True)
cubin_path = f'{path}/kernel.cubin' cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin' tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
start_time = time.time() start_time = time.time()
cls.compile(name, src_path, tmp_cubin_path) kernel_name = cls.compile(name, code, tmp_cubin_path)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
@ -169,7 +165,7 @@ class Compiler(abc.ABC):
os.replace(tmp_cubin_path, cubin_path) os.replace(tmp_cubin_path, cubin_path)
# Put cache and return # Put cache and return
runtime_cache[path] = Runtime(path) runtime_cache[path] = Runtime(path, kernel_name)
return runtime_cache[path] return runtime_cache[path]
@ -190,7 +186,11 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}'] f'--compiler-options={",".join(cxx_flags)}']
@classmethod @classmethod
def compile(cls, name: str, src_path: str, target_path: str): def compile(cls, name: str, code: str, target_path: str) -> str:
# Write the code
path = f'{get_cache_dir()}/{name}'
src_path = f'{path}/kernel.cu'
put(src_path, code)
command = [get_nvcc_compiler()[0], command = [get_nvcc_compiler()[0],
src_path, '-o', target_path, src_path, '-o', target_path,
*cls.flags()] *cls.flags()]
@ -200,6 +200,8 @@ class NvccCompiler(Compiler):
return_code = subprocess.check_call(command) return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}' assert return_code == 0, f'Failed to compile {src_path}'
return get_symbol(target_path, 'fp8_gemm_kernel')
class NvrtcCompiler(Compiler): class NvrtcCompiler(Compiler):
@staticmethod @staticmethod
@ -218,19 +220,27 @@ class NvrtcCompiler(Compiler):
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device'] '--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8): if cls.__version__() >= (12, 8):
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}'] base_flags += ['--pch']
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
base_flags += ['--pch-verbose=true'] base_flags += ['--pch-verbose=true']
return base_flags return base_flags
@classmethod @classmethod
def compile(cls, name: str, src_path: str, target_path: str): def compile(cls, name: str, code: str, target_path: str) -> str:
code_bytes = open(src_path, 'rb').read() code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram( res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], []) code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}") raise Exception(f"Failed to create program: {res}")
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression(
program, bytes(kernel_name, 'utf-8'))[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to add name expression: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()] options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram( compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0] program, len(options), options)[0]
@ -249,6 +259,11 @@ class NvrtcCompiler(Compiler):
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}") raise Exception(f"Failed to compile program: {compile_res}")
res, lowered_name = nvrtc.nvrtcGetLoweredName(
program, bytes(kernel_name, 'utf-8'))
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get lowered name: {res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}") raise Exception(f"Failed to get CUBIN size: {res}")
@ -264,6 +279,8 @@ class NvrtcCompiler(Compiler):
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}") raise Exception(f"Failed to destroy program: {res}")
return lowered_name.decode('utf-8')
def build(name: str, code: str) -> Runtime: def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:

View File

@ -1,16 +1,31 @@
import os import os
import time import time
import subprocess
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
from .utils import run_gemm from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
for line in result.stdout.splitlines():
if pattern in line:
return line.split()[-1]
return None
class Runtime: class Runtime:
def __init__(self, path: str) -> None: def __init__(self, path: str, kernel_name: str) -> None:
self.path = path self.path = path
self.lib = None self.lib = None
self.kernel = None self.kernel = None
self.kernel_name = kernel_name
assert self.is_path_valid(self.path) assert self.is_path_valid(self.path)
@ -21,34 +36,25 @@ class Runtime:
return False return False
# Contains all necessary files # Contains all necessary files
files = ['kernel.cu', 'kernel.cubin'] files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files) return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN # Load CUBIN
if self.lib is None: if self.kernel is None:
start_time = time.time_ns() start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile( res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}") raise Exception(f"Failed to load library: {res}")
res, kernel_count = cuda.cuLibraryGetKernelCount(lib) print(f"Kernel name: {self.kernel_name}")
res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel count: {res}") raise Exception(f"Failed to get kernel: {res}")
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to enumerate kernels: {res}")
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel name: {res}")
if b"fp8" in kernel_name:
self.kernel = kernel
break
self.kernel = kernel
if self.kernel is not None: if self.kernel is not None:
self.lib = lib self.lib = lib
else: else:
@ -95,7 +101,9 @@ class RuntimeCache:
# Already compiled # Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path): if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = Runtime(path) kernel_name = get_symbol(os.path.join(
path, 'kernel.cubin'), 'fp8_gemm_kernel')
runtime = Runtime(path, kernel_name)
self.cache[path] = runtime self.cache[path] = runtime
return runtime return runtime
return None return None

View File

@ -22,7 +22,9 @@ def generate(**kwargs: Dict[str, Any]) -> str:
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh> #include <deep_gemm/fp8_gemm.cuh>
namespace deep_gemm {{ using namespace deep_gemm;
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{ __global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel< void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']}, {kwargs['N']},
@ -41,7 +43,7 @@ __global__ void dummy_kernel() {{
GemmType::{kwargs['GEMM_TYPE']} GemmType::{kwargs['GEMM_TYPE']}
>; >;
}} }}
}} #endif
''' '''
# Debug print # Debug print