[wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-22 08:08:40 +00:00
parent 891f35adf5
commit 27cd276e19
13 changed files with 581 additions and 323 deletions

View File

@ -504,63 +504,73 @@ public:
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess); DG_HOST_ASSERT(status == cudaSuccess);
} }
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
}
template <typename T>
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
return make_2d_tma_desc(global_address, Layout::ColMajor,
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
swizzle_mode);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor,
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_desc(
T* global_address, Layout layout,
uint32_t gmem_rows, uint32_t gmem_cols,
uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
}; };
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_k, block_m, block_k);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
block_k, block_n);
}
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if constexpr (kSwizzleDMode == 32)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_n, block_m,
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(
global_address, Layout::ColMajor, shape_m,
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type =
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
}; // namespace deep_gemm }; // namespace deep_gemm
#pragma clang diagnostic pop #pragma clang diagnostic pop

View File

@ -1,6 +1,8 @@
#pragma once #pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <cuda.h> #include <cuda.h>
#endif
#include <cute/arch/mma_sm90_gmma.hpp> #include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp> #include <cute/arch/mma_sm90_gmma_ext.hpp>

View File

@ -0,0 +1,69 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef NVRTC_JIT_COMPILATION
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
using cuuint64_t = unsigned long long;
namespace std
{
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
__device__ constexpr operator value_type() const noexcept
{
return value;
}
__device__ constexpr value_type operator()() const noexcept
{
return value;
} // since c++14
};
using false_type = integral_constant<bool, false>;
using true_type = integral_constant<bool, true>;
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value;
} // namespace std
#endif

View File

@ -1,6 +1,9 @@
#pragma once #pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <cassert> #include <cassert>
#endif
#include <cuda.h> #include <cuda.h>
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <exception> #include <exception>
#ifdef __CLION_IDE__ #ifdef __CLION_IDE__
@ -16,8 +17,12 @@ public:
const char *what() const noexcept override { return message.c_str(); } const char *what() const noexcept override { return message.c_str(); }
}; };
#endif
#ifndef DG_HOST_ASSERT #ifndef DG_HOST_ASSERT
#ifdef NVRTC_JIT_COMPILATION
#define DG_HOST_ASSERT(cond) ((void)0)
#else
#define DG_HOST_ASSERT(cond) \ #define DG_HOST_ASSERT(cond) \
do { \ do { \
if (not (cond)) { \ if (not (cond)) { \
@ -27,6 +32,7 @@ do { \
} \ } \
} while (0) } while (0)
#endif #endif
#endif
#ifndef DG_DEVICE_ASSERT #ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \ #define DG_DEVICE_ASSERT(cond) \

View File

@ -1,3 +1,3 @@
from .compiler import get_nvcc_compiler, build from .compiler import get_nvcc_compiler, build
from .template import cpp_format, generate from .template import generate
from .runtime import Runtime from .runtime import Runtime

View File

@ -3,13 +3,13 @@ import functools
import os import os
import re import re
import subprocess import subprocess
import time
import uuid import uuid
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple from typing import Tuple
from . import interleave_ffma from . import interleave_ffma
from .runtime import Runtime, RuntimeCache from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache() runtime_cache = RuntimeCache()
@ -94,11 +94,11 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path) os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime: def build(name: str, code: str) -> Runtime:
# Compiler flags # Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a', '-gencode=arch=compute_90a,code=sm_90a', '-cubin',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940'] '--diag-suppress=39,174,177,940']
@ -121,31 +121,36 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Write the code # Write the code
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
args_path = f'{path}/kernel.args'
src_path = f'{path}/kernel.cu' src_path = f'{path}/kernel.cu'
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
put(src_path, code) put(src_path, code)
# Compile into a temporary SO file # Compile into a temporary CU file
so_path = f'{path}/kernel.so' cubin_path = f'{path}/kernel.cubin'
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so' tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
# Compile # Compile
command = [get_nvcc_compiler()[0], command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_so_path, src_path, '-o', tmp_cubin_path,
*flags, *flags,
*[f'-I{d}' for d in include_dirs]] *[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}') print(f'Compiling JIT runtime {name} with command {command}')
start_time = time.time()
return_code = subprocess.check_call(command) return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}' assert return_code == 0, f'Failed to compile {src_path}'
# Print elapsed time if debug is enabled
elapsed_time = end_time - start_time
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse # Interleave FFMA reuse
if enable_sass_opt: if enable_sass_opt:
interleave_ffma.process(tmp_so_path) interleave_ffma.process(tmp_cubin_path)
# Atomic replace SO file # Atomic replace CU file
os.replace(tmp_so_path, so_path) os.replace(tmp_cubin_path, cubin_path)
# Put cache and return # Put cache and return
runtime_cache[path] = Runtime(path) runtime_cache[path] = Runtime(path)

