mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: make API more general
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
|
||||
from .template import generate
|
||||
from .runtime import Runtime
|
||||
|
||||
@@ -7,14 +7,14 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache, get_symbol
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@@ -115,7 +115,7 @@ class Compiler(abc.ABC):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@@ -132,7 +132,7 @@ class Compiler(abc.ABC):
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str) -> Runtime:
|
||||
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
include_dirs = cls.include_dirs()
|
||||
@@ -146,10 +146,11 @@ class Compiler(abc.ABC):
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
if cached_runtime is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
return cached_runtime
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -157,7 +158,7 @@ class Compiler(abc.ABC):
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
@@ -176,8 +177,9 @@ class Compiler(abc.ABC):
|
||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path, kernel_name)
|
||||
return runtime_cache[path]
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
runtime_cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
|
||||
class NvccCompiler(Compiler):
|
||||
@@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
@@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
return get_symbol(target_path, kernel_name_pattern)
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
@@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
|
||||
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
|
||||
kernel_name = kernel_regex.search(code).group(
|
||||
0).replace('\n', '').replace(' ', '')
|
||||
res = nvrtc.nvrtcAddNameExpression(
|
||||
@@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code)
|
||||
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
|
||||
else:
|
||||
return NvccCompiler.build(name, code)
|
||||
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import platform
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
@@ -13,9 +13,10 @@ from .utils import run_gemm
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise Exception("CUDA_HOME is not set")
|
||||
|
||||
|
||||
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
|
||||
'-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
@@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
|
||||
self.caller = caller
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
@@ -62,7 +64,7 @@ class Runtime:
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
raise Exception("Failed to find fp8 gemm kernel")
|
||||
raise Exception("Failed to find kernel")
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
@@ -70,21 +72,9 @@ class Runtime:
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
return self.caller(
|
||||
self.kernel,
|
||||
kwargs['NUM_TMA_MULTICAST'],
|
||||
kwargs['M'],
|
||||
kwargs['BLOCK_M'],
|
||||
kwargs['GMEM_D'],
|
||||
kwargs['SCALES_B'],
|
||||
kwargs['GROUPED_LAYOUT'],
|
||||
kwargs['NUM_SMS'],
|
||||
kwargs['SMEM_SIZE'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
kwargs['STREAM'],
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
@@ -94,22 +84,42 @@ class Runtime:
|
||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
super().__init__(path, kernel_name, run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = Runtime(path, kernel_name)
|
||||
kernel_name = open(os.path.join(
|
||||
path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
return None
|
||||
Reference in New Issue
Block a user