mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
refactor: compile to .cubin and add NVRTC option
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
import hashlib
|
||||
import abc
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
@@ -29,7 +33,8 @@ def get_jit_include_dir() -> str:
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
assert os.path.exists(
|
||||
include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
@@ -53,7 +58,8 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
match = version_pattern.search(
|
||||
os.popen(f'{path} --version').read())
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
@@ -94,64 +100,173 @@ def put(path, data, is_binary=False):
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
class Compiler(abc.ABC):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
pass
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
@staticmethod
|
||||
def flags() -> List[str]:
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
return [f'-std=c++{cpp_standard}',
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
include_dirs = cls.include_dirs()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
|
||||
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary CU file
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
|
||||
start_time = time.time()
|
||||
cls.compile(name, src_path, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Atomic replace CU file
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary CU file
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
class NvccCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
|
||||
return (major, minor)
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_cubin_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
start_time = time.time()
|
||||
return_code = subprocess.check_call(command)
|
||||
end_time = time.time()
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
cxx_flags = ['-fPIC', '-O3',
|
||||
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
# Print elapsed time if debug is enabled
|
||||
elapsed_time = end_time - start_time
|
||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
@classmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Atomic replace CU file
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
class NvrtcCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
_, version = get_nvcc_compiler()
|
||||
major, minor = map(int, version.split('.'))
|
||||
return (major, minor)
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
if cls.__version__() >= (12, 8):
|
||||
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
base_flags += ['--pch-verbose=true']
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
code_bytes = open(src_path, 'rb').read()
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
compile_res = nvrtc.nvrtcCompileProgram(
|
||||
program, len(options), options)[0]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get program log size: {res}")
|
||||
log_bytes = bytes(log_size)
|
||||
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get program log: {res}")
|
||||
log_str = log_bytes.decode('utf-8')
|
||||
print(log_str)
|
||||
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to compile program: {compile_res}")
|
||||
|
||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN size: {res}")
|
||||
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN: {res}")
|
||||
|
||||
put(target_path, cubin_bytes, is_binary=True)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to destroy program: {res}")
|
||||
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code)
|
||||
else:
|
||||
return NvccCompiler.build(name, code)
|
||||
|
||||
@@ -3,8 +3,6 @@ import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
import torch
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
@@ -58,8 +56,9 @@ class Runtime:
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
self.kernel,
|
||||
|
||||
Reference in New Issue
Block a user