View File

@ -1,16 +1,18 @@
import ctypes
import os import os
import time
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
import torch import torch
from typing import Optional
from .template import map_ctype
from .utils import run_gemm
class Runtime: class Runtime:
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
self.path = path self.path = path
self.lib = None self.lib = None
self.args = None self.kernel = None
assert self.is_path_valid(self.path) assert self.is_path_valid(self.path)
@ -21,29 +23,66 @@ class Runtime:
return False return False
# Contains all necessary files # Contains all necessary files
files = ['kernel.cu', 'kernel.args', 'kernel.so'] files = ['kernel.cu', 'kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files) return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, *args) -> int: def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load SO file # Load CUBIN
if self.lib is None or self.args is None: if self.lib is None:
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so')) start_time = time.time_ns()
with open(os.path.join(self.path, 'kernel.args'), 'r') as f: res, lib = cuda.cuLibraryLoadFromFile(
self.args = eval(f.read()) bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
# Check args and launch res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' if res != cuda.CUresult.CUDA_SUCCESS:
cargs = [] raise Exception(f"Failed to get kernel count: {res}")
for arg, (name, dtype) in zip(args, self.args):
if isinstance(arg, torch.Tensor): res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to enumerate kernels: {res}")
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel name: {res}")
if b"fp8" in kernel_name:
self.kernel = kernel
break
if self.kernel is not None:
self.lib = lib
else: else:
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' raise Exception("Failed to find fp8 gemm kernel")
cargs.append(map_ctype(arg))
return_code = ctypes.c_int(0) end_time = time.time_ns()
self.lib.launch(*cargs, ctypes.byref(return_code)) elapsed_time = (end_time - start_time) / 1000
return return_code.value print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
self.kernel,
kwargs['NUM_TMA_MULTICAST'],
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
)
def __del__(self) -> None:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to unload library {self.path}: {res}")
class RuntimeCache: class RuntimeCache:

View File

@ -1,111 +1,48 @@
import copy
import ctypes
import os import os
import torch from typing import Any, Dict
from typing import Any, Dict, Iterable, Tuple
# Name map for Python `eval` def generate(**kwargs: Dict[str, Any]) -> str:
typename_map: Dict[Any, str] = { code = f'''
**{t: t.__name__ for t in (bool, int, float)}, #ifdef __CUDACC_RTC__
torch.int: 'torch.int', #ifndef NVRTC_JIT_COMPILATION
torch.float: 'torch.float', #define NVRTC_JIT_COMPILATION
torch.bfloat16: 'torch.bfloat16', #endif
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
torch.cuda.Stream: 'torch.cuda.Stream',
}
# `ctype` map for Python casting #include <deep_gemm/nvrtc_std.cuh>
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
}
#else
# Type map for both Python API and source code usages #include <string>
genc_map = { #include <cuda.h>
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
torch.int: ('void*', 'int*'),
torch.float: ('void*', 'float*'),
torch.bfloat16: ('void*', '__nv_bfloat16*'),
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
torch.cuda.Stream: ('void*', 'cudaStream_t'),
}
#endif
def map_ctype(value: Any) -> Any: #include <cuda_bf16.h>
if hasattr(value, 'data_ptr'): #include <cuda_fp8.h>
if value.dtype == torch.int: #include <deep_gemm/fp8_gemm.cuh>
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.bfloat16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
if hasattr(value, 'cuda_stream'): namespace deep_gemm {{
return ctypes.c_void_p(value.cuda_stream) __global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
if isinstance(value, bool): {kwargs['N']},
return ctypes.c_bool(value) {kwargs['K']},
elif isinstance(value, int): {kwargs['BLOCK_M']},
return ctypes.c_int(value) {kwargs['BLOCK_N']},
elif isinstance(value, float): {kwargs['BLOCK_K']},
return ctypes.c_float(value) {kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
return ctype_map[type(value)](value) {kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
def cpp_format(template: str, keys: Dict[str, Any]) -> str: {kwargs['NUM_MATH_THREADS_PER_GROUP']},
# We don't use `str.format` because it's not safe for C++ {} braces {kwargs['NUM_TMA_MULTICAST']},
new_template = copy.deepcopy(template) {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
for key, value in keys.items(): GemmType::{kwargs['GEMM_TYPE']}
value_str = str(value) >;
if isinstance(value, bool): }}
value_str = value_str.lower() }}
new_template = new_template.replace(f'{{{key}}}', f'{value_str}') '''
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
# Common prefix
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
# Includes
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
# Function signature
raw = '__raw_'
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
code += f'extern "C" void launch('
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
code += ') {\n'
# Cast raw types
code += ' // Cast raw types (if needed)\n'
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
# Function body
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
# End the function
code += '}\n\n'
# Debug print # Debug print
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):

164
deep_gemm/jit/utils.py Normal file
View File

@ -0,0 +1,164 @@
import ctypes
from enum import Enum
from typing import Any, Dict, Tuple
import cuda.bindings.driver as cuda
import torch
class Layout(Enum):
RowMajor = 0
ColMajor = 1
class GemmType(Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
typename_map: Dict[Any, str] = {
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_map = {
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, "Only support 128 threads per math group"
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
tensor_dtype = typename_map[global_address.dtype]
res, tensor_map = cuda.cuTensorMapEncodeTiled(
tensor_dtype,
2, # tensor rank
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes,), # global strides
smem_dim,
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to encode tensor map: {res}")
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Make TMA aligned to 16 bytes
kAlignment = 16 / global_address.element_size()
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to set max dynamic shared memory size: {res}")
attr_val = cuda.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cuda.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
kernelValues = (
gmem_d.data_ptr(),
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)

View File

@ -3,40 +3,11 @@ import torch
from functools import lru_cache from functools import lru_cache
from typing import Tuple from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .tuner import jit_tuner from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
constexpr auto kNumGroups = 1;
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated GEMM
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
gemm_t::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1: if num_tma_multicast == 1:
@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
# Try swizzle first, as it does not waste shared memory # Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n) swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
smem_d = block_m * (block_n + block_n_padding) * 2 smem_d = block_m * (block_n + block_n_padding) * 2
smem_a_per_stage = block_m * block_k smem_a_per_stage = block_m * block_k
@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += ceil_div(smem_scales_b * (1 if block_k %
block_n == 0 else 2), 8) * 8
smem_size += smem_barrier smem_size += smem_barrier
# Swizzle and padding are not compatible # Swizzle and padding are not compatible
@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + (144, 160, ) block_ns = tuple(range(16, 129, 8)) + (144, 160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x def fix_wave_saturate(x): return num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) def get_num_waves(bm, bn): return (ceil_div(
ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
def get_last_wave_util(bm, bn): return fix_wave_saturate(
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves # Decide block sizes by waves
best_block_m, best_block_n = None, None best_block_m, best_block_n = None, None
@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# NOTES: the block sizes can not be too large, so at least one dim less than 128 # NOTES: the block sizes can not be too large, so at least one dim less than 128
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
success = False success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) num_waves, best_num_waves = get_num_waves(
block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None: if best_block_m is None or best_block_n is None:
success = True success = True
elif num_waves < best_num_waves: elif num_waves < best_num_waves:
@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
success |= block_n == best_block_n and block_m < best_block_m success |= block_n == best_block_n and block_m < best_block_m
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
success |= block_m != best_block_m and block_n > best_block_n success |= block_m != best_block_m and block_n > best_block_n
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) best_block_m, best_block_n = (block_m, block_n) if success else (
best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None assert best_block_m is not None and best_block_n is not None
# Always pick the longest one # Always pick the longest one
@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Unrolling both stages and `num_former_iters` will cause large code size # Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, 3) stage_candidates = (4, 3)
for num_stages in stage_candidates: for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) best_smem_config = get_smem_config(
num_stages, k, best_block_m, best_block_n)
if best_smem_config[0] <= sm90_capacity: if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages best_num_stages = num_stages
break break
@ -158,8 +138,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Recompute the minimal number of SMs required # Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop # NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n) num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) num_min_sms = ceil_div(ceil_div(m, best_block_m) *
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(
num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
@ -210,11 +192,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
return return
# Auto-tuning with compilation # Auto-tuning with compilation
global includes, template
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0]) m, n, k, 1, num_sms)
runtime = jit_tuner.compile_and_tune( block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
kwargs = {
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'NUM_GROUPS': 1,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt', name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1], 'SWIZZLE_D_MODE': smem_config[1],
@ -223,14 +236,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'NUM_TMA_MULTICAST': tma_multicast_config[0], 'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(), space=(),
includes=includes, kwargs=kwargs
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
) )
# Run the kernel # Run the kernel
runtime(*args) runtime(**best_keys, **kwargs)

View File

@ -1,41 +1,12 @@
import torch import torch
from typing import Tuple from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .gemm import get_best_configs, get_block_n_padding_for_smem_d from .gemm import get_best_configs, get_block_n_padding_for_smem_d
from .tuner import jit_tuner from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms from .utils import get_col_major_tma_aligned_tensor, get_num_sms
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
constexpr auto kNumGroups = {NUM_GROUPS};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated grouped GEMM
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
gemm_t::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
return return
# Auto-tuning with compilation # Auto-tuning with compilation
global includes, template
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, n, k, 1, num_sms, is_grouped_contiguous=True)
m_indices, m, num_groups, block_k = 128
torch.cuda.current_stream(), num_sms, smem_config[0]) num_tma_threads = 128
runtime = jit_tuner.compile_and_tune( num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': m_indices,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt', name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1], 'SWIZZLE_D_MODE': smem_config[1],
@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'NUM_STAGES': num_stages, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0], 'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedContiguous'}, 'GEMM_TYPE': GemmType.GroupedContiguous},
space=(), space=(),
includes=includes, kwargs=kwargs,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
) )
# Run the kernel # Run the kernel
runtime(*args) runtime(**best_keys, **kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation # Auto-tuning with compilation
global includes, template global includes, template
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
# Extra checks for TMA store # Extra checks for TMA store
if num_groups > 1 and m > block_m: if num_groups > 1 and m > block_m:
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
args = (lhs, lhs_scales, rhs, rhs_scales, out, block_k = 128
masked_m, m, num_tma_threads = 128
torch.cuda.current_stream(), num_sms, smem_config[0]) num_math_threads_per_group = 128
runtime = jit_tuner.compile_and_tune(
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': masked_m,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt', name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1], 'SWIZZLE_D_MODE': smem_config[1],
@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'NUM_STAGES': num_stages, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0], 'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedMasked'}, 'GEMM_TYPE': GemmType.GroupedMasked},
space=(), space=(),
includes=includes, kwargs=kwargs,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
) )
# Run the kernel # Run the kernel
runtime(*args) runtime(**best_keys, **kwargs)

