mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: save kernel name to file
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -92,7 +92,9 @@ def make_tmp_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
def put(path, data):
|
||||
is_binary = isinstance(data, bytes)
|
||||
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
@@ -153,16 +155,20 @@ class Compiler(abc.ABC):
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if os.getenv('DG_JIT_DEBUG', None) or True:
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Store kernel name
|
||||
put(f'{tmp_cubin_path}.name', kernel_name)
|
||||
|
||||
# Atomic replace CU file
|
||||
# Atomic replace files
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path, kernel_name)
|
||||
@@ -200,6 +206,7 @@ class NvccCompiler(Compiler):
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
|
||||
|
||||
@@ -259,6 +266,7 @@ class NvrtcCompiler(Compiler):
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to compile program: {compile_res}")
|
||||
|
||||
# NVRTC can directly get the lowered name
|
||||
res, lowered_name = nvrtc.nvrtcGetLoweredName(
|
||||
program, bytes(kernel_name, 'utf-8'))
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
@@ -273,7 +281,7 @@ class NvrtcCompiler(Compiler):
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN: {res}")
|
||||
|
||||
put(target_path, cubin_bytes, is_binary=True)
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
|
||||
@@ -36,7 +36,7 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cubin']
|
||||
files = ['kernel.cubin', 'kernel.cubin.name']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
@@ -48,7 +48,6 @@ class Runtime:
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to load library: {res}")
|
||||
|
||||
print(f"Kernel name: {self.kernel_name}")
|
||||
res, kernel = cuda.cuLibraryGetKernel(
|
||||
lib, bytes(self.kernel_name, encoding='utf-8'))
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
@@ -101,8 +100,7 @@ class RuntimeCache:
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
kernel_name = get_symbol(os.path.join(
|
||||
path, 'kernel.cubin'), 'fp8_gemm_kernel')
|
||||
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = Runtime(path, kernel_name)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
Reference in New Issue
Block a user