mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: compat for old drivers
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .runtime import Runtime, RuntimeCache, get_symbol
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@@ -108,7 +108,7 @@ class Compiler(abc.ABC):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@@ -118,7 +118,7 @@ class Compiler(abc.ABC):
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
'--diag-suppress=39,161,174,177,940']
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
@@ -144,17 +144,13 @@ class Compiler(abc.ABC):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
|
||||
start_time = time.time()
|
||||
cls.compile(name, src_path, tmp_cubin_path)
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
@@ -169,7 +165,7 @@ class Compiler(abc.ABC):
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
runtime_cache[path] = Runtime(path, kernel_name)
|
||||
return runtime_cache[path]
|
||||
|
||||
|
||||
@@ -190,7 +186,11 @@ class NvccCompiler(Compiler):
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
# Write the code
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(src_path, code)
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
@@ -200,6 +200,8 @@ class NvccCompiler(Compiler):
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
@staticmethod
|
||||
@@ -218,19 +220,27 @@ class NvrtcCompiler(Compiler):
|
||||
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
if cls.__version__() >= (12, 8):
|
||||
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
|
||||
base_flags += ['--pch']
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
base_flags += ['--pch-verbose=true']
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
code_bytes = open(src_path, 'rb').read()
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
|
||||
kernel_name = kernel_regex.search(code).group(
|
||||
0).replace('\n', '').replace(' ', '')
|
||||
res = nvrtc.nvrtcAddNameExpression(
|
||||
program, bytes(kernel_name, 'utf-8'))[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to add name expression: {res}")
|
||||
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
compile_res = nvrtc.nvrtcCompileProgram(
|
||||
program, len(options), options)[0]
|
||||
@@ -249,6 +259,11 @@ class NvrtcCompiler(Compiler):
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to compile program: {compile_res}")
|
||||
|
||||
res, lowered_name = nvrtc.nvrtcGetLoweredName(
|
||||
program, bytes(kernel_name, 'utf-8'))
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get lowered name: {res}")
|
||||
|
||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN size: {res}")
|
||||
@@ -264,6 +279,8 @@ class NvrtcCompiler(Compiler):
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to destroy program: {res}")
|
||||
|
||||
return lowered_name.decode('utf-8')
|
||||
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
|
||||
Reference in New Issue
Block a user