View File

@ -3,15 +3,16 @@ import os
import torch import torch
from typing import Any, Dict from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime import cuda.bindings.driver as cuda
from ..jit import build, generate, Runtime
class JITTuner: class JITTuner:
def __init__(self) -> None: def __init__(self) -> None:
self.tuned = {} self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
# NOTES: we always assume the space and template will not change # NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed # We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects # NOTES: the function must have no accumulated side effects
@ -26,7 +27,7 @@ class JITTuner:
print(f'Auto-tuning JIT kernel {name} with keys {keys}') print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned assert signature not in self.tuned
assert args is not None assert kwargs is not None
space = (dict(), ) if len(space) == 0 else space space = (dict(), ) if len(space) == 0 else space
kernels = [] kernels = []
@ -34,30 +35,31 @@ class JITTuner:
assert isinstance(tuned_keys, dict) assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys) full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys) full_keys.update(tuned_keys)
code = generate(includes, arg_defs, cpp_format(template, full_keys)) code = generate(**kwargs, **full_keys)
kernels.append((build(name, code), full_keys))
# Illegal build must raise errors
kernels.append((build(name, arg_defs, code), tuned_keys))
best_runtime, best_time, best_keys = None, None, None best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels: for runtime, tuned_keys in kernels:
if len(space) > 1: if len(space) > 1:
# Check kernel validity # Check kernel validity
return_code = runtime(*args) return_code = runtime(**tuned_keys, **kwargs)
if return_code != 0: if return_code != cuda.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g. insufficient shared memory capacity # Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') print(
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() torch.empty(int(256e6 // 4), dtype=torch.int,
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record() start_event.record()
for i in range(20): for i in range(20):
assert runtime(*args) == 0 assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event) elapsed_time = start_event.elapsed_time(end_event)
@ -68,14 +70,16 @@ class JITTuner:
if best_time is None or elapsed_time < best_time: if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') print(
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return # Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') print(
self.tuned[signature] = best_runtime f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
return best_runtime self.tuned[signature] = (best_runtime, best_keys)
return best_runtime, best_keys
jit_tuner = JITTuner() jit_tuner = JITTuner()