Compatible with CUDA 12.3

This commit is contained in:
Chenggang Zhao
2025-05-07 11:15:19 +08:00
parent 5373da7b28
commit ba349d9cf8
3 changed files with 30 additions and 47 deletions

View File

@@ -1,18 +1,17 @@
import os
import subprocess
import time
import cuda.bindings.driver as cbd
from typing import List, Optional, Type
import cuda.bindings.driver as cuda
from typing import List, Optional, Type
from torch.utils.cpp_extension import CUDA_HOME
class Runtime:
def __init__(self, path: str, kernel_name: str = None,
args: List[str] = None) -> None:
def __init__(self, path: str, args: List[str] = None) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
self.args = args
assert self.is_path_valid(self.path)
@@ -34,51 +33,39 @@ class Runtime:
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
raise NotImplemented
def __call__(self, **kwargs) -> cuda.CUresult:
def __call__(self, **kwargs) -> cbd.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to load library: {res}')
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to get kernel count: {res}')
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to enumerate kernels: {res}')
# Load CUBIN
path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8')
result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}'
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to get kernel name: {res}')
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
self.kernel = kernel
break
# Extract the kernel name
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
if self.kernel is not None:
self.lib = lib
else:
raise Exception('Failed to find required kernel')
# Load kernel from the library
result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8'))
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}'
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
elapsed_time = (end_time - start_time) / 1e6
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
# noinspection PyArgumentList
return self.launch(
self.kernel,
*[kwargs[arg] for arg in self.args]
)
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
def __del__(self) -> None:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
res = cbd.cuLibraryUnload(self.lib)[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to unload library {self.path}: {res}')