mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-06 19:54:24 +00:00
feat: compat for old drivers
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
78c7fa347e
commit
767793bf95
@ -13,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc
|
|||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
|
|
||||||
from . import interleave_ffma
|
from . import interleave_ffma
|
||||||
from .runtime import Runtime, RuntimeCache
|
from .runtime import Runtime, RuntimeCache, get_symbol
|
||||||
|
|
||||||
runtime_cache = RuntimeCache()
|
runtime_cache = RuntimeCache()
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class Compiler(abc.ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(cls, name: str, src_path: str, target_path: str):
|
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -118,7 +118,7 @@ class Compiler(abc.ABC):
|
|||||||
'--ptxas-options=--register-usage-level=10' +
|
'--ptxas-options=--register-usage-level=10' +
|
||||||
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||||
'--diag-suppress=39,174,177,940']
|
'--diag-suppress=39,161,174,177,940']
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def include_dirs() -> List[str]:
|
def include_dirs() -> List[str]:
|
||||||
@ -144,17 +144,13 @@ class Compiler(abc.ABC):
|
|||||||
print(f'Using cached JIT runtime {name} during build')
|
print(f'Using cached JIT runtime {name} during build')
|
||||||
return runtime_cache[path]
|
return runtime_cache[path]
|
||||||
|
|
||||||
# Write the code
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
src_path = f'{path}/kernel.cu'
|
|
||||||
put(src_path, code)
|
|
||||||
|
|
||||||
# Compile into a temporary CU file
|
# Compile into a temporary CU file
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
cubin_path = f'{path}/kernel.cubin'
|
cubin_path = f'{path}/kernel.cubin'
|
||||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
cls.compile(name, src_path, tmp_cubin_path)
|
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
if os.getenv('DG_JIT_DEBUG', None):
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
@ -169,7 +165,7 @@ class Compiler(abc.ABC):
|
|||||||
os.replace(tmp_cubin_path, cubin_path)
|
os.replace(tmp_cubin_path, cubin_path)
|
||||||
|
|
||||||
# Put cache and return
|
# Put cache and return
|
||||||
runtime_cache[path] = Runtime(path)
|
runtime_cache[path] = Runtime(path, kernel_name)
|
||||||
return runtime_cache[path]
|
return runtime_cache[path]
|
||||||
|
|
||||||
|
|
||||||
@ -190,7 +186,11 @@ class NvccCompiler(Compiler):
|
|||||||
f'--compiler-options={",".join(cxx_flags)}']
|
f'--compiler-options={",".join(cxx_flags)}']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compile(cls, name: str, src_path: str, target_path: str):
|
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||||
|
# Write the code
|
||||||
|
path = f'{get_cache_dir()}/{name}'
|
||||||
|
src_path = f'{path}/kernel.cu'
|
||||||
|
put(src_path, code)
|
||||||
command = [get_nvcc_compiler()[0],
|
command = [get_nvcc_compiler()[0],
|
||||||
src_path, '-o', target_path,
|
src_path, '-o', target_path,
|
||||||
*cls.flags()]
|
*cls.flags()]
|
||||||
@ -200,6 +200,8 @@ class NvccCompiler(Compiler):
|
|||||||
return_code = subprocess.check_call(command)
|
return_code = subprocess.check_call(command)
|
||||||
assert return_code == 0, f'Failed to compile {src_path}'
|
assert return_code == 0, f'Failed to compile {src_path}'
|
||||||
|
|
||||||
|
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||||
|
|
||||||
|
|
||||||
class NvrtcCompiler(Compiler):
|
class NvrtcCompiler(Compiler):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -218,19 +220,27 @@ class NvrtcCompiler(Compiler):
|
|||||||
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||||
'--gpu-architecture=sm_90a', '-default-device']
|
'--gpu-architecture=sm_90a', '-default-device']
|
||||||
if cls.__version__() >= (12, 8):
|
if cls.__version__() >= (12, 8):
|
||||||
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
|
base_flags += ['--pch']
|
||||||
if os.getenv('DG_JIT_DEBUG', None):
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
base_flags += ['--pch-verbose=true']
|
base_flags += ['--pch-verbose=true']
|
||||||
return base_flags
|
return base_flags
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compile(cls, name: str, src_path: str, target_path: str):
|
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||||
code_bytes = open(src_path, 'rb').read()
|
code_bytes = bytes(code, 'utf-8')
|
||||||
res, program = nvrtc.nvrtcCreateProgram(
|
res, program = nvrtc.nvrtcCreateProgram(
|
||||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
raise Exception(f"Failed to create program: {res}")
|
raise Exception(f"Failed to create program: {res}")
|
||||||
|
|
||||||
|
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
|
||||||
|
kernel_name = kernel_regex.search(code).group(
|
||||||
|
0).replace('\n', '').replace(' ', '')
|
||||||
|
res = nvrtc.nvrtcAddNameExpression(
|
||||||
|
program, bytes(kernel_name, 'utf-8'))[0]
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to add name expression: {res}")
|
||||||
|
|
||||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||||
compile_res = nvrtc.nvrtcCompileProgram(
|
compile_res = nvrtc.nvrtcCompileProgram(
|
||||||
program, len(options), options)[0]
|
program, len(options), options)[0]
|
||||||
@ -249,6 +259,11 @@ class NvrtcCompiler(Compiler):
|
|||||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
raise Exception(f"Failed to compile program: {compile_res}")
|
raise Exception(f"Failed to compile program: {compile_res}")
|
||||||
|
|
||||||
|
res, lowered_name = nvrtc.nvrtcGetLoweredName(
|
||||||
|
program, bytes(kernel_name, 'utf-8'))
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to get lowered name: {res}")
|
||||||
|
|
||||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
raise Exception(f"Failed to get CUBIN size: {res}")
|
raise Exception(f"Failed to get CUBIN size: {res}")
|
||||||
@ -264,6 +279,8 @@ class NvrtcCompiler(Compiler):
|
|||||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
raise Exception(f"Failed to destroy program: {res}")
|
raise Exception(f"Failed to destroy program: {res}")
|
||||||
|
|
||||||
|
return lowered_name.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def build(name: str, code: str) -> Runtime:
|
def build(name: str, code: str) -> Runtime:
|
||||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||||
|
@ -1,16 +1,31 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import subprocess
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import cuda.bindings.driver as cuda
|
import cuda.bindings.driver as cuda
|
||||||
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
|
|
||||||
from .utils import run_gemm
|
from .utils import run_gemm
|
||||||
|
|
||||||
|
|
||||||
|
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||||
|
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', file_path]
|
||||||
|
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE, text=True)
|
||||||
|
assert result.returncode == 0
|
||||||
|
for line in result.stdout.splitlines():
|
||||||
|
if pattern in line:
|
||||||
|
return line.split()[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, path: str) -> None:
|
def __init__(self, path: str, kernel_name: str) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.lib = None
|
self.lib = None
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
|
self.kernel_name = kernel_name
|
||||||
|
|
||||||
assert self.is_path_valid(self.path)
|
assert self.is_path_valid(self.path)
|
||||||
|
|
||||||
@ -21,34 +36,25 @@ class Runtime:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Contains all necessary files
|
# Contains all necessary files
|
||||||
files = ['kernel.cu', 'kernel.cubin']
|
files = ['kernel.cubin']
|
||||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||||
|
|
||||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||||
# Load CUBIN
|
# Load CUBIN
|
||||||
if self.lib is None:
|
if self.kernel is None:
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
res, lib = cuda.cuLibraryLoadFromFile(
|
res, lib = cuda.cuLibraryLoadFromFile(
|
||||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||||
raise Exception(f"Failed to load library: {res}")
|
raise Exception(f"Failed to load library: {res}")
|
||||||
|
|
||||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
print(f"Kernel name: {self.kernel_name}")
|
||||||
|
res, kernel = cuda.cuLibraryGetKernel(
|
||||||
|
lib, bytes(self.kernel_name, encoding='utf-8'))
|
||||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||||
raise Exception(f"Failed to get kernel count: {res}")
|
raise Exception(f"Failed to get kernel: {res}")
|
||||||
|
|
||||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
|
||||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
|
||||||
raise Exception(f"Failed to enumerate kernels: {res}")
|
|
||||||
|
|
||||||
for kernel in kernels:
|
|
||||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
|
||||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
|
||||||
raise Exception(f"Failed to get kernel name: {res}")
|
|
||||||
if b"fp8" in kernel_name:
|
|
||||||
self.kernel = kernel
|
self.kernel = kernel
|
||||||
break
|
|
||||||
|
|
||||||
if self.kernel is not None:
|
if self.kernel is not None:
|
||||||
self.lib = lib
|
self.lib = lib
|
||||||
else:
|
else:
|
||||||
@ -95,7 +101,9 @@ class RuntimeCache:
|
|||||||
|
|
||||||
# Already compiled
|
# Already compiled
|
||||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||||
runtime = Runtime(path)
|
kernel_name = get_symbol(os.path.join(
|
||||||
|
path, 'kernel.cubin'), 'fp8_gemm_kernel')
|
||||||
|
runtime = Runtime(path, kernel_name)
|
||||||
self.cache[path] = runtime
|
self.cache[path] = runtime
|
||||||
return runtime
|
return runtime
|
||||||
return None
|
return None
|
||||||
|
@ -22,7 +22,9 @@ def generate(**kwargs: Dict[str, Any]) -> str:
|
|||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#include <deep_gemm/fp8_gemm.cuh>
|
#include <deep_gemm/fp8_gemm.cuh>
|
||||||
|
|
||||||
namespace deep_gemm {{
|
using namespace deep_gemm;
|
||||||
|
|
||||||
|
#ifndef NVRTC_JIT_COMPILATION
|
||||||
__global__ void dummy_kernel() {{
|
__global__ void dummy_kernel() {{
|
||||||
void *ptr = (void *)&fp8_gemm_kernel<
|
void *ptr = (void *)&fp8_gemm_kernel<
|
||||||
{kwargs['N']},
|
{kwargs['N']},
|
||||||
@ -41,7 +43,7 @@ __global__ void dummy_kernel() {{
|
|||||||
GemmType::{kwargs['GEMM_TYPE']}
|
GemmType::{kwargs['GEMM_TYPE']}
|
||||||
>;
|
>;
|
||||||
}}
|
}}
|
||||||
}}
|
#endif
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Debug print
|
# Debug print
|
||||||
|
Loading…
Reference in New Issue
Block a user