mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: save kernel name to file
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -92,7 +92,9 @@ def make_tmp_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
def put(path, data):
|
||||
is_binary = isinstance(data, bytes)
|
||||
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
@@ -153,16 +155,20 @@ class Compiler(abc.ABC):
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if os.getenv('DG_JIT_DEBUG', None) or True:
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Store kernel name
|
||||
put(f'{tmp_cubin_path}.name', kernel_name)
|
||||
|
||||
# Atomic replace CU file
|
||||
# Atomic replace files
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path, kernel_name)
|
||||
@@ -200,6 +206,7 @@ class NvccCompiler(Compiler):
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
|
||||
|
||||
@@ -259,6 +266,7 @@ class NvrtcCompiler(Compiler):
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to compile program: {compile_res}")
|
||||
|
||||
# NVRTC can directly get the lowered name
|
||||
res, lowered_name = nvrtc.nvrtcGetLoweredName(
|
||||
program, bytes(kernel_name, 'utf-8'))
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
@@ -273,7 +281,7 @@ class NvrtcCompiler(Compiler):
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN: {res}")
|
||||
|
||||
put(target_path, cubin_bytes, is_binary=True)
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
|
||||
Reference in New Issue
Block a user