From 816b39053adb1921e6517c95847e14a4cc2eddbd Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:14:21 +0800 Subject: [PATCH 1/5] Refactor launch-related structures --- README.md | 2 +- .../include/deep_gemm/fp8_wgrad_gemm.cuh | 8 +- deep_gemm/jit/runtime.py | 9 +- deep_gemm/jit_kernels/gemm.py | 60 ++--- deep_gemm/jit_kernels/m_grouped_gemm.py | 100 ++++----- deep_gemm/jit_kernels/runtime.py | 211 +++++++----------- deep_gemm/jit_kernels/tuner.py | 82 ------- deep_gemm/jit_kernels/wgrad_gemm.py | 93 +++----- tests/test_jit.py | 30 +-- 9 files changed, 199 insertions(+), 396 deletions(-) delete mode 100644 deep_gemm/jit_kernels/tuner.py diff --git a/README.md b/README.md index 170c271..5f5388b 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful: - Post optimization - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default - Heuristic selection - - `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default + - `DG_PRINT_HEURISTIC`: `0` or `1`, print selected configs for each shape, `0` by default - Testing - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default diff --git a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh index 4bf179e..eb2282d 100644 --- a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh @@ -18,7 +18,7 @@ namespace deep_gemm { template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) @@ -127,7 +127,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k, struct DivisibleK {}; struct NotDivisibleK {}; auto launch_k_iterations = [&](const auto& func) { - if constexpr (kLastStages == 0) { + if constexpr (kNumLastStages == 0) { for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) func(k_iter, DivisibleK{}); } else { @@ -155,7 +155,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](int k_iter, auto type) { constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); // Assign TMA multicast number into A and B @@ -244,7 +244,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k, // Launch MMAs launch_k_iterations([&](int k_iter, auto type) { constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); #pragma unroll diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 74ceff5..041e23f 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -8,11 +8,10 @@ from torch.utils.cpp_extension import CUDA_HOME class Runtime: - def __init__(self, path: str, args: List[str] = None) -> None: + def __init__(self, path: str) -> None: self.path = path self.lib = None self.kernel = None - self.args = args assert self.is_path_valid(self.path) @staticmethod @@ -48,8 +47,10 @@ class Runtime: command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) assert result.returncode == 0 + illegal_names = ['vprintf', '__instantiate_kernel', '__internal'] + check_illegal = lambda line: any([name in line for name in illegal_names]) kernel_names = [line.split()[-1] for line in result.stdout.splitlines() - if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line] + if line.startswith('STT_FUNC') and not check_illegal(line)] assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' # Load kernel from the library @@ -62,7 +63,7 @@ class Runtime: print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') # noinspection PyArgumentList - return self.launch(self.kernel, *[kwargs[arg] for arg in self.args]) + return self.launch(self.kernel, kwargs) def __del__(self) -> None: if self.lib is not None: diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index c782f28..3adef72 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -3,11 +3,11 @@ import torch from functools import lru_cache from typing import Tuple +from ..jit import build from .runtime import ( FP8GemmRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_a_desc) -from .tuner import jit_tuner + make_2d_tma_d_desc, make_2d_tma_scales_desc) from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout @@ -18,7 +18,6 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in def get_swizzle_mode(block_n: int) -> int: - # TODO: remove some candidates if slow elem_size = 2 for mode_bytes in (128, 64, 32): if (block_n * elem_size) % mode_bytes == 0: @@ -187,13 +186,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], assert out.dtype == torch.bfloat16 assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 - lhs_stride = lhs.stride(0) - rhs_stride = rhs.stride(0) - out_stride = out.stride(0) - - # The stride(0) of LHS, RHS, and output must be aligned to 16 bytes - assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0 - # LHS scales must be transposed for TMA loads, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) @@ -208,29 +200,30 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - m, n, k, 1, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) block_k = 128 num_tma_threads = 128 num_math_threads_per_group = 128 - tensor_map_a = make_2d_tma_a_desc( - GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) - tensor_map_b = make_2d_tma_b_desc( - GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) - tensor_map_d = make_2d_tma_d_desc( - GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride) - tensor_map_scales_a = make_2d_tma_scales_a_desc( - GemmType.Normal, lhs_scales, m, k, block_m, block_k) + tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) + tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) + tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) kwargs = { + # Templated arguments 'GEMM_TYPE': GemmType.Normal, 'NUM_TMA_THREADS': num_tma_threads, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, + 'M': m, 'N': n, 'K': aligned_k, 'NUM_GROUPS': 1, - 'BLOCK_K': block_k, - 'GMEM_D': out, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + # Runtime arguments 'SCALES_B': rhs_scales, 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), 'NUM_SMS': num_sms, @@ -240,21 +233,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 'TENSOR_MAP_D': tensor_map_d, 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index } - - runtime, best_keys = jit_tuner.compile_and_tune( - name='gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': aligned_k, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, - space=(), - kwargs=kwargs, - runtime_cls=FP8GemmRuntime, - ) - # Run the kernel - runtime(**best_keys, **kwargs) + # Generate, build and run the kernel + code = FP8GemmRuntime.generate(**kwargs) + runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index e8c1922..ef1a088 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,12 +1,12 @@ import torch from typing import Tuple +from ..jit import build from .gemm import get_best_configs from .runtime import ( FP8GemmRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_a_desc) -from .tuner import jit_tuner + make_2d_tma_d_desc, make_2d_tma_scales_desc) from .utils import get_col_major_tma_aligned_tensor, get_num_sms @@ -69,21 +69,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten num_tma_threads = 128 num_math_threads_per_group = 128 - tensor_map_a = make_2d_tma_a_desc( - GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_desc( - GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) - tensor_map_d = make_2d_tma_d_desc( - GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_a_desc( - GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) + tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) kwargs = { + # Templated arguments 'NUM_TMA_THREADS': num_tma_threads, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, - 'BLOCK_K': block_k, - 'GMEM_D': out, + 'M': m, 'N': n, 'K': k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], + 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': GemmType.GroupedContiguous, + # Runtime arguments 'SCALES_B': rhs_scales, 'GROUPED_LAYOUT': m_indices, 'NUM_SMS': num_sms, @@ -93,25 +97,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 'TENSOR_MAP_D': tensor_map_d, 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index } - runtime, best_keys = jit_tuner.compile_and_tune( - name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': GemmType.GroupedContiguous}, - space=(), - kwargs=kwargs, - runtime_cls=FP8GemmRuntime, - ) - - # Run the kernel - runtime(**best_keys, **kwargs) + # Generate, build and run the kernel + code = FP8GemmRuntime.generate(**kwargs) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime(**kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -176,21 +168,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] num_tma_threads = 128 num_math_threads_per_group = 128 - tensor_map_a = make_2d_tma_a_desc( - GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_desc( - GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) - tensor_map_d = make_2d_tma_d_desc( - GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_a_desc( - GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) + tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) kwargs = { + # Templated arguments 'NUM_TMA_THREADS': num_tma_threads, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, - 'BLOCK_K': block_k, - 'GMEM_D': out, + 'M': m, 'N': n, 'K': k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], + 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': GemmType.GroupedMasked, + # Runtime arguments 'SCALES_B': rhs_scales, 'GROUPED_LAYOUT': masked_m, 'NUM_SMS': num_sms, @@ -200,22 +196,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 'TENSOR_MAP_D': tensor_map_d, 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index } - runtime, best_keys = jit_tuner.compile_and_tune( - name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': GemmType.GroupedMasked}, - space=(), - kwargs=kwargs, - runtime_cls=FP8GemmRuntime, - ) - - # Run the kernel - runtime(**best_keys, **kwargs) + # Generate, build and run the kernel + code = FP8GemmRuntime.generate(**kwargs) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 1ac0fe1..a7b0e66 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -5,14 +5,10 @@ import torch import cuda.bindings.driver as cbd from typing import Any, Dict, Tuple +from .utils import get_tma_aligned_size from ..jit.runtime import Runtime -class Layout(enum.Enum): - RowMajor = 0 - ColMajor = 1 - - class GemmType(enum.Enum): Normal = 0 GroupedContiguous = 1 @@ -61,19 +57,18 @@ def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads -def make_2d_tma_copy_desc(global_address: torch.Tensor, - gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], - stride_in_bytes: cbd.cuuint64_t, - smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], +def make_2d_tma_copy_desc(t: torch.Tensor, + gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t, + smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: - tensor_dtype = tmap_type_map[global_address.dtype] + tensor_dtype = tmap_type_map[t.dtype] res, tensor_map = cbd.cuTensorMapEncodeTiled( tensor_dtype, 2, - global_address.data_ptr(), - gmem_dim, - (stride_in_bytes, ), - smem_dim, + t.data_ptr(), + gmem_dims, + (gmem_outer_stride,), + smem_dims, (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, @@ -86,90 +81,61 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor, return tensor_map -def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, - gmem_rows: int, gmem_cols: int, gmem_stride: int, - smem_rows: int, smem_cols: int, +def make_2d_tma_desc(t: torch.Tensor, + gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int, + smem_inner_dim: int, smem_outer_dim: int, swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: - if layout == Layout.RowMajor: - gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows)) - smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows)) - return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type) - else: - gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols)) - smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols)) - return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type) + gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim)) + smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim)) + return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type) -def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, - shape_m: int, shape_k: int, +def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor, + shape_m: int, shape_k: int, m_stride: int, block_m: int, block_k: int, - num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap: - a_stride = shape_k if a_stride == 0 else a_stride - return make_2d_tma_desc(global_address, Layout.RowMajor, - shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride, - block_m, block_k) + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, + block_k, block_m) -def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, - shape_k: int, shape_n: int, - block_k: int, block_n: int, - num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap: - b_stride = shape_k if b_stride == 0 else b_stride - return make_2d_tma_desc(global_address, Layout.ColMajor, - shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride, +def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor, + shape_n: int, shape_k: int, n_stride: int, + block_n: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride, block_k, block_n) -def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor, - shape_m: int, shape_n: int, +def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor, + shape_m: int, shape_n: int, m_stride: int, block_m: int, block_n: int, - num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap: + num_groups: int, + swizzle_mode: int) -> cbd.CUtensorMap: # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - d_stride = shape_n if d_stride == 0 else d_stride - return make_2d_tma_desc(global_address, Layout.RowMajor, - shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride, - block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), + return make_2d_tma_desc(t, + shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, + block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m, swizzle_type_map[swizzle_mode]) -def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: +def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor, + shape_mn: int, shape_k: int, + block_mn: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: # Make TMA aligned to 16 bytes - tma_alignment = 16 / global_address.element_size() - shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment - - return make_2d_tma_desc(global_address, Layout.ColMajor, - shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m, - block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) - - -def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: - # Make TMA aligned to 16 bytes - tma_alignment = 16 / global_address.element_size() - shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment - - return make_2d_tma_desc(global_address, Layout.ColMajor, - shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, - block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + shape_mn = get_tma_aligned_size(shape_mn, t.element_size()) + return make_2d_tma_desc(t, + shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn, + block_mn, 1, + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) class FP8GemmRuntime(Runtime): def __init__(self, path: str) -> None: - super().__init__(path, [ - 'NUM_TMA_MULTICAST', - 'M', - 'BLOCK_M', - 'GMEM_D', - 'SCALES_B', - 'GROUPED_LAYOUT', - 'NUM_SMS', - 'SMEM_SIZE', - 'TENSOR_MAP_A', - 'TENSOR_MAP_B', - 'TENSOR_MAP_SCALES_A', - 'TENSOR_MAP_D', - 'STREAM', - ]) + super().__init__(path) @staticmethod def generate(**kwargs) -> str: @@ -213,21 +179,16 @@ static void __instantiate_kernel() {{ # noinspection PyMethodOverriding @staticmethod - def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, - block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, - grouped_layout: torch.Tensor, num_sms: int, smem_size: int, - tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, - tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap, - stream: cbd.CUstream) -> cbd.CUresult: + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: num_tma_threads = 128 num_math_threads_per_group = 128 - res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to set max dynamic shared memory size: {res}') + result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] attr_val.clusterDim.y = 1 attr_val.clusterDim.z = 1 attr = cbd.CUlaunchAttribute() @@ -237,23 +198,23 @@ static void __instantiate_kernel() {{ config = cbd.CUlaunchConfig() config.numAttrs = 1 config.attrs = [attr] - config.gridDimX = num_sms + config.gridDimX = kwargs['NUM_SMS'] config.gridDimY = 1 config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) config.blockDimY = 1 config.blockDimZ = 1 - config.sharedMemBytes = smem_size - config.hStream = stream + config.sharedMemBytes = kwargs['SMEM_SIZE'] + config.hStream = kwargs['STREAM'] arg_values = ( - scales_b.data_ptr(), - grouped_layout.data_ptr(), - shape_m, - tensor_map_a, - tensor_map_b, - tensor_map_scales_a, - tensor_map_d, + kwargs['SCALES_B'].data_ptr(), + kwargs['GROUPED_LAYOUT'].data_ptr(), + kwargs['M'], + kwargs['TENSOR_MAP_A'], + kwargs['TENSOR_MAP_B'], + kwargs['TENSOR_MAP_SCALES_A'], + kwargs['TENSOR_MAP_D'], ) arg_types = ( ctypes.c_void_p, @@ -269,20 +230,7 @@ static void __instantiate_kernel() {{ class FP8WGradGemmRuntime(Runtime): def __init__(self, path: str) -> None: - super().__init__(path, [ - 'NUM_TMA_MULTICAST', - 'K', - 'BLOCK_M', - 'GMEM_D', - 'NUM_SMS', - 'SMEM_SIZE', - 'TENSOR_MAP_A', - 'TENSOR_MAP_B', - 'TENSOR_MAP_SCALES_A', - 'TENSOR_MAP_SCALES_B', - 'TENSOR_MAP_D', - 'STREAM', - ]) + super().__init__(path) @staticmethod def generate(**kwargs) -> str: @@ -309,7 +257,7 @@ static void __instantiate_kernel() {{ {kwargs['BLOCK_N']}, {kwargs['BLOCK_K']}, {kwargs['NUM_STAGES']}, - {kwargs['LAST_STAGES']}, + {kwargs['NUM_LAST_STAGES']}, {kwargs['NUM_TMA_THREADS']}, {kwargs['NUM_MATH_THREADS_PER_GROUP']}, {kwargs['NUM_TMA_MULTICAST']}, @@ -323,21 +271,16 @@ static void __instantiate_kernel() {{ # noinspection PyMethodOverriding @staticmethod - def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int, - block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int, - tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, - tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap, - tensor_map_d: cbd.CUtensorMap, - stream: cbd.CUstream) -> cbd.CUresult: + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: num_tma_threads = 128 num_math_threads_per_group = 128 - res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to set max dynamic shared memory size: {res}') + result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] attr_val.clusterDim.y = 1 attr_val.clusterDim.z = 1 attr = cbd.CUlaunchAttribute() @@ -347,22 +290,22 @@ static void __instantiate_kernel() {{ config = cbd.CUlaunchConfig() config.numAttrs = 1 config.attrs = [attr] - config.gridDimX = num_sms + config.gridDimX = kwargs['NUM_SMS'] config.gridDimY = 1 config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) config.blockDimY = 1 config.blockDimZ = 1 - config.sharedMemBytes = smem_size - config.hStream = stream + config.sharedMemBytes = kwargs['SMEM_SIZE'] + config.hStream = kwargs['STREAM'] arg_values = ( - shape_k, - tensor_map_a, - tensor_map_b, - tensor_map_scales_a, - tensor_map_scales_b, - tensor_map_d, + kwargs['K'], + kwargs['TENSOR_MAP_A'], + kwargs['TENSOR_MAP_B'], + kwargs['TENSOR_MAP_SCALES_A'], + kwargs['TENSOR_MAP_SCALES_B'], + kwargs['TENSOR_MAP_D'], ) arg_types = ( ctypes.c_uint32, diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py deleted file mode 100644 index 4fc9283..0000000 --- a/deep_gemm/jit_kernels/tuner.py +++ /dev/null @@ -1,82 +0,0 @@ -import copy -import os -import torch -import cuda.bindings.driver as cbd -from typing import Any, Callable, Dict, Type, Tuple - -from ..jit import build, Runtime - - -class JITTuner: - def __init__(self) -> None: - self.tuned = {} - - def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, - kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: - # NOTES: we always assume the space, template and GPU devices will not change - # NOTES: the function must have no accumulated side effects - keys = {k: keys[k] for k in sorted(keys.keys())} - signature = (name, f'{keys}') - if signature in self.tuned: - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Using cached JIT kernel {name} with keys {keys}') - return self.tuned[signature] - - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Auto-tuning JIT kernel {name} with keys {keys}') - - assert signature not in self.tuned - assert kwargs is not None - space = (dict(), ) if len(space) == 0 else space - - kernels = [] - for tuned_keys in space: - assert isinstance(tuned_keys, dict) - full_keys = copy.deepcopy(keys) - full_keys.update(tuned_keys) - code = runtime_cls.generate(**kwargs, **full_keys) - kernels.append((build(name, code, runtime_cls), full_keys)) - - # TODO: fix tuning with space > 1 - best_runtime, best_time, best_keys = None, None, None - for runtime, tuned_keys in kernels: - if len(space) > 1: - # Check kernel validity - return_code = runtime(**tuned_keys, **kwargs) - if return_code != cbd.CUresult.CUDA_SUCCESS: - # Pass illegal kernels, e.g., insufficient shared memory capacity - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') - continue - - # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.empty(int(256e6 // 4), dtype=torch.int, - device='cuda').zero_() - torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn( - (8192, 8192), dtype=torch.float, device='cuda') - start_event.record() - for i in range(20): - assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS - end_event.record() - end_event.synchronize() - elapsed_time = start_event.elapsed_time(end_event) - else: - elapsed_time = 0 - - # Compare if better - if best_time is None or elapsed_time < best_time: - best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') - assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' - - # Cache the best runtime and return - if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)): - print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') - self.tuned[signature] = (best_runtime, best_keys) - return best_runtime, best_keys - - -jit_tuner = JITTuner() diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 7dd5fc5..dea91e2 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -1,18 +1,18 @@ import torch from typing import List, Tuple +from ..jit import build from .runtime import ( FP8WGradGemmRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc) + make_2d_tma_d_desc, make_2d_tma_scales_desc) from .gemm import get_best_configs -from .tuner import jit_tuner from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], - out: Tuple[torch.Tensor, torch.Tensor]): + out: torch.Tensor): """ Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. Results will be accumulated into the output tensor. @@ -21,8 +21,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. + The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format. + If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, @@ -47,13 +47,6 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], assert out.dtype == torch.float assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 - lhs_stride = lhs.stride(0) - rhs_stride = rhs.stride(0) - out_stride = out.stride(0) - - # The stride(0) of LHS, RHS, and output must be aligned to 16 bytes - assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0 - # LHS and RHS scales must be transposed for TMA load # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels if lhs_scales.shape == ((k + 127) // 128, m): @@ -81,30 +74,30 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) - last_stages = (k + 127) // 128 % num_stages + num_last_stages = (k + 127) // 128 % num_stages block_k = 128 num_tma_threads = 128 num_math_threads_per_group = 128 - tensor_map_a = make_2d_tma_a_desc( - GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) - tensor_map_b = make_2d_tma_b_desc( - GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) - tensor_map_d = make_2d_tma_d_desc( - GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride) - tensor_map_scales_a = make_2d_tma_scales_a_desc( - GemmType.Normal, lhs_scales, m, k, block_m, block_k) - tensor_map_scales_b = make_2d_tma_scales_b_desc( - GemmType.Normal, rhs_scales, n, k, block_n, block_k) + tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) + tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) + tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) + tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1) kwargs = { + # Templated arguments 'GEMM_TYPE': GemmType.Normal, 'NUM_TMA_THREADS': num_tma_threads, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'K': aligned_k, + 'M': m, 'N': n, 'K': aligned_k, 'NUM_GROUPS': 1, - 'BLOCK_K': block_k, - 'GMEM_D': out, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'NUM_STAGES': num_stages, + 'NUM_LAST_STAGES': num_last_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + # Runtime arguments 'NUM_SMS': num_sms, 'SMEM_SIZE': smem_config[0], 'TENSOR_MAP_A': tensor_map_a, @@ -113,23 +106,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'TENSOR_MAP_SCALES_B': tensor_map_scales_b, 'TENSOR_MAP_D': tensor_map_d, 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index } - runtime, best_keys = jit_tuner.compile_and_tune( - name='wgrad_gemm_fp8_fp8_fp32_nt', - keys={'M': m, 'N': n, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'NUM_STAGES': num_stages, - 'LAST_STAGES': last_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, - space=(), - kwargs=kwargs, - runtime_cls=FP8WGradGemmRuntime, - ) - - # Run the kernel - runtime(**best_keys, **kwargs) + # Generate, build and run the kernel + code = FP8WGradGemmRuntime.generate(**kwargs) + runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime) + runtime(**kwargs) def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -144,16 +127,16 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], This function handles multiple batches with varying k-dimensions, processing each batch sequentially. Each batch's LHS, RHS, and output tensors must be contiguous. The RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. + The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format. Arguments: - lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, + lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. - the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, + The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, representing the per-128-channel scaling factors. - rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, + rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. - the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, + The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, representing the per-128-channel scaling factors. out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. batch_sizes: A list of integers specifying the k-dimension for each batch. @@ -164,16 +147,14 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], lhs_offset, rhs_offset, scales_offset = 0, 0, 0 - for idx in range(num_batches): - k = batch_sizes[idx] - A = lhs[lhs_offset:lhs_offset + m * k].view(m, k) - B = rhs[rhs_offset:rhs_offset + n * k].view(n, k) - A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] - B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] - D = out[idx] - - wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D) + for i in range(num_batches): + k = batch_sizes[i] + lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k) + rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k) + lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] + rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] + wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i]) lhs_offset += m * k rhs_offset += n * k - scales_offset += (k + 127) // 128 \ No newline at end of file + scales_offset += (k + 127) // 128 diff --git a/tests/test_jit.py b/tests/test_jit.py index fbd84e1..a1bf583 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -2,6 +2,7 @@ import ctypes import os import torch import cuda.bindings.driver as cbd +from typing import Any, Dict from deep_gemm import jit @@ -12,12 +13,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') class VectorAddRuntime(jit.Runtime): def __init__(self, path: str) -> None: - super().__init__(path, [ - 'A', - 'B', - 'C', - 'STREAM', - ]) + super().__init__(path) @staticmethod def generate(**kwargs) -> str: @@ -46,27 +42,25 @@ static void __instantiate_kernel() {{ # noinspection PyShadowingNames,PyMethodOverriding @staticmethod - def launch(kernel: cbd.CUkernel, - a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, - stream: cbd.CUstream) -> cbd.CUresult: - assert a.shape == b.shape == c.shape - assert a.device == b.device == c.device - assert a.dim() == 1 + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: + assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape + assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device + assert kwargs['A'].dim() == 1 config = cbd.CUlaunchConfig() - config.gridDimX = (a.numel() + 127) // 128 + config.gridDimX = (kwargs['A'].numel() + 127) // 128 config.gridDimY = 1 config.gridDimZ = 1 config.blockDimX = 128 config.blockDimY = 1 config.blockDimZ = 1 - config.hStream = stream + config.hStream = kwargs['STREAM'] arg_values = ( - a.data_ptr(), - b.data_ptr(), - c.data_ptr(), - a.numel(), + kwargs['A'].data_ptr(), + kwargs['B'].data_ptr(), + kwargs['C'].data_ptr(), + kwargs['A'].numel(), ) arg_types = ( ctypes.c_void_p, From 4373af2e82ab083f1f6f50bd8add978fee7f23e6 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:36:40 +0800 Subject: [PATCH 2/5] Add `DG_PRINT_CONFIGS` --- README.md | 2 +- deep_gemm/jit/compiler.py | 14 +++++++------- deep_gemm/jit/runtime.py | 16 ++++++++++++++-- deep_gemm/jit_kernels/gemm.py | 2 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 4 ++-- deep_gemm/jit_kernels/wgrad_gemm.py | 2 +- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5f5388b..db89b2d 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful: - Post optimization - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default - Heuristic selection - - `DG_PRINT_HEURISTIC`: `0` or `1`, print selected configs for each shape, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default - Testing - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 2ab6b25..54e3ab2 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -5,7 +5,7 @@ import re import subprocess import time import uuid -from typing import List, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import cuda.bindings import cuda.bindings.nvrtc as nvrtc @@ -128,7 +128,7 @@ class Compiler: return [get_jit_include_dir()] @classmethod - def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: + def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: # Compiler flags flags = cls.flags() @@ -140,7 +140,7 @@ class Compiler: # Check runtime cache or file system hit global runtime_cache - cached_runtime = runtime_cache.get(path, runtime_cls) + cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs) if cached_runtime is not None: if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Using cached JIT runtime {name} during build') @@ -166,8 +166,8 @@ class Compiler: os.replace(tmp_cubin_path, cubin_path) # Put cache and return - runtime = runtime_cls(path) - runtime_cache[path] = runtime + runtime = runtime_cache.get(path, runtime_cls, name, kwargs) + assert runtime is not None return runtime @@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler): assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' -def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: +def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler - return compiler_cls.build(name, code, runtime_cls=runtime_cls) + return compiler_cls.build(name, code, runtime_cls, kwargs) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 041e23f..ffcd0b3 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,9 +1,11 @@ +import copy import os import subprocess import time +import torch import cuda.bindings.driver as cbd -from typing import List, Optional, Type +from typing import Any, Dict, Optional, Type from torch.utils.cpp_extension import CUDA_HOME @@ -79,13 +81,23 @@ class RuntimeCache: def __setitem__(self, path: str, runtime: Runtime) -> None: self.cache[path] = runtime - def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]: + def get(self, path: str, runtime_cls: Type[Runtime], + name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): + # Print heuristic for the first time + if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))): + simplified_kwargs = dict() + for key, value in kwargs.items(): + value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value + value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value + simplified_kwargs[key] = value + print(f'Put kernel {name} with {simplified_kwargs} into runtime cache') + runtime = runtime_cls(path) self.cache[path] = runtime return runtime diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 3adef72..343e84a 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -238,5 +238,5 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Generate, build and run the kernel code = FP8GemmRuntime.generate(**kwargs) - runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index ef1a088..73fd2f1 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -102,7 +102,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Generate, build and run the kernel code = FP8GemmRuntime.generate(**kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) @@ -201,5 +201,5 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Generate, build and run the kernel code = FP8GemmRuntime.generate(**kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index dea91e2..8a38578 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -111,7 +111,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Generate, build and run the kernel code = FP8WGradGemmRuntime.generate(**kwargs) - runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime) + runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) runtime(**kwargs) From 350989eef34bc0e48a33a7fcae4656ebc53fe4c0 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:48:32 +0800 Subject: [PATCH 3/5] Unify `ceil_div`s --- deep_gemm/jit_kernels/gemm.py | 8 ++--- deep_gemm/jit_kernels/m_grouped_gemm.py | 10 +++--- deep_gemm/jit_kernels/wgrad_gemm.py | 42 ++++++++++++------------- tests/test_core.py | 8 ++--- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 343e84a..5f7a123 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Type and shape checks assert m == m_ and n == n_ and k == k_ assert n > 0 and k > 0 - assert lhs_scales.shape == (m, (k + 127) // 128) - assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) + assert lhs_scales.shape == (m, ceil_div(k, 128)) + assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 # LHS scales must be transposed for TMA loads, but not for RHS scales - # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert rhs_scales.is_contiguous() @@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], return # K must be aligned to 128 - aligned_k = (k + 127) // 128 * 128 + aligned_k = ceil_div(k, 128) * 128 # Auto-tuning with compilation num_sms = get_num_sms() diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 73fd2f1..c2f2d93 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -7,7 +7,7 @@ from .runtime import ( FP8GemmRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import get_col_major_tma_aligned_tensor, get_num_sms +from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Type and shape checks assert m == m_ == m__ and k == k_ and n == n_ - assert lhs_scales.shape == (m, (k + 127) // 128) - assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) + assert lhs_scales.shape == (m, ceil_div(k, 128)) + assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 @@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] assert num_groups == num_groups_ == num_groups__ == num_groups___ assert m == m_ and n == n_ and k == k_ assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) - assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) + assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) + assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 8a38578..658f005 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -7,7 +7,7 @@ from .runtime import ( make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_desc) from .gemm import get_best_configs -from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size +from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -40,41 +40,39 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Type and shape checks assert m == m_ and n == n_ and k == k_ assert n > 0 and m > 0 - assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m) - assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n) + assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m) + assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.float assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 # LHS and RHS scales must be transposed for TMA load - # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels - if lhs_scales.shape == ((k + 127) // 128, m): - lhs_scales = lhs_scales.permute(1, 0) - assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m - else: - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert lhs_scales.stride(0) == 1 - - if rhs_scales.shape == ((k + 127) // 128, n): - rhs_scales = rhs_scales.permute(1, 0) - assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n - else: - rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales) - assert rhs_scales.stride(0) == 1 + # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels + def get_valid_scales(scales: torch.Tensor, mn: int): + if scales.shape == (ceil_div(k, 128), mn): + # For k-grouped GEMMs + scales = scales.permute(1, 0) + assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn + else: + scales = get_col_major_tma_aligned_tensor(scales) + return scales + + lhs_scales = get_valid_scales(lhs_scales, m) + rhs_scales = get_valid_scales(rhs_scales, n) # Do nothing if `k` is zero if k == 0: return # K must be aligned to 128 - aligned_k = (k + 127) // 128 * 128 + aligned_k = ceil_div(k, 128) * 128 # Auto-tuning with compilation num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) - num_last_stages = (k + 127) // 128 % num_stages + num_last_stages = ceil_div(k, 128) % num_stages block_k = 128 num_tma_threads = 128 num_math_threads_per_group = 128 @@ -151,10 +149,10 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], k = batch_sizes[i] lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k) rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k) - lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] - rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] + lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] + rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i]) lhs_offset += m * k rhs_offset += n * k - scales_offset += (k + 127) // 128 + scales_offset += ceil_div(k, 128) diff --git a/tests/test_core.py b/tests/test_core.py index 36c1c34..03038db 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: assert m % 4 == 0, f'TMA alignment error: {m}' x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) @@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ assert m % 4 == 0, f'TMA alignment error: {m}' x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) @@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) - total_scale_factors = sum((k + 127) // 128 for k in k_sizes) + total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) @@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) - num_scales = (k + 127) // 128 + num_scales = ceil_div(k, 128) x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) From 3b412f458a4e2cf816fc51d8d3a7e1a8549f915c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:53:52 +0800 Subject: [PATCH 4/5] Unify `kwargs` usages --- deep_gemm/jit/runtime.py | 7 +++---- deep_gemm/jit_kernels/gemm.py | 4 ++-- deep_gemm/jit_kernels/m_grouped_gemm.py | 8 ++++---- deep_gemm/jit_kernels/runtime.py | 4 ++-- deep_gemm/jit_kernels/wgrad_gemm.py | 4 ++-- tests/test_jit.py | 2 +- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index ffcd0b3..52af8e1 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,4 +1,3 @@ -import copy import os import subprocess import time @@ -27,14 +26,14 @@ class Runtime: return all(os.path.exists(os.path.join(path, file)) for file in files) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: raise NotImplemented @staticmethod - def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: raise NotImplemented - def __call__(self, **kwargs) -> cbd.CUresult: + def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult: # Load CUBIN if self.kernel is None: start_time = time.time_ns() diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 5f7a123..9cb01c3 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -237,6 +237,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], } # Generate, build and run the kernel - code = FP8GemmRuntime.generate(**kwargs) + code = FP8GemmRuntime.generate(kwargs) runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index c2f2d93..b072060 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -101,9 +101,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten } # Generate, build and run the kernel - code = FP8GemmRuntime.generate(**kwargs) + code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -200,6 +200,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] } # Generate, build and run the kernel - code = FP8GemmRuntime.generate(**kwargs) + code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index a7b0e66..e65e85a 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -138,7 +138,7 @@ class FP8GemmRuntime(Runtime): super().__init__(path) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: code = f''' #ifdef __CUDACC_RTC__ #include @@ -233,7 +233,7 @@ class FP8WGradGemmRuntime(Runtime): super().__init__(path) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: code = f''' #ifdef __CUDACC_RTC__ #include diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 658f005..d0655cc 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -108,9 +108,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], } # Generate, build and run the kernel - code = FP8WGradGemmRuntime.generate(**kwargs) + code = FP8WGradGemmRuntime.generate(kwargs) runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], diff --git a/tests/test_jit.py b/tests/test_jit.py index a1bf583..413bd01 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -16,7 +16,7 @@ class VectorAddRuntime(jit.Runtime): super().__init__(path) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: return f""" #ifdef __CUDACC_RTC__ #include From 104a6ec109d8fe5e39141d19639c4f2067f506fb Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 17:04:21 +0800 Subject: [PATCH 5/5] Add `__assertfail` --- deep_gemm/jit/runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 52af8e1..3646d26 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -48,7 +48,7 @@ class Runtime: command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) assert result.returncode == 0 - illegal_names = ['vprintf', '__instantiate_kernel', '__internal'] + illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail'] check_illegal = lambda line: any([name in line for name in illegal_names]) kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC') and not check_illegal(line)]