From 27cd276e19515ceffa7231d3768cbfc01d5cc874 Mon Sep 17 00:00:00 2001 From: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Date: Tue, 22 Apr 2025 08:08:40 +0000 Subject: [PATCH] [wip] refactor: compile to .cubin Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 120 ++++++++-------- deep_gemm/include/deep_gemm/mma_utils.cuh | 2 + deep_gemm/include/deep_gemm/nvrtc_std.cuh | 69 +++++++++ deep_gemm/include/deep_gemm/tma_utils.cuh | 3 + deep_gemm/include/deep_gemm/utils.cuh | 6 + deep_gemm/jit/__init__.py | 2 +- deep_gemm/jit/compiler.py | 29 ++-- deep_gemm/jit/runtime.py | 85 ++++++++--- deep_gemm/jit/template.py | 135 +++++------------- deep_gemm/jit/utils.py | 164 ++++++++++++++++++++++ deep_gemm/jit_kernels/gemm.py | 113 ++++++++------- deep_gemm/jit_kernels/m_grouped_gemm.py | 136 ++++++++++-------- deep_gemm/jit_kernels/tuner.py | 40 +++--- 13 files changed, 581 insertions(+), 323 deletions(-) create mode 100644 deep_gemm/include/deep_gemm/nvrtc_std.cuh create mode 100644 deep_gemm/jit/utils.py diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index c2934b8..3574495 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -504,63 +504,73 @@ public: tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); DG_HOST_ASSERT(status == cudaSuccess); } - - template - static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); - } - - template - static CUtensorMap make_2d_tma_b_desc(T* global_address) { - return make_2d_tma_desc(global_address, Layout::ColMajor, - SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); - } - - template - static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { - auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; - if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; - if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; - if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; - - // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes - // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), - swizzle_mode); - } - - template - static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { - // Make TMA aligned to 16 bytes - constexpr uint32_t kAlignment = 16 / sizeof(T); - shape_m = ceil_div(shape_m, kAlignment) * kAlignment; - - return make_2d_tma_desc(global_address, Layout::ColMajor, - shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); - } - - template - static CUtensorMap make_2d_tma_desc( - T* global_address, Layout layout, - uint32_t gmem_rows, uint32_t gmem_cols, - uint32_t smem_rows, uint32_t smem_cols, - CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { - if (layout == Layout::RowMajor) { - uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; - uint32_t smem_dim[2] = {smem_cols, smem_rows}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); - } else { - uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; - uint32_t smem_dim[2] = {smem_rows, smem_cols}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); - } - } }; +template +static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) { + return make_2d_tma_desc( + global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1), + shape_k, block_m, block_k); +} + +template +static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) { + return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k, + shape_n * (kGemmType != GemmType::Normal ? num_groups : 1), + block_k, block_n); +} + +template +static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) { + auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + if constexpr (kSwizzleDMode == 32) + swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; + if constexpr (kSwizzleDMode == 64) + swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; + if constexpr (kSwizzleDMode == 128) + swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; + + // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` + // bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc( + global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1), + shape_n, block_m, + kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode); +} + +template +static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_m = ceil_div(shape_m, kAlignment) * kAlignment; + + return make_2d_tma_desc( + global_address, Layout::ColMajor, shape_m, + ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1), + block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +static CUtensorMap +make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows, + uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols, + CUtensorMapSwizzle swizzle_type = + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { + if (layout == Layout::RowMajor) { + uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; + uint32_t smem_dim[2] = {smem_cols, smem_rows}; + return make_2d_tma_copy_desc(global_address, gmem_dim, + gmem_cols * sizeof(T), smem_dim, swizzle_type); + } else { + uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; + uint32_t smem_dim[2] = {smem_rows, smem_cols}; + return make_2d_tma_copy_desc(global_address, gmem_dim, + gmem_rows * sizeof(T), smem_dim, swizzle_type); + } +} + }; // namespace deep_gemm #pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index a442af7..f07e540 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -1,6 +1,8 @@ #pragma once +#ifndef NVRTC_JIT_COMPILATION #include +#endif #include #include diff --git a/deep_gemm/include/deep_gemm/nvrtc_std.cuh b/deep_gemm/include/deep_gemm/nvrtc_std.cuh new file mode 100644 index 0000000..ca49771 --- /dev/null +++ b/deep_gemm/include/deep_gemm/nvrtc_std.cuh @@ -0,0 +1,69 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef NVRTC_JIT_COMPILATION + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using int32_t = signed int; +using uint32_t = unsigned int; +using int64_t = signed long long; +using uint64_t = unsigned long long; +using cuuint64_t = unsigned long long; + +namespace std +{ +template +struct integral_constant +{ + static constexpr T value = v; + using value_type = T; + using type = integral_constant; // using injected-class-name + + __device__ constexpr operator value_type() const noexcept + { + return value; + } + + __device__ constexpr value_type operator()() const noexcept + { + return value; + } // since c++14 +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +template +struct is_same : false_type +{ +}; + +template +struct is_same : true_type +{ +}; + +template +inline constexpr bool is_same_v = is_same::value; +} // namespace std + +#endif diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index 22731a6..e7e0fcc 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -1,6 +1,9 @@ #pragma once +#ifndef NVRTC_JIT_COMPILATION #include +#endif + #include #include #include diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 9b93af5..8edf35b 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -1,5 +1,6 @@ #pragma once +#ifndef NVRTC_JIT_COMPILATION #include #ifdef __CLION_IDE__ @@ -16,8 +17,12 @@ public: const char *what() const noexcept override { return message.c_str(); } }; +#endif #ifndef DG_HOST_ASSERT +#ifdef NVRTC_JIT_COMPILATION +#define DG_HOST_ASSERT(cond) ((void)0) +#else #define DG_HOST_ASSERT(cond) \ do { \ if (not (cond)) { \ @@ -27,6 +32,7 @@ do { \ } \ } while (0) #endif +#endif #ifndef DG_DEVICE_ASSERT #define DG_DEVICE_ASSERT(cond) \ diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py index eb08b14..999eafb 100644 --- a/deep_gemm/jit/__init__.py +++ b/deep_gemm/jit/__init__.py @@ -1,3 +1,3 @@ from .compiler import get_nvcc_compiler, build -from .template import cpp_format, generate +from .template import generate from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index c17d466..b3eb583 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -3,13 +3,13 @@ import functools import os import re import subprocess +import time import uuid from torch.utils.cpp_extension import CUDA_HOME from typing import Tuple from . import interleave_ffma from .runtime import Runtime, RuntimeCache -from .template import typename_map runtime_cache = RuntimeCache() @@ -94,11 +94,11 @@ def put(path, data, is_binary=False): os.replace(tmp_file_path, path) -def build(name: str, arg_defs: tuple, code: str) -> Runtime: +def build(name: str, code: str) -> Runtime: # Compiler flags cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '-gencode=arch=compute_90a,code=sm_90a', + '-gencode=arch=compute_90a,code=sm_90a', '-cubin', '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases '--diag-suppress=39,174,177,940'] @@ -121,31 +121,36 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime: # Write the code os.makedirs(path, exist_ok=True) - args_path = f'{path}/kernel.args' src_path = f'{path}/kernel.cu' - put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs])) put(src_path, code) - # Compile into a temporary SO file - so_path = f'{path}/kernel.so' - tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so' + # Compile into a temporary CU file + cubin_path = f'{path}/kernel.cubin' + tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin' # Compile command = [get_nvcc_compiler()[0], - src_path, '-o', tmp_so_path, + src_path, '-o', tmp_cubin_path, *flags, *[f'-I{d}' for d in include_dirs]] if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): print(f'Compiling JIT runtime {name} with command {command}') + + start_time = time.time() return_code = subprocess.check_call(command) + end_time = time.time() assert return_code == 0, f'Failed to compile {src_path}' + # Print elapsed time if debug is enabled + elapsed_time = end_time - start_time + print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') + # Interleave FFMA reuse if enable_sass_opt: - interleave_ffma.process(tmp_so_path) + interleave_ffma.process(tmp_cubin_path) - # Atomic replace SO file - os.replace(tmp_so_path, so_path) + # Atomic replace CU file + os.replace(tmp_cubin_path, cubin_path) # Put cache and return runtime_cache[path] = Runtime(path) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 66c370a..e060eeb 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,16 +1,18 @@ -import ctypes import os +import time +from typing import Any, Dict, Optional + +import cuda.bindings.driver as cuda +import cuda.bindings.nvrtc as nvrtc import torch -from typing import Optional - -from .template import map_ctype +from .utils import run_gemm class Runtime: def __init__(self, path: str) -> None: self.path = path self.lib = None - self.args = None + self.kernel = None assert self.is_path_valid(self.path) @@ -21,29 +23,66 @@ class Runtime: return False # Contains all necessary files - files = ['kernel.cu', 'kernel.args', 'kernel.so'] + files = ['kernel.cu', 'kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) - def __call__(self, *args) -> int: - # Load SO file - if self.lib is None or self.args is None: - self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so')) - with open(os.path.join(self.path, 'kernel.args'), 'r') as f: - self.args = eval(f.read()) + def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: + # Load CUBIN + if self.lib is None: + start_time = time.time_ns() + res, lib = cuda.cuLibraryLoadFromFile( + bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to load library: {res}") - # Check args and launch - assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' - cargs = [] - for arg, (name, dtype) in zip(args, self.args): - if isinstance(arg, torch.Tensor): - assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' + res, kernel_count = cuda.cuLibraryGetKernelCount(lib) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to get kernel count: {res}") + + res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to enumerate kernels: {res}") + + for kernel in kernels: + res, kernel_name = cuda.cuKernelGetName(kernel) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to get kernel name: {res}") + if b"fp8" in kernel_name: + self.kernel = kernel + break + + if self.kernel is not None: + self.lib = lib else: - assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' - cargs.append(map_ctype(arg)) + raise Exception("Failed to find fp8 gemm kernel") - return_code = ctypes.c_int(0) - self.lib.launch(*cargs, ctypes.byref(return_code)) - return return_code.value + end_time = time.time_ns() + elapsed_time = (end_time - start_time) / 1000 + print( + f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') + + return run_gemm( + self.kernel, + kwargs['NUM_TMA_MULTICAST'], + kwargs['M'], + kwargs['BLOCK_M'], + kwargs['GMEM_D'], + kwargs['SCALES_B'], + kwargs['GROUPED_LAYOUT'], + kwargs['NUM_SMS'], + kwargs['SMEM_SIZE'], + kwargs['TENSOR_MAP_A'], + kwargs['TENSOR_MAP_B'], + kwargs['TENSOR_MAP_SCALES_A'], + kwargs['TENSOR_MAP_D'], + kwargs['STREAM'], + ) + + def __del__(self) -> None: + if self.lib is not None: + res = cuda.cuLibraryUnload(self.lib)[0] + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to unload library {self.path}: {res}") class RuntimeCache: diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index ead37f5..2a99322 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -1,111 +1,48 @@ -import copy -import ctypes import os -import torch -from typing import Any, Dict, Iterable, Tuple +from typing import Any, Dict -# Name map for Python `eval` -typename_map: Dict[Any, str] = { - **{t: t.__name__ for t in (bool, int, float)}, - torch.int: 'torch.int', - torch.float: 'torch.float', - torch.bfloat16: 'torch.bfloat16', - torch.float8_e4m3fn: 'torch.float8_e4m3fn', - torch.cuda.Stream: 'torch.cuda.Stream', -} +def generate(**kwargs: Dict[str, Any]) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#ifndef NVRTC_JIT_COMPILATION +#define NVRTC_JIT_COMPILATION +#endif -# `ctype` map for Python casting -ctype_map: Dict[Any, Any] = { - **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, - **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)}, -} +#include +#else -# Type map for both Python API and source code usages -genc_map = { - bool: ('bool', 'bool'), - int: ('int', 'int'), - float: ('float', 'float'), - torch.int: ('void*', 'int*'), - torch.float: ('void*', 'float*'), - torch.bfloat16: ('void*', '__nv_bfloat16*'), - torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), - torch.cuda.Stream: ('void*', 'cudaStream_t'), -} +#include +#include +#endif -def map_ctype(value: Any) -> Any: - if hasattr(value, 'data_ptr'): - if value.dtype == torch.int: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.bfloat16: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float16: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float8_e4m3fn: - return ctypes.c_void_p(value.data_ptr()) - else: - return ctypes.c_void_p(value.data_ptr()) +#include +#include +#include - if hasattr(value, 'cuda_stream'): - return ctypes.c_void_p(value.cuda_stream) - - if isinstance(value, bool): - return ctypes.c_bool(value) - elif isinstance(value, int): - return ctypes.c_int(value) - elif isinstance(value, float): - return ctypes.c_float(value) - - return ctype_map[type(value)](value) - - -def cpp_format(template: str, keys: Dict[str, Any]) -> str: - # We don't use `str.format` because it's not safe for C++ {} braces - new_template = copy.deepcopy(template) - for key, value in keys.items(): - value_str = str(value) - if isinstance(value, bool): - value_str = value_str.lower() - new_template = new_template.replace(f'{{{key}}}', f'{value_str}') - return new_template - - -def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str: - # Common prefix - code = '// DeepGEMM auto-generated JIT CUDA source file\n\n' - - # Includes - preload_sys_includes = ['', '', '', ''] - preload_package_includes = ['"cutlass/cutlass.h"'] - - assert isinstance(includes, list) or isinstance(includes, tuple) - sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')]))) - package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')]))) - code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n' - code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n' - - # Function signature - raw = '__raw_' - get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n - code += f'extern "C" void launch(' - code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ]) - code += ') {\n' - - # Cast raw types - code += ' // Cast raw types (if needed)\n' - for arg_name, arg_type in arg_defs: - if genc_map[arg_type][0] != genc_map[arg_type][1]: - code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n' - - # Function body - code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')]) - - # End the function - code += '}\n\n' +namespace deep_gemm {{ +__global__ void dummy_kernel() {{ + void *ptr = (void *)&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >; +}} +}} +''' # Debug print if os.getenv('DG_JIT_DEBUG', None): diff --git a/deep_gemm/jit/utils.py b/deep_gemm/jit/utils.py new file mode 100644 index 0000000..f1feefa --- /dev/null +++ b/deep_gemm/jit/utils.py @@ -0,0 +1,164 @@ +import ctypes +from enum import Enum +from typing import Any, Dict, Tuple + +import cuda.bindings.driver as cuda +import torch + + +class Layout(Enum): + RowMajor = 0 + ColMajor = 1 + + +class GemmType(Enum): + Normal = 0 + GroupedContiguous = 1 + GroupedMasked = 2 + + def __str__(self) -> str: + return { + 0: 'Normal', + 1: 'GroupedContiguous', + 2: 'GroupedMasked', + }[self.value] + + +typename_map: Dict[Any, str] = { + torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, + torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, + torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, + torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, + torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +} + +swizzle_map = { + 128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + 64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, + 32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, + 0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, +} + +def get_num_math_warpgroups(block_m: int) -> int: + return 1 if block_m == 64 else 2 + +def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: + assert num_math_threads_per_group == 128, "Only support 128 threads per math group" + return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads + + +def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap: + tensor_dtype = typename_map[global_address.dtype] + res, tensor_map = cuda.cuTensorMapEncodeTiled( + tensor_dtype, + 2, # tensor rank + global_address.data_ptr(), + gmem_dim, + (stride_in_bytes,), # global strides + smem_dim, + (cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides + cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle_type, + cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to encode tensor map: {res}") + + return tensor_map + + +def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap: + if layout == Layout.RowMajor: + gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows)) + smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) + else: + gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols)) + smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) + + +def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k) + + +def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n) + + +def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap: + # Swizzling requires the inner box dim less or equal than `kSwizzleDMode` + # bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode]) + + +def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap: + # Make TMA aligned to 16 bytes + kAlignment = 16 / global_address.element_size() + shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment + + return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + +def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0] + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"Failed to set max dynamic shared memory size: {res}") + + attr_val = cuda.CUlaunchAttributeValue() + attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cuda.CUlaunchAttribute() + attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cuda.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = num_sms + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = smem_size + config.hStream = stream + + kernelValues = ( + gmem_d.data_ptr(), + scales_b.data_ptr(), + grouped_layout.data_ptr(), + shape_m, + tensor_map_a, + tensor_map_b, + tensor_map_scales_a, + tensor_map_d, + ) + kernelTypes = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + None, + None, + None, + None, + ) + + return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index f52dc48..e55baef 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -3,40 +3,11 @@ import torch from functools import lru_cache from typing import Tuple +from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc + from .tuner import jit_tuner from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"', ) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto BLOCK_K = 128; -constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; -constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; -constexpr auto kNumGroups = 1; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; -constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; - -// Make a templated GEMM -using gemm_t = Gemm; - -// Launch kernel -auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); -gemm_t::run(out, rhs_scales, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: if num_tma_multicast == 1: @@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: # Try swizzle first, as it does not waste shared memory swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 + block_n_padding = get_block_n_padding_for_smem_d( + block_n) if swizzle_mode == 0 else 0 smem_d = block_m * (block_n + block_n_padding) * 2 smem_a_per_stage = block_m * block_k @@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += ceil_div(smem_scales_b * (1 if block_k % + block_n == 0 else 2), 8) * 8 smem_size += smem_barrier # Swizzle and padding are not compatible @@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) + (144, 160, ) - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) + def fix_wave_saturate(x): return num_sms if x == 0 else x + + def get_num_waves(bm, bn): return (ceil_div( + ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) + + def get_last_wave_util(bm, bn): return fix_wave_saturate( + (ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) # Decide block sizes by waves best_block_m, best_block_n = None, None @@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # NOTES: the block sizes can not be too large, so at least one dim less than 128 for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) + num_waves, best_num_waves = get_num_waves( + block_m, block_n), get_num_waves(best_block_m, best_block_n) if best_block_m is None or best_block_n is None: success = True elif num_waves < best_num_waves: @@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, success |= block_n == best_block_n and block_m < best_block_m # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better success |= block_m != best_block_m and block_n > best_block_n - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) + best_block_m, best_block_n = (block_m, block_n) if success else ( + best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None # Always pick the longest one @@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = (4, 3) for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) + best_smem_config = get_smem_config( + num_stages, k, best_block_m, best_block_n) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break @@ -158,8 +138,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + num_min_sms = ceil_div(ceil_div(m, best_block_m) * + ceil_div(n, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div( + num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] assert num_min_sms <= num_sms return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config @@ -210,11 +192,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], return # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, k, 1, num_sms) + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.Normal, lhs, m, k, block_m, block_k) + tensor_map_b = make_2d_tma_b_desc( + GemmType.Normal, rhs, k, n, block_k, block_n) + tensor_map_d = make_2d_tma_d_desc( + GemmType.Normal, smem_config[1], out, m, n, block_m, block_n) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.Normal, lhs_scales, m, k, block_m, block_k) + + kwargs = { + 'GEMM_TYPE': GemmType.Normal, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'NUM_GROUPS': 1, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -223,14 +236,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 3b518c9..22281b1 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,41 +1,12 @@ import torch from typing import Tuple +from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc + from .gemm import get_best_configs, get_block_n_padding_for_smem_d from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"', ) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto BLOCK_K = 128; -constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; -constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; -constexpr auto kNumGroups = {NUM_GROUPS}; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; -constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; - -// Make a templated grouped GEMM -using gemm_t = Gemm; - -// Launch kernel -auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); -gemm_t::run(out, rhs_scales, grouped_layout, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], @@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten return # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) - args = (lhs, lhs_scales, rhs, rhs_scales, out, - m_indices, m, num_groups, - torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, k, 1, num_sms, is_grouped_contiguous=True) + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc( + GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) + tensor_map_d = make_2d_tma_d_desc( + GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) + + kwargs = { + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': m_indices, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': 'GroupedContiguous'}, + 'GEMM_TYPE': GemmType.GroupedContiguous}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), - ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) # Extra checks for TMA store if num_groups > 1 and m > block_m: assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' - args = (lhs, lhs_scales, rhs, rhs_scales, out, - masked_m, m, - torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc( + GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) + tensor_map_d = make_2d_tma_d_desc( + GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) + + kwargs = { + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': masked_m, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': 'GroupedMasked'}, + 'GEMM_TYPE': GemmType.GroupedMasked}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), - ('grouped_layout', torch.int32), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py index 6ed6749..4dcd6cd 100644 --- a/deep_gemm/jit_kernels/tuner.py +++ b/deep_gemm/jit_kernels/tuner.py @@ -3,15 +3,16 @@ import os import torch from typing import Any, Dict -from ..jit import build, cpp_format, generate, Runtime +import cuda.bindings.driver as cuda + +from ..jit import build, generate, Runtime class JITTuner: def __init__(self) -> None: self.tuned = {} - def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, - includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime: + def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime: # NOTES: we always assume the space and template will not change # We also assume the GPU device will not be changed # NOTES: the function must have no accumulated side effects @@ -26,7 +27,7 @@ class JITTuner: print(f'Auto-tuning JIT kernel {name} with keys {keys}') assert signature not in self.tuned - assert args is not None + assert kwargs is not None space = (dict(), ) if len(space) == 0 else space kernels = [] @@ -34,30 +35,31 @@ class JITTuner: assert isinstance(tuned_keys, dict) full_keys = copy.deepcopy(keys) full_keys.update(tuned_keys) - code = generate(includes, arg_defs, cpp_format(template, full_keys)) - - # Illegal build must raise errors - kernels.append((build(name, arg_defs, code), tuned_keys)) + code = generate(**kwargs, **full_keys) + kernels.append((build(name, code), full_keys)) best_runtime, best_time, best_keys = None, None, None for runtime, tuned_keys in kernels: if len(space) > 1: # Check kernel validity - return_code = runtime(*args) - if return_code != 0: + return_code = runtime(**tuned_keys, **kwargs) + if return_code != cuda.CUresult.CUDA_SUCCESS: # Pass illegal kernels, e.g. insufficient shared memory capacity if os.getenv('DG_JIT_DEBUG', None): - print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') + print( + f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') continue # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() - torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') + torch.empty(int(256e6 // 4), dtype=torch.int, + device='cuda').zero_() + torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn( + (8192, 8192), dtype=torch.float, device='cuda') start_event.record() for i in range(20): - assert runtime(*args) == 0 + assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS end_event.record() end_event.synchronize() elapsed_time = start_event.elapsed_time(end_event) @@ -68,14 +70,16 @@ class JITTuner: if best_time is None or elapsed_time < best_time: best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys if os.getenv('DG_JIT_DEBUG', None): - print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') + print( + f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' # Cache the best runtime and return if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): - print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') - self.tuned[signature] = best_runtime - return best_runtime + print( + f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') + self.tuned[signature] = (best_runtime, best_keys) + return best_runtime, best_keys jit_tuner = JITTuner()