mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: make API more general
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -7,14 +7,14 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache, get_symbol
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@@ -115,7 +115,7 @@ class Compiler(abc.ABC):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@@ -132,7 +132,7 @@ class Compiler(abc.ABC):
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str) -> Runtime:
|
||||
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
include_dirs = cls.include_dirs()
|
||||
@@ -146,10 +146,11 @@ class Compiler(abc.ABC):
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
if cached_runtime is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
return cached_runtime
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -157,7 +158,7 @@ class Compiler(abc.ABC):
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
@@ -176,8 +177,9 @@ class Compiler(abc.ABC):
|
||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path, kernel_name)
|
||||
return runtime_cache[path]
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
runtime_cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
|
||||
class NvccCompiler(Compiler):
|
||||
@@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
@@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
return get_symbol(target_path, kernel_name_pattern)
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
@@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
|
||||
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
|
||||
kernel_name = kernel_regex.search(code).group(
|
||||
0).replace('\n', '').replace(' ', '')
|
||||
res = nvrtc.nvrtcAddNameExpression(
|
||||
@@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code)
|
||||
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
|
||||
else:
|
||||
return NvccCompiler.build(name, code)
|
||||
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')
|
||||
|
||||
Reference in New Issue
Block a user