mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 20:14:21 +00:00
feat: drop support for CUDA<12.3
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
46762b6903
commit
f6198492cb
@ -14,7 +14,7 @@ import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@ -115,7 +115,7 @@ class Compiler(abc.ABC):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -132,10 +132,9 @@ class Compiler(abc.ABC):
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
include_dirs = cls.include_dirs()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
|
||||
@ -158,7 +157,7 @@ class Compiler(abc.ABC):
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
|
||||
cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
@ -169,15 +168,11 @@ class Compiler(abc.ABC):
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Store kernel name
|
||||
put(f'{tmp_cubin_path}.name', kernel_name)
|
||||
|
||||
# Atomic replace files
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||
|
||||
# Put cache and return
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
runtime = runtime_cls(path)
|
||||
runtime_cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
@ -202,7 +197,7 @@ class NvccCompiler(Compiler):
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str):
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
@ -221,9 +216,6 @@ class NvccCompiler(Compiler):
|
||||
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, kernel_name_pattern)
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
@staticmethod
|
||||
@ -238,7 +230,7 @@ class NvrtcCompiler(Compiler):
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
@ -251,20 +243,12 @@ class NvrtcCompiler(Compiler):
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
|
||||
kernel_name = kernel_regex.search(code).group(
|
||||
0).replace('\n', '').replace(' ', '')
|
||||
res = nvrtc.nvrtcAddNameExpression(
|
||||
program, bytes(kernel_name, 'utf-8'))[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to add name expression: {res}")
|
||||
raise Exception(f'Failed to create program: {res}')
|
||||
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
compile_res = nvrtc.nvrtcCompileProgram(
|
||||
@ -273,43 +257,35 @@ class NvrtcCompiler(Compiler):
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get program log size: {res}")
|
||||
raise Exception(f'Failed to get program log size: {res}')
|
||||
log_bytes = bytes(log_size)
|
||||
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get program log: {res}")
|
||||
raise Exception(f'Failed to get program log: {res}')
|
||||
log_str = log_bytes.decode('utf-8')
|
||||
print(log_str)
|
||||
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to compile program: {compile_res}")
|
||||
|
||||
# NVRTC can directly get the lowered name
|
||||
res, lowered_name = nvrtc.nvrtcGetLoweredName(
|
||||
program, bytes(kernel_name, 'utf-8'))
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get lowered name: {res}")
|
||||
raise Exception(f'Failed to compile program: {compile_res}')
|
||||
|
||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN size: {res}")
|
||||
raise Exception(f'Failed to get CUBIN size: {res}')
|
||||
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN: {res}")
|
||||
raise Exception(f'Failed to get CUBIN: {res}')
|
||||
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to destroy program: {res}")
|
||||
|
||||
return lowered_name.decode('utf-8')
|
||||
raise Exception(f'Failed to destroy program: {res}')
|
||||
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
|
||||
return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls)
|
||||
else:
|
||||
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')
|
||||
return NvccCompiler.build(name, code, runtime_cls=runtime_cls)
|
||||
|
@ -1,31 +1,12 @@
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise Exception("CUDA_HOME is not set")
|
||||
|
||||
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
|
||||
'-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
for line in result.stdout.splitlines():
|
||||
if pattern in line:
|
||||
return line.split()[-1]
|
||||
return None
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
self.path = path
|
||||
@ -43,7 +24,7 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cubin', 'kernel.cubin.name']
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
@ -53,18 +34,28 @@ class Runtime:
|
||||
res, lib = cuda.cuLibraryLoadFromFile(
|
||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to load library: {res}")
|
||||
raise Exception(f'Failed to load library: {res}')
|
||||
|
||||
res, kernel = cuda.cuLibraryGetKernel(
|
||||
lib, bytes(self.kernel_name, encoding='utf-8'))
|
||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel: {res}")
|
||||
raise Exception(f'Failed to get kernel count: {res}')
|
||||
|
||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to enumerate kernels: {res}')
|
||||
|
||||
for kernel in kernels:
|
||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to get kernel name: {res}')
|
||||
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
|
||||
self.kernel = kernel
|
||||
break
|
||||
|
||||
self.kernel = kernel
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
raise Exception("Failed to find kernel")
|
||||
raise Exception('Failed to find required kernel')
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
@ -81,12 +72,12 @@ class Runtime:
|
||||
if self.lib is not None:
|
||||
res = cuda.cuLibraryUnload(self.lib)[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||
raise Exception(f'Failed to unload library {self.path}: {res}')
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
super().__init__(path, kernel_name, run_gemm, [
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
@ -117,9 +108,7 @@ class RuntimeCache:
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
kernel_name = open(os.path.join(
|
||||
path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
runtime = runtime_cls(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
@ -24,7 +24,6 @@ def generate(**kwargs: Dict[str, Any]) -> str:
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
@ -43,7 +42,6 @@ __global__ void dummy_kernel() {{
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
#endif
|
||||
'''
|
||||
|
||||
# Debug print
|
||||
|
@ -53,7 +53,7 @@ def get_num_math_warpgroups(block_m: int) -> int:
|
||||
return 1 if block_m == 64 else 2
|
||||
|
||||
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
|
||||
assert num_math_threads_per_group == 128, "Only support 128 threads per math group"
|
||||
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuu
|
||||
)
|
||||
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to encode tensor map: {res}")
|
||||
raise Exception(f'Failed to encode tensor map: {res}')
|
||||
|
||||
return tensor_map
|
||||
|
||||
@ -118,7 +118,7 @@ def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_
|
||||
|
||||
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to set max dynamic shared memory size: {res}")
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
|
||||
attr_val = cuda.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
|
@ -62,17 +62,15 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
}}
|
||||
}}
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&vector_add<{kwargs['T']}>;
|
||||
}}
|
||||
#endif
|
||||
"""
|
||||
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
super().__init__(path, kernel_name, run_vector_add, [
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', run_vector_add, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
@ -87,7 +85,7 @@ if __name__ == '__main__':
|
||||
code = generate_vector_add(T='float')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime)
|
||||
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
@ -105,7 +103,7 @@ if __name__ == '__main__':
|
||||
code = generate_vector_add(T='__nv_bfloat16')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime)
|
||||
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
|
Loading…
Reference in New Issue
Block a user