diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index ab67d4d..d267147 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -92,7 +92,9 @@ def make_tmp_dir(): return tmp_dir -def put(path, data, is_binary=False): +def put(path, data): + is_binary = isinstance(data, bytes) + # Write and do POSIX atomic replace tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' with open(tmp_file_path, 'wb' if is_binary else 'w') as f: @@ -153,16 +155,20 @@ class Compiler(abc.ABC): kernel_name = cls.compile(name, code, tmp_cubin_path) end_time = time.time() elapsed_time = end_time - start_time - if os.getenv('DG_JIT_DEBUG', None): + if os.getenv('DG_JIT_DEBUG', None) or True: print( f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') # Interleave FFMA reuse if enable_sass_opt: interleave_ffma.process(tmp_cubin_path) + + # Store kernel name + put(f'{tmp_cubin_path}.name', kernel_name) - # Atomic replace CU file + # Atomic replace files os.replace(tmp_cubin_path, cubin_path) + os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name') # Put cache and return runtime_cache[path] = Runtime(path, kernel_name) @@ -200,6 +206,7 @@ class NvccCompiler(Compiler): return_code = subprocess.check_call(command) assert return_code == 0, f'Failed to compile {src_path}' + # NVCC needs to get the symbol name from the cubin file using `cuobjdump` return get_symbol(target_path, 'fp8_gemm_kernel') @@ -259,6 +266,7 @@ class NvrtcCompiler(Compiler): if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to compile program: {compile_res}") + # NVRTC can directly get the lowered name res, lowered_name = nvrtc.nvrtcGetLoweredName( program, bytes(kernel_name, 'utf-8')) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: @@ -273,7 +281,7 @@ class NvrtcCompiler(Compiler): if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to get CUBIN: {res}") - put(target_path, cubin_bytes, is_binary=True) + put(target_path, cubin_bytes) res = nvrtc.nvrtcDestroyProgram(program)[0] if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 5c4e0ff..0bbc0ca 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -36,7 +36,7 @@ class Runtime: return False # Contains all necessary files - files = ['kernel.cubin'] + files = ['kernel.cubin', 'kernel.cubin.name'] return all(os.path.exists(os.path.join(path, file)) for file in files) def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: @@ -48,7 +48,6 @@ class Runtime: if res != cuda.CUresult.CUDA_SUCCESS: raise Exception(f"Failed to load library: {res}") - print(f"Kernel name: {self.kernel_name}") res, kernel = cuda.cuLibraryGetKernel( lib, bytes(self.kernel_name, encoding='utf-8')) if res != cuda.CUresult.CUDA_SUCCESS: @@ -101,8 +100,7 @@ class RuntimeCache: # Already compiled if os.path.exists(path) and Runtime.is_path_valid(path): - kernel_name = get_symbol(os.path.join( - path, 'kernel.cubin'), 'fp8_gemm_kernel') + kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read() runtime = Runtime(path, kernel_name) self.cache[path] = runtime return runtime