feat: drop support for CUDA<12.3

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-25 18:56:40 -07:00
parent 46762b6903
commit f6198492cb
5 changed files with 46 additions and 85 deletions

View File

@@ -1,31 +1,12 @@
import os
import platform
import time
import subprocess
from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
'-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
for line in result.stdout.splitlines():
if pattern in line:
return line.split()[-1]
return None
class Runtime:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
self.path = path
@@ -43,7 +24,7 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cubin', 'kernel.cubin.name']
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
@@ -53,18 +34,28 @@ class Runtime:
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
raise Exception(f'Failed to load library: {res}')
res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel: {res}")
raise Exception(f'Failed to get kernel count: {res}')
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to enumerate kernels: {res}')
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to get kernel name: {res}')
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
self.kernel = kernel
break
self.kernel = kernel
if self.kernel is not None:
self.lib = lib
else:
raise Exception("Failed to find kernel")
raise Exception('Failed to find required kernel')
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
@@ -81,12 +72,12 @@ class Runtime:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to unload library {self.path}: {res}")
raise Exception(f'Failed to unload library {self.path}: {res}')
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_gemm, [
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
@@ -117,9 +108,7 @@ class RuntimeCache:
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = open(os.path.join(
path, 'kernel.cubin.name'), 'r').read()
runtime = runtime_cls(path, kernel_name)
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime
return None