feat: make API more general

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-23 02:34:23 -07:00
parent 6c982791eb
commit 46762b6903
4 changed files with 156 additions and 90 deletions

View File

@@ -1,3 +1,3 @@
from .compiler import get_nvcc_compiler, build
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
from .template import generate
from .runtime import Runtime

View File

@@ -7,14 +7,14 @@ import re
import subprocess
import time
import uuid
from typing import List, Tuple
from typing import List, Tuple, Type
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache, get_symbol
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
runtime_cache = RuntimeCache()
@@ -115,7 +115,7 @@ class Compiler(abc.ABC):
@classmethod
@abc.abstractmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
pass
@staticmethod
@@ -132,7 +132,7 @@ class Compiler(abc.ABC):
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str) -> Runtime:
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
# Compiler flags
flags = cls.flags()
include_dirs = cls.include_dirs()
@@ -146,10 +146,11 @@ class Compiler(abc.ABC):
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
cached_runtime = runtime_cache.get(path, runtime_cls)
if cached_runtime is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
return cached_runtime
# Compile into a temporary CU file
os.makedirs(path, exist_ok=True)
@@ -157,7 +158,7 @@ class Compiler(abc.ABC):
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time()
kernel_name = cls.compile(name, code, tmp_cubin_path)
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
@@ -176,8 +177,9 @@ class Compiler(abc.ABC):
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return
runtime_cache[path] = Runtime(path, kernel_name)
return runtime_cache[path]
runtime = runtime_cls(path, kernel_name)
runtime_cache[path] = runtime
return runtime
class NvccCompiler(Compiler):
@@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
@@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel')
return get_symbol(target_path, kernel_name_pattern)
class NvrtcCompiler(Compiler):
@@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
return base_flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}")
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression(
@@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code)
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
else:
return NvccCompiler.build(name, code)
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')

View File

@@ -2,7 +2,7 @@ import os
import platform
import time
import subprocess
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
@@ -13,9 +13,10 @@ from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
'-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
@@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
class Runtime:
def __init__(self, path: str, kernel_name: str) -> None:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
self.caller = caller
self.args = args
assert self.is_path_valid(self.path)
@staticmethod
@@ -62,7 +64,7 @@ class Runtime:
if self.kernel is not None:
self.lib = lib
else:
raise Exception("Failed to find fp8 gemm kernel")
raise Exception("Failed to find kernel")
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
@@ -70,21 +72,9 @@ class Runtime:
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
return self.caller(
self.kernel,
kwargs['NUM_TMA_MULTICAST'],
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
*[kwargs[arg] for arg in self.args]
)
def __del__(self) -> None:
@@ -94,22 +84,42 @@ class Runtime:
raise Exception(f"Failed to unload library {self.path}: {res}")
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]:
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
kernel_name = open(os.path.join(
path, 'kernel.cubin.name'), 'r').read()
runtime = runtime_cls(path, kernel_name)
self.cache[path] = runtime
return runtime
return None
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
return None