mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor JIT compilation (+NVRTC support) (#94)
* [wip] refactor: compile to .cubin Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * refactor: compile to .cubin and add NVRTC option Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: compiler version Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: compat for old drivers Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: save kernel name to file Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: fix win compat Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: windows compat Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> * feat: make API more general Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: drop support for CUDA<12.3 Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * doc: update README Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * Some lints and refactor * Refactor runtime * Several fixes * Refactor environment variables * Code format * Add a TODO * Compatible with CUDA 12.3 * Fix indent * Fix typing * Drop support for Windows * Add a TODO --------- Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
import hashlib
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@@ -22,21 +25,22 @@ def hash_to_hex(s: str) -> str:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
|
||||
# Update include directories
|
||||
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
with open(os.path.join(include_dir, filename), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
# Update `interleave_ffma.py`
|
||||
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
|
||||
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
@@ -44,16 +48,20 @@ def get_deep_gemm_version() -> str:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
if os.getenv('DG_JIT_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
|
||||
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
command = [path, '--version']
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
match = version_pattern.search(result.stdout)
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
@@ -63,21 +71,21 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
if 'DG_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
if 'DG_JIT_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_JIT_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
return os.path.join(get_default_user_dir(), 'tmp')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
return os.path.join(get_default_user_dir(), 'cache')
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
@@ -86,67 +94,192 @@ def make_tmp_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
def put(path, data):
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
|
||||
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
class Compiler:
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
pass
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
pass
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
@staticmethod
|
||||
def flags() -> List[str]:
|
||||
cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20))
|
||||
return [f'-std=c++{cpp_standard}',
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,161,174,177,940']
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_so_path)
|
||||
# Build signature
|
||||
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0))
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
if cached_runtime is not None:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return cached_runtime
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
cubin_path = os.path.join(path, 'kernel.cubin')
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Atomic replace files
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime = runtime_cls(path)
|
||||
runtime_cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
|
||||
class NVCCCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
_, version = get_nvcc_compiler()
|
||||
major, minor = map(int, version.split('.'))
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvcc+{cls.__version__()}'
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
put(src_path, code)
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
|
||||
assert False, f'Failed to compile {src_path}'
|
||||
|
||||
|
||||
class NVRTCCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
res, major, minor = nvrtc.nvrtcVersion()
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
# Failed to get the actual NVRTC version, use cuda-bindings version instead
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvrtc+{cls.__version__()}'
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
# NOTES: PCH is vital for compilation speed
|
||||
if cls.__version__() >= (12, 8):
|
||||
flags += ['--pch']
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
flags += ['--pch-verbose=true']
|
||||
return flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
# Create program
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
result, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
|
||||
|
||||
# Compile
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
|
||||
print(f'Compiling JIT runtime {name} with options: {options}')
|
||||
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
|
||||
|
||||
# Print compiler log
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
|
||||
|
||||
log_bytes = bytes(log_size)
|
||||
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
|
||||
print(f'Compiler log: {log_bytes.decode("utf-8")}')
|
||||
|
||||
# Exit if failed
|
||||
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
|
||||
|
||||
# Create CUBIN
|
||||
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
|
||||
|
||||
# Write into the file system
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
# Destroy handler
|
||||
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
||||
|
||||
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
||||
return compiler_cls.build(name, code, runtime_cls=runtime_cls)
|
||||
|
||||
Reference in New Issue
Block a user