DeepGEMM/deep_gemm/jit/runtime.py
Zihua Wu 40c09fb883 feat: fix win compat
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
2025-04-22 23:01:23 -07:00

113 lines
3.7 KiB
Python

import os
import time
import subprocess
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
command = [os.path.join(CUDA_HOME, 'bin', 'cuobjdump'), '-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
for line in result.stdout.splitlines():
if pattern in line:
return line.split()[-1]
return None
class Runtime:
def __init__(self, path: str, kernel_name: str) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
assert self.is_path_valid(self.path)
@staticmethod
def is_path_valid(path: str) -> bool:
# Exists and is a directory
if not os.path.exists(path) or not os.path.isdir(path):
return False
# Contains all necessary files
files = ['kernel.cubin', 'kernel.cubin.name']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel: {res}")
self.kernel = kernel
if self.kernel is not None:
self.lib = lib
else:
raise Exception("Failed to find fp8 gemm kernel")
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
self.kernel,
kwargs['NUM_TMA_MULTICAST'],
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
)
def __del__(self) -> None:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to unload library {self.path}: {res}")
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
self.cache[path] = runtime
return runtime
return None
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime