mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Initial commit
This commit is contained in:
3
deep_gemm/jit/__init__.py
Normal file
3
deep_gemm/jit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .template import cpp_format, generate
|
||||
from .runtime import Runtime
|
||||
150
deep_gemm/jit/compiler.py
Normal file
150
deep_gemm/jit/compiler.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import hashlib
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
|
||||
def hash_to_hex(s: str) -> str:
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode('utf-8'))
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
# Update `interleave_ffma.py`
|
||||
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
|
||||
md5.update(f.read())
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
return path, version
|
||||
raise RuntimeError('Cannot find any available NVCC compiler')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
if 'DG_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
tmp_dir = get_tmp_dir()
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=177,174,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
assert subprocess.check_call(command) == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_so_path)
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
138
deep_gemm/jit/interleave_ffma.py
Normal file
138
deep_gemm/jit/interleave_ffma.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import mmap
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
return result.stdout
|
||||
|
||||
|
||||
def extract_ffma(sass):
|
||||
lines = sass.splitlines()
|
||||
collected = []
|
||||
current = []
|
||||
|
||||
arch_name, func_name = 'N/A', 'N/A'
|
||||
skip_next_line = False
|
||||
for line in lines:
|
||||
if 'code for' in line:
|
||||
arch_name = line.lstrip().lstrip('code for ').rstrip()
|
||||
elif 'Function :' in line:
|
||||
func_name = line.lstrip().lstrip('Function :').rstrip()
|
||||
elif 'FFMA' in line:
|
||||
current.append(line)
|
||||
skip_next_line = True
|
||||
elif skip_next_line:
|
||||
current.append(line)
|
||||
skip_next_line = False
|
||||
else:
|
||||
if len(current) >= 16:
|
||||
assert len(current) % 2 == 0
|
||||
collected.append((f'{arch_name}::{func_name}', current))
|
||||
current = []
|
||||
|
||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||
print(f"Found {len(collected)} FFMA segments")
|
||||
return collected
|
||||
|
||||
|
||||
def extract_hex_from_line(line):
|
||||
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
|
||||
assert match
|
||||
return int(match.group(1), 16)
|
||||
|
||||
|
||||
def validate(m, offset, le_bytes, num_lines):
|
||||
assert len(le_bytes) == num_lines // 2
|
||||
assert m[offset:offset + 16] == le_bytes[0]
|
||||
for i in range(1, num_lines // 2):
|
||||
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_registers(line):
|
||||
import re
|
||||
line = re.sub(r'/\*.*?\*/', '', line)
|
||||
line = line.replace(';', '')
|
||||
tokens = line.strip().split(',')
|
||||
registers = []
|
||||
for token in tokens:
|
||||
token = token.strip()
|
||||
words = token.split()
|
||||
for word in words:
|
||||
if word.startswith('R'):
|
||||
reg = word.split('.')[0]
|
||||
registers.append(reg)
|
||||
return registers
|
||||
|
||||
|
||||
def modify_segment(m, name, ffma_lines):
|
||||
num_lines = len(ffma_lines)
|
||||
assert num_lines % 2 == 0
|
||||
|
||||
le_bytes, new_le_bytes = [], []
|
||||
reused_list = []
|
||||
dst_reg_set = set()
|
||||
last_reused, last_dst_reg = False, ''
|
||||
num_changed = 0
|
||||
for i in range(num_lines // 2):
|
||||
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
|
||||
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
|
||||
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
|
||||
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
||||
reused = (high_hex & 0x0800000000000000) != 0
|
||||
if reused:
|
||||
is_first_occurred = dst_reg not in dst_reg_set
|
||||
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
||||
# Modify the `reuse` and `yield` bits
|
||||
assert high_hex & 0x0800200000000000, f"{hex(high_hex)}"
|
||||
high_hex ^= 0x0800200000000000
|
||||
reused = False
|
||||
num_changed += 1
|
||||
else:
|
||||
reused_list.append(i)
|
||||
dst_reg_set.add(dst_reg)
|
||||
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
||||
last_reused, last_dst_reg = reused, dst_reg
|
||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||
print(f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}")
|
||||
|
||||
# Find the offset
|
||||
offsets = []
|
||||
offset = m.find(le_bytes[0])
|
||||
while offset != -1:
|
||||
offsets.append(offset)
|
||||
offset = m.find(le_bytes[0], offset + 1)
|
||||
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
|
||||
|
||||
# Replace with `new_le_bytes`
|
||||
for offset in offsets:
|
||||
for i in range(num_lines // 2):
|
||||
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
|
||||
|
||||
|
||||
def process(path):
|
||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||
print(f'Processing {path}')
|
||||
output = run_cuobjdump(path)
|
||||
segments = extract_ffma(output)
|
||||
with open(path, 'r+b') as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
|
||||
for segment in segments:
|
||||
modify_segment(mm, *segment)
|
||||
mm.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
|
||||
parser.add_argument('--so', help='Path to the SO file')
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args.so)
|
||||
66
deep_gemm/jit/runtime.py
Normal file
66
deep_gemm/jit/runtime.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
def is_path_valid(path: str) -> bool:
|
||||
# Exists and is a directory
|
||||
if not os.path.exists(path) or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
93
deep_gemm/jit/template.py
Normal file
93
deep_gemm/jit/template.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
|
||||
from typing import Any, Iterable, Dict, Tuple
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
torch.int: 'torch.int',
|
||||
torch.float: 'torch.float',
|
||||
torch.bfloat16: 'torch.bfloat16',
|
||||
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
}
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
torch.int: ('void*', 'int*'),
|
||||
torch.float: ('void*', 'float*'),
|
||||
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
||||
if isinstance(value, torch.Tensor):
|
||||
return ctype(value.data_ptr())
|
||||
if isinstance(value, torch.cuda.Stream):
|
||||
return ctype(value.cuda_stream)
|
||||
return ctype(value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
# Common prefix
|
||||
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
||||
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
||||
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
||||
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
||||
|
||||
# Function signature
|
||||
raw = '__raw_'
|
||||
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
||||
code += f'extern "C" void launch('
|
||||
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
||||
code += ') {\n'
|
||||
|
||||
# Cast raw types
|
||||
code += ' // Cast raw types (if needed)\n'
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
||||
|
||||
# Function body
|
||||
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
||||
|
||||
# End the function
|
||||
code += '}\n\n'
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Generated code:\n{code}')
|
||||
|
||||
return code
|
||||
Reference in New Issue
Block a user