feat: save kernel name to file

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 22:23:47 -07:00
parent 767793bf95
commit a3210ac850
2 changed files with 14 additions and 8 deletions

View File

@@ -92,7 +92,9 @@ def make_tmp_dir():
return tmp_dir
def put(path, data, is_binary=False):
def put(path, data):
is_binary = isinstance(data, bytes)
# Write and do POSIX atomic replace
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
@@ -153,16 +155,20 @@ class Compiler(abc.ABC):
kernel_name = cls.compile(name, code, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
if os.getenv('DG_JIT_DEBUG', None) or True:
print(
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Store kernel name
put(f'{tmp_cubin_path}.name', kernel_name)
# Atomic replace CU file
# Atomic replace files
os.replace(tmp_cubin_path, cubin_path)
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return
runtime_cache[path] = Runtime(path, kernel_name)
@@ -200,6 +206,7 @@ class NvccCompiler(Compiler):
return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel')
@@ -259,6 +266,7 @@ class NvrtcCompiler(Compiler):
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}")
# NVRTC can directly get the lowered name
res, lowered_name = nvrtc.nvrtcGetLoweredName(
program, bytes(kernel_name, 'utf-8'))
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
@@ -273,7 +281,7 @@ class NvrtcCompiler(Compiler):
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}")
put(target_path, cubin_bytes, is_binary=True)
put(target_path, cubin_bytes)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:

View File

@@ -36,7 +36,7 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cubin']
files = ['kernel.cubin', 'kernel.cubin.name']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
@@ -48,7 +48,6 @@ class Runtime:
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
print(f"Kernel name: {self.kernel_name}")
res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS:
@@ -101,8 +100,7 @@ class RuntimeCache:
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = get_symbol(os.path.join(
path, 'kernel.cubin'), 'fp8_gemm_kernel')
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
self.cache[path] = runtime
return runtime