From bfe983c4c27634d5f696fd613154c965c110385e Mon Sep 17 00:00:00 2001 From: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Date: Wed, 7 May 2025 11:38:14 +0800 Subject: [PATCH 1/5] Refactor JIT compilation (+NVRTC support) (#94) * [wip] refactor: compile to .cubin Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * refactor: compile to .cubin and add NVRTC option Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: compiler version Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: compat for old drivers Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: save kernel name to file Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: fix win compat Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: windows compat Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> * feat: make API more general Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: drop support for CUDA<12.3 Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * doc: update README Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * Some lints and refactor * Refactor runtime * Several fixes * Refactor environment variables * Code format * Add a TODO * Compatible with CUDA 12.3 * Fix indent * Fix typing * Drop support for Windows * Add a TODO --------- Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Co-authored-by: Chenggang Zhao --- README.md | 34 ++- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 124 +--------- deep_gemm/include/deep_gemm/mma_utils.cuh | 2 + deep_gemm/include/deep_gemm/nvrtc_std.cuh | 103 ++++++++ deep_gemm/include/deep_gemm/scheduler.cuh | 5 +- deep_gemm/include/deep_gemm/tma_utils.cuh | 77 +----- deep_gemm/include/deep_gemm/utils.cuh | 29 +-- deep_gemm/jit/__init__.py | 3 +- deep_gemm/jit/compiler.py | 271 ++++++++++++++++------ deep_gemm/jit/interleave_ffma.py | 6 +- deep_gemm/jit/runtime.py | 88 ++++--- deep_gemm/jit/template.py | 114 --------- deep_gemm/jit_kernels/gemm.py | 101 ++++---- deep_gemm/jit_kernels/m_grouped_gemm.py | 153 ++++++------ deep_gemm/jit_kernels/runtime.py | 254 ++++++++++++++++++++ deep_gemm/jit_kernels/tuner.py | 50 ++-- deep_gemm/utils.py | 19 +- tests/test_core.py | 5 + tests/test_jit.py | 131 +++++++---- 19 files changed, 909 insertions(+), 660 deletions(-) create mode 100644 deep_gemm/include/deep_gemm/nvrtc_std.cuh delete mode 100644 deep_gemm/jit/template.py create mode 100644 deep_gemm/jit_kernels/runtime.py diff --git a/README.md b/README.md index dab1f05..f4a6a60 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] MoE scheduler with TMA multicast compatibility - [x] Fix TMA multicast compatibility for indivisible shapes - [ ] Skip useless computation on M -- [ ] NVRTC as a faster compiler +- [x] NVRTC as a faster compiler +- [ ] Stolen JIT cache - [ ] Sanitizer for testing - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models @@ -104,14 +105,23 @@ The library provides some utility functions besides the above kernels: The library also provides some environment variables, which may be useful: -- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default -- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default -- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler -- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization -- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output -- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details -- `DG_JIT_PRINT_NVCC_COMMAND`: 0 or 1, print NVCC compilation command -- `DG_JIT_DEBUG`: 0 or 1, print more debugging information +- General + - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default +- JIT cache related + - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default + - `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default +- NVCC/NVRTC selections + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default + - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default +- Compiler options + - `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default + - `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default +- Post optimization + - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default +- Testing + - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. @@ -138,9 +148,9 @@ The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide - Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction - [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups -- Larger block sizes -- Less bank conflicts via 3D TMA 🐳 -- Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳 +- Less bank conflicts via 3D TMA or swizzling +- Larger block sizes (up to 256x128 🐳) +- Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳 #### A unified and optimized block scheduler diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index e8370af..c57691b 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -84,16 +84,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_id(); - // Prefetch TMA descriptors at very beginning + // Prefetch TMA descriptors at the very beginning if (threadIdx.x == kNumMathThreads) { - cute::prefetch_tma_descriptor(&tensor_map_a); - cute::prefetch_tma_descriptor(&tensor_map_b); - cute::prefetch_tma_descriptor(&tensor_map_scales_a); + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); // `tensor_map_d` is only used in swizzling mode // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode if constexpr (kSwizzleDMode > 0) - cute::prefetch_tma_descriptor(&tensor_map_d); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); } __syncwarp(); @@ -447,119 +448,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #endif } -template -class Gemm { -private: - using Barrier = cuda::barrier; - -public: - Gemm() = default; - - static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, - uint32_t shape_m, - const CUtensorMap& tma_a_desc, - const CUtensorMap& tma_b_desc, - const CUtensorMap& tma_scales_a_desc, - const CUtensorMap& tma_d_desc, - cudaStream_t stream, - int num_sms, uint32_t smem_size) { - // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps - constexpr uint32_t kNumTMAThreads = 128; - constexpr uint32_t kNumMathThreadsPerGroup = 128; - auto kernel = fp8_gemm_kernel; - DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); - - // Cluster launch - cudaLaunchConfig_t config; - config.gridDim = num_sms; - config.blockDim = get_num_threads_per_sm(BLOCK_M); - config.dynamicSmemBytes = smem_size; - config.stream = stream; - - // Clusters for TMA multicast - // NOTES: `>= 4` cluster size will cause performance degradation - cudaLaunchAttribute attr; - attr.id = cudaLaunchAttributeClusterDimension; - attr.val.clusterDim = {kNumTMAMulticast, 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; - - // Launch - auto status = cudaLaunchKernelEx(&config, kernel, - gmem_d, scales_b, grouped_layout, - shape_m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); - DG_HOST_ASSERT(status == cudaSuccess); - } - - template - static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); - } - - template - static CUtensorMap make_2d_tma_b_desc(T* global_address) { - return make_2d_tma_desc(global_address, Layout::ColMajor, - SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); - } - - template - static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { - auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; - if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; - if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; - if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; - - // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes - // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), - swizzle_mode); - } - - template - static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { - // Make TMA aligned to 16 bytes - constexpr uint32_t kAlignment = 16 / sizeof(T); - shape_m = ceil_div(shape_m, kAlignment) * kAlignment; - - return make_2d_tma_desc(global_address, Layout::ColMajor, - shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); - } - - template - static CUtensorMap make_2d_tma_desc( - T* global_address, Layout layout, - uint32_t gmem_rows, uint32_t gmem_cols, - uint32_t smem_rows, uint32_t smem_cols, - CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { - if (layout == Layout::RowMajor) { - uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; - uint32_t smem_dim[2] = {smem_cols, smem_rows}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); - } else { - uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; - uint32_t smem_dim[2] = {smem_rows, smem_cols}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); - } - } -}; - }; // namespace deep_gemm #pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index a442af7..c6c7e28 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -1,6 +1,8 @@ #pragma once +#ifndef __CUDACC_RTC__ #include +#endif #include #include diff --git a/deep_gemm/include/deep_gemm/nvrtc_std.cuh b/deep_gemm/include/deep_gemm/nvrtc_std.cuh new file mode 100644 index 0000000..00ce734 --- /dev/null +++ b/deep_gemm/include/deep_gemm/nvrtc_std.cuh @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef __CUDACC_RTC__ + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using int32_t = signed int; +using uint32_t = unsigned int; +using int64_t = signed long long; +using uint64_t = unsigned long long; +using cuuint64_t = unsigned long long; + +#ifndef CU_TENSOR_MAP_NUM_QWORDS +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +}; + +using CUtensorMap = CUtensorMap_st; +#endif + +namespace std { + +template struct integral_constant { + static constexpr T value = v; + + using value_type = T; + using type = integral_constant; + + __device__ constexpr operator value_type() const noexcept { return value; } + + __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +template struct is_same : false_type {}; + +template struct is_same : true_type {}; + +template +inline constexpr bool is_same_v = is_same::value; + +namespace index_sequence_impl { + +// Based on https://stackoverflow.com/a/32223343/11717224 +template struct index_sequence { + using type = index_sequence; + using value_type = size_t; + static constexpr size_t size() noexcept { return sizeof...(Ints); } +}; + +template struct _merge_and_renumber; + +template +struct _merge_and_renumber, index_sequence> + : index_sequence {}; + +template +struct make_index_sequence + : _merge_and_renumber::type, + typename make_index_sequence::type> {}; + +template <> struct make_index_sequence<0> : index_sequence<> {}; +template <> struct make_index_sequence<1> : index_sequence<0> {}; + +} // namespace index_sequence_impl + +template +using index_sequence = index_sequence_impl::index_sequence; + +template +using make_index_sequence = index_sequence_impl::make_index_sequence; + +} // namespace std + +#endif diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 9743871..c213d57 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -46,7 +46,7 @@ struct Scheduler { } } - __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) { + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { if (num_blocks_in_group == 1) return false; if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { @@ -63,7 +63,8 @@ struct Scheduler { } } - __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, + uint32_t& m_block_idx, uint32_t& n_block_idx) { DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index 18cdb58..795dca6 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -1,85 +1,10 @@ #pragma once -#include -#include -#include -#include -#include -#include - #include "utils.cuh" namespace deep_gemm { -template -constexpr CUtensorMapDataType get_CUtensorMapDataType() { - if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_INT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_INT64; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; - } -} - -inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { - // Get pointer to `cuTensorMapEncodeTiled` - cudaDriverEntryPointQueryResult driver_status; - void* cuTensorMapEncodeTiled_ptr = nullptr; - -#if CUDA_VERSION >= 12050 - cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, - cudaEnableDefault, &driver_status); -#else - cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, - cudaEnableDefault, &driver_status); -#endif - - if (driver_status != cudaDriverEntryPointSuccess) - throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); -} - -template -CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], - uint64_t stride_in_bytes, uint32_t smem_dim[2], - CUtensorMapSwizzle swizzle_type, - PFN_cuTensorMapEncodeTiled encode_func = nullptr) { - CUtensorMap tensor_map = {}; - uint64_t global_stride[1] = {stride_in_bytes}; - uint32_t elem_strides[2] = {1, 1}; - - if (encode_func == nullptr) - encode_func = get_cuTensorMapEncodeTiled(); - - auto result = encode_func( - &tensor_map, get_CUtensorMapDataType>(), 2, - global_address, gmem_dim, global_stride, smem_dim, elem_strides, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - DG_HOST_ASSERT(result == CUDA_SUCCESS); - return tensor_map; -} - +// TODO: move this function to other files __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 9b93af5..598a414 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -1,33 +1,14 @@ #pragma once -#include - #ifdef __CLION_IDE__ -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + #define printf host_device_printf #endif -class AssertionException : public std::exception { -private: - std::string message{}; - -public: - explicit AssertionException(const std::string& message) : message(message) {} - - const char *what() const noexcept override { return message.c_str(); } -}; - -#ifndef DG_HOST_ASSERT -#define DG_HOST_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", \ - __FILE__, __LINE__, #cond); \ - throw AssertionException("Assertion failed: " #cond); \ - } \ -} while (0) -#endif - #ifndef DG_DEVICE_ASSERT #define DG_DEVICE_ASSERT(cond) \ do { \ diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py index eb08b14..06a5194 100644 --- a/deep_gemm/jit/__init__.py +++ b/deep_gemm/jit/__init__.py @@ -1,3 +1,2 @@ -from .compiler import get_nvcc_compiler, build -from .template import cpp_format, generate +from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index c17d466..559a2f6 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -1,15 +1,18 @@ -import hashlib import functools +import hashlib import os import re import subprocess +import time import uuid +from typing import List, Tuple, Type + +import cuda.bindings +import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME -from typing import Tuple from . import interleave_ffma from .runtime import Runtime, RuntimeCache -from .template import typename_map runtime_cache = RuntimeCache() @@ -22,21 +25,22 @@ def hash_to_hex(s: str) -> str: @functools.lru_cache(maxsize=None) def get_jit_include_dir() -> str: - return f'{os.path.dirname(os.path.abspath(__file__))}/../include' + return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') @functools.lru_cache(maxsize=None) def get_deep_gemm_version() -> str: - # Update include directories - include_dir = f'{get_jit_include_dir()}/deep_gemm' - assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' md5 = hashlib.md5() + + # Update include directories + include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') + assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): - with open(f'{include_dir}/{filename}', 'rb') as f: + with open(os.path.join(include_dir, filename), 'rb') as f: md5.update(f.read()) # Update `interleave_ffma.py` - with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f: + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: md5.update(f.read()) return md5.hexdigest()[0:12] @@ -44,16 +48,20 @@ def get_deep_gemm_version() -> str: @functools.lru_cache(maxsize=None) def get_nvcc_compiler() -> Tuple[str, str]: paths = [] - if os.getenv('DG_NVCC_COMPILER'): - paths.append(os.getenv('DG_NVCC_COMPILER')) - paths.append(f'{CUDA_HOME}/bin/nvcc') + if os.getenv('DG_JIT_NVCC_COMPILER'): + paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) + + paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) # Try to find the first available NVCC compiler least_version_required = '12.3' version_pattern = re.compile(r'release (\d+\.\d+)') for path in paths: if os.path.exists(path): - match = version_pattern.search(os.popen(f'{path} --version').read()) + command = [path, '--version'] + result = subprocess.run(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True) + match = version_pattern.search(result.stdout) version = match.group(1) assert match, f'Cannot get the version of NVCC compiler {path}' assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' @@ -63,21 +71,21 @@ def get_nvcc_compiler() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def get_default_user_dir(): - if 'DG_CACHE_DIR' in os.environ: - path = os.getenv('DG_CACHE_DIR') + if 'DG_JIT_CACHE_DIR' in os.environ: + path = os.getenv('DG_JIT_CACHE_DIR') os.makedirs(path, exist_ok=True) return path - return os.path.expanduser('~') + '/.deep_gemm' + return os.path.join(os.path.expanduser('~'), '.deep_gemm') @functools.lru_cache(maxsize=None) def get_tmp_dir(): - return f'{get_default_user_dir()}/tmp' + return os.path.join(get_default_user_dir(), 'tmp') @functools.lru_cache(maxsize=None) def get_cache_dir(): - return f'{get_default_user_dir()}/cache' + return os.path.join(get_default_user_dir(), 'cache') def make_tmp_dir(): @@ -86,67 +94,192 @@ def make_tmp_dir(): return tmp_dir -def put(path, data, is_binary=False): +def put(path, data): # Write and do POSIX atomic replace - tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' - with open(tmp_file_path, 'wb' if is_binary else 'w') as f: + tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') + with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: f.write(data) os.replace(tmp_file_path, path) -def build(name: str, arg_defs: tuple, code: str) -> Runtime: - # Compiler flags - cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) - nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '-gencode=arch=compute_90a,code=sm_90a', - '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), - # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - '--diag-suppress=39,174,177,940'] - cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] - flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] - include_dirs = [get_jit_include_dir()] +class Compiler: + @classmethod + def signature(cls) -> str: + pass - # Build signature - enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 - signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' - name = f'kernel.{name}.{hash_to_hex(signature)}' - path = f'{get_cache_dir()}/{name}' + @staticmethod + def __version__() -> Tuple[int, int]: + pass - # Check runtime cache or file system hit - global runtime_cache - if runtime_cache[path] is not None: - if os.getenv('DG_JIT_DEBUG', None): - print(f'Using cached JIT runtime {name} during build') - return runtime_cache[path] + @classmethod + def compile(cls, name: str, code: str, target_path: str) -> None: + pass - # Write the code - os.makedirs(path, exist_ok=True) - args_path = f'{path}/kernel.args' - src_path = f'{path}/kernel.cu' - put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs])) - put(src_path, code) + @staticmethod + def flags() -> List[str]: + cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) + return [f'-std=c++{cpp_standard}', + '--ptxas-options=--register-usage-level=10' + + (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), + # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases + '--diag-suppress=39,161,174,177,940'] - # Compile into a temporary SO file - so_path = f'{path}/kernel.so' - tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so' + @staticmethod + def include_dirs() -> List[str]: + return [get_jit_include_dir()] - # Compile - command = [get_nvcc_compiler()[0], - src_path, '-o', tmp_so_path, - *flags, - *[f'-I{d}' for d in include_dirs]] - if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): - print(f'Compiling JIT runtime {name} with command {command}') - return_code = subprocess.check_call(command) - assert return_code == 0, f'Failed to compile {src_path}' + @classmethod + def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: + # Compiler flags + flags = cls.flags() - # Interleave FFMA reuse - if enable_sass_opt: - interleave_ffma.process(tmp_so_path) + # Build signature + enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) + signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' + name = f'kernel.{name}.{hash_to_hex(signature)}' + path = os.path.join(get_cache_dir(), name) - # Atomic replace SO file - os.replace(tmp_so_path, so_path) + # Check runtime cache or file system hit + global runtime_cache + cached_runtime = runtime_cache.get(path, runtime_cls) + if cached_runtime is not None: + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Using cached JIT runtime {name} during build') + return cached_runtime - # Put cache and return - runtime_cache[path] = Runtime(path) - return runtime_cache[path] + # Compile into a temporary CU file + os.makedirs(path, exist_ok=True) + cubin_path = os.path.join(path, 'kernel.cubin') + tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') + + start_time = time.time() + cls.compile(name, code, tmp_cubin_path) + end_time = time.time() + elapsed_time = end_time - start_time + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') + + # Interleave FFMA reuse + if enable_sass_opt: + interleave_ffma.process(tmp_cubin_path) + + # Atomic replace files + os.replace(tmp_cubin_path, cubin_path) + + # Put cache and return + runtime = runtime_cls(path) + runtime_cache[path] = runtime + return runtime + + +class NVCCCompiler(Compiler): + @staticmethod + def __version__() -> Tuple[int, int]: + _, version = get_nvcc_compiler() + major, minor = map(int, version.split('.')) + return major, minor + + @classmethod + def signature(cls) -> str: + return f'nvcc+{cls.__version__()}' + + @classmethod + def flags(cls) -> List[str]: + cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] + return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], + '-gencode=arch=compute_90a,code=sm_90a', + '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', + f'--compiler-options={",".join(cxx_flags)}'] + + @classmethod + def compile(cls, name: str, code: str, target_path: str) -> None: + # Write the code + path = os.path.join(get_cache_dir(), name) + src_path = os.path.join(path, 'kernel.cu') + put(src_path, code) + command = [get_nvcc_compiler()[0], + src_path, '-o', target_path, + *cls.flags()] + if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): + print(f'Compiling JIT runtime {name} with command {command}') + + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') + assert False, f'Failed to compile {src_path}' + + +class NVRTCCompiler(Compiler): + @staticmethod + def __version__() -> Tuple[int, int]: + res, major, minor = nvrtc.nvrtcVersion() + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + # Failed to get the actual NVRTC version, use cuda-bindings version instead + major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) + return major, minor + + @classmethod + def signature(cls) -> str: + return f'nvrtc+{cls.__version__()}' + + @staticmethod + def include_dirs() -> List[str]: + if CUDA_HOME is None: + raise RuntimeError('CUDA_HOME is required for NVRTC compilation') + return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] + + @classmethod + def flags(cls) -> List[str]: + flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], + '--gpu-architecture=sm_90a', '-default-device'] + # NOTES: PCH is vital for compilation speed + if cls.__version__() >= (12, 8): + flags += ['--pch'] + if int(os.getenv('DG_JIT_DEBUG', 0)): + flags += ['--pch-verbose=true'] + return flags + + @classmethod + def compile(cls, name: str, code: str, target_path: str) -> None: + # Create program + code_bytes = bytes(code, 'utf-8') + result, program = nvrtc.nvrtcCreateProgram( + code_bytes, bytes(name, 'utf-8'), 0, [], []) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' + + # Compile + options = [bytes(flag, 'utf-8') for flag in cls.flags()] + if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): + print(f'Compiling JIT runtime {name} with options: {options}') + compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] + + # Print compiler log + if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + result, log_size = nvrtc.nvrtcGetProgramLogSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' + + log_bytes = bytes(log_size) + result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' + print(f'Compiler log: {log_bytes.decode("utf-8")}') + + # Exit if failed + assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' + + # Create CUBIN + result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' + cubin_bytes = bytes(cubin_size) + result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' + + # Write into the file system + put(target_path, cubin_bytes) + + # Destroy handler + assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' + + +def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: + compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler + return compiler_cls.build(name, code, runtime_cls=runtime_cls) diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py index fcb377e..7899a22 100644 --- a/deep_gemm/jit/interleave_ffma.py +++ b/deep_gemm/jit/interleave_ffma.py @@ -37,7 +37,7 @@ def extract_ffma(sass): collected.append((f'{arch_name}::{func_name}', current)) current = [] - if os.getenv('DG_PRINT_REG_REUSE', None): + if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): print(f'Found {len(collected)} FFMA segments') return collected @@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines): dst_reg_set.add(dst_reg) new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) last_reused, last_dst_reg = reused, dst_reg - if os.getenv('DG_PRINT_REG_REUSE', None): + if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') # Find the offset @@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines): def process(path): - if os.getenv('DG_PRINT_REG_REUSE', None): + if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): print(f'Processing {path}') output = run_cuobjdump(path) segments = extract_ffma(output) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 66c370a..b7c2f95 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,17 +1,18 @@ -import ctypes import os -import torch -from typing import Optional +import subprocess +import time +import cuda.bindings.driver as cbd -from .template import map_ctype +from typing import List, Optional, Type +from torch.utils.cpp_extension import CUDA_HOME class Runtime: - def __init__(self, path: str) -> None: + def __init__(self, path: str, args: List[str] = None) -> None: self.path = path self.lib = None - self.args = None - + self.kernel = None + self.args = args assert self.is_path_valid(self.path) @staticmethod @@ -21,46 +22,69 @@ class Runtime: return False # Contains all necessary files - files = ['kernel.cu', 'kernel.args', 'kernel.so'] + files = ['kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) - def __call__(self, *args) -> int: - # Load SO file - if self.lib is None or self.args is None: - self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so')) - with open(os.path.join(self.path, 'kernel.args'), 'r') as f: - self.args = eval(f.read()) + @staticmethod + def generate(**kwargs) -> str: + raise NotImplemented - # Check args and launch - assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' - cargs = [] - for arg, (name, dtype) in zip(args, self.args): - if isinstance(arg, torch.Tensor): - assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' - else: - assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' - cargs.append(map_ctype(arg)) + @staticmethod + def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: + raise NotImplemented - return_code = ctypes.c_int(0) - self.lib.launch(*cargs, ctypes.byref(return_code)) - return return_code.value + def __call__(self, **kwargs) -> cbd.CUresult: + # Load CUBIN + if self.kernel is None: + start_time = time.time_ns() + + # Load CUBIN + path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8') + result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}' + + # Extract the kernel name + # TODO: use `cuda-bindings` API to do this (requires at least 12.8) + command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + assert result.returncode == 0 + kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')] + assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' + + # Load kernel from the library + result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8')) + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}' + + end_time = time.time_ns() + elapsed_time = (end_time - start_time) / 1e6 + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') + + # noinspection PyArgumentList + return self.launch(self.kernel, *[kwargs[arg] for arg in self.args]) + + def __del__(self) -> None: + if self.lib is not None: + res = cbd.cuLibraryUnload(self.lib)[0] + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to unload library {self.path}: {res}') class RuntimeCache: def __init__(self) -> None: self.cache = {} - def __getitem__(self, path: str) -> Optional[Runtime]: + def __setitem__(self, path: str, runtime: Runtime) -> None: + self.cache[path] = runtime + + def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled - if os.path.exists(path) and Runtime.is_path_valid(path): - runtime = Runtime(path) + if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): + runtime = runtime_cls(path) self.cache[path] = runtime return runtime return None - - def __setitem__(self, path, runtime) -> None: - self.cache[path] = runtime diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py deleted file mode 100644 index ead37f5..0000000 --- a/deep_gemm/jit/template.py +++ /dev/null @@ -1,114 +0,0 @@ -import copy -import ctypes -import os -import torch -from typing import Any, Dict, Iterable, Tuple - - -# Name map for Python `eval` -typename_map: Dict[Any, str] = { - **{t: t.__name__ for t in (bool, int, float)}, - torch.int: 'torch.int', - torch.float: 'torch.float', - torch.bfloat16: 'torch.bfloat16', - torch.float8_e4m3fn: 'torch.float8_e4m3fn', - torch.cuda.Stream: 'torch.cuda.Stream', -} - -# `ctype` map for Python casting -ctype_map: Dict[Any, Any] = { - **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, - **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)}, -} - - -# Type map for both Python API and source code usages -genc_map = { - bool: ('bool', 'bool'), - int: ('int', 'int'), - float: ('float', 'float'), - torch.int: ('void*', 'int*'), - torch.float: ('void*', 'float*'), - torch.bfloat16: ('void*', '__nv_bfloat16*'), - torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), - torch.cuda.Stream: ('void*', 'cudaStream_t'), -} - - -def map_ctype(value: Any) -> Any: - if hasattr(value, 'data_ptr'): - if value.dtype == torch.int: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.bfloat16: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float16: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float8_e4m3fn: - return ctypes.c_void_p(value.data_ptr()) - else: - return ctypes.c_void_p(value.data_ptr()) - - if hasattr(value, 'cuda_stream'): - return ctypes.c_void_p(value.cuda_stream) - - if isinstance(value, bool): - return ctypes.c_bool(value) - elif isinstance(value, int): - return ctypes.c_int(value) - elif isinstance(value, float): - return ctypes.c_float(value) - - return ctype_map[type(value)](value) - - -def cpp_format(template: str, keys: Dict[str, Any]) -> str: - # We don't use `str.format` because it's not safe for C++ {} braces - new_template = copy.deepcopy(template) - for key, value in keys.items(): - value_str = str(value) - if isinstance(value, bool): - value_str = value_str.lower() - new_template = new_template.replace(f'{{{key}}}', f'{value_str}') - return new_template - - -def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str: - # Common prefix - code = '// DeepGEMM auto-generated JIT CUDA source file\n\n' - - # Includes - preload_sys_includes = ['', '', '', ''] - preload_package_includes = ['"cutlass/cutlass.h"'] - - assert isinstance(includes, list) or isinstance(includes, tuple) - sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')]))) - package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')]))) - code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n' - code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n' - - # Function signature - raw = '__raw_' - get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n - code += f'extern "C" void launch(' - code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ]) - code += ') {\n' - - # Cast raw types - code += ' // Cast raw types (if needed)\n' - for arg_name, arg_type in arg_defs: - if genc_map[arg_type][0] != genc_map[arg_type][1]: - code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n' - - # Function body - code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')]) - - # End the function - code += '}\n\n' - - # Debug print - if os.getenv('DG_JIT_DEBUG', None): - print(f'Generated code:\n{code}') - - return code diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index c6fd29d..d8023fe 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -3,40 +3,13 @@ import torch from functools import lru_cache from typing import Tuple +from .runtime import ( + FP8GemmRuntime, GemmType, + make_2d_tma_a_desc, make_2d_tma_b_desc, + make_2d_tma_d_desc, make_2d_tma_scales_a_desc) from .tuner import jit_tuner from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"', ) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto BLOCK_K = 128; -constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; -constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; -constexpr auto kNumGroups = 1; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; -constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; - -// Make a templated GEMM -using gemm_t = Gemm; - -// Launch kernel -auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); -gemm_t::run(out, rhs_scales, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, require_divisible: bool = False) -> bool: @@ -64,7 +37,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: # Try swizzle first, as it does not waste shared memory swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 + block_n_padding = get_block_n_padding_for_smem_d( + block_n) if swizzle_mode == 0 else 0 smem_d = block_m * (block_n + block_n_padding) * 2 smem_a_per_stage = block_m * block_k @@ -78,7 +52,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += ceil_div(smem_scales_b * (1 if block_k % + block_n == 0 else 2), 8) * 8 smem_size += smem_barrier # Swizzle and padding are not compatible @@ -104,7 +79,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Decide block sizes by waves best_block_m, best_block_n = None, None for block_m in block_ms: - # NOTES: the block sizes can not be too large, so at least one dim less than 128 + # NOTES: the block sizes cannot be too large, so at least one dim less than 128 for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): success = False num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) @@ -142,7 +117,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, assert best_smem_config is not None assert best_num_stages is not None - # Decide the number of TMA multicast and whether broadcast on A + # Decide the number of TMA multicasts and whether broadcast on A best_tma_multicast_config = (1, True) # Try to multicast on the larger block side first @@ -173,13 +148,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[m, n]`, representing the result. """ @@ -201,7 +176,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], assert out.dtype == torch.bfloat16 assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() - # LHS scales must be transposed for TMA load, but not for RHS scales + # LHS scales must be transposed for TMA loads, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert rhs_scales.is_contiguous() @@ -211,11 +186,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], return # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, k, 1, num_sms) + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.Normal, lhs, m, k, block_m, block_k, 1) + tensor_map_b = make_2d_tma_b_desc( + GemmType.Normal, rhs, k, n, block_k, block_n, 1) + tensor_map_d = make_2d_tma_d_desc( + GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.Normal, lhs_scales, m, k, block_m, block_k) + + kwargs = { + 'GEMM_TYPE': GemmType.Normal, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'NUM_GROUPS': 1, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -224,14 +230,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, + runtime_cls=FP8GemmRuntime, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 3b518c9..24a2183 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,41 +1,14 @@ import torch from typing import Tuple -from .gemm import get_best_configs, get_block_n_padding_for_smem_d +from .gemm import get_best_configs +from .runtime import ( + FP8GemmRuntime, GemmType, + make_2d_tma_a_desc, make_2d_tma_b_desc, + make_2d_tma_d_desc, make_2d_tma_scales_a_desc) from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"', ) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto BLOCK_K = 128; -constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; -constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; -constexpr auto kNumGroups = {NUM_GROUPS}; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; -constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; - -// Make a templated grouped GEMM -using gemm_t = Gemm; - -// Launch kernel -auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); -gemm_t::run(out, rhs_scales, grouped_layout, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], @@ -44,7 +17,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. On the M axis, inputs are grouped into several batches, of which batch sizes aligned to `get_m_alignment_for_contiguous_layout()` (128). @@ -52,11 +25,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`, the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. m_indices: a tensor of shape `[m_sum]` with type `torch.int`. - `m_indices[i]` records the group which the i-th row of the LHS belong to, + `m_indices[i]` records the group which the i-th row of the LHS belongs to, which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. Values of `m_indices` in every-m-alignment-block must also be the same. """ @@ -87,13 +60,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten return # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) - args = (lhs, lhs_scales, rhs, rhs_scales, out, - m_indices, m, num_groups, - torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, k, 1, num_sms, is_grouped_contiguous=True) + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc( + GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) + tensor_map_d = make_2d_tma_d_desc( + GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) + + kwargs = { + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': m_indices, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -102,20 +102,14 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': 'GroupedContiguous'}, + 'GEMM_TYPE': GemmType.GroupedContiguous}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), - ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, + runtime_cls=FP8GemmRuntime, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -125,7 +119,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch should be separately transposed. @@ -134,7 +128,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. - the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. + The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute in the i-th group. @@ -166,18 +160,45 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] assert rhs_scales.is_contiguous() # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) # Extra checks for TMA store if num_groups > 1 and m > block_m: assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' - args = (lhs, lhs_scales, rhs, rhs_scales, out, - masked_m, m, - torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc( + GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) + tensor_map_d = make_2d_tma_d_desc( + GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1]) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) + + kwargs = { + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': masked_m, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -186,17 +207,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': 'GroupedMasked'}, + 'GEMM_TYPE': GemmType.GroupedMasked}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), - ('grouped_layout', torch.int32), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, + runtime_cls=FP8GemmRuntime, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py new file mode 100644 index 0000000..5396601 --- /dev/null +++ b/deep_gemm/jit_kernels/runtime.py @@ -0,0 +1,254 @@ +import ctypes +import os +import enum +import torch +import cuda.bindings.driver as cbd +from typing import Any, Dict, Tuple + +from ..jit.runtime import Runtime + + +class Layout(enum.Enum): + RowMajor = 0 + ColMajor = 1 + + +class GemmType(enum.Enum): + Normal = 0 + GroupedContiguous = 1 + GroupedMasked = 2 + + def __str__(self) -> str: + return { + 0: 'Normal', + 1: 'GroupedContiguous', + 2: 'GroupedMasked', + }[self.value] + + +tmap_type_map: Dict[Any, str] = { + torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, + torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, + torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, + torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, + torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +} + +swizzle_type_map = { + 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, + 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, + 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, + 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, +} + + +def get_num_math_warpgroups(block_m: int) -> int: + return 1 if block_m == 64 else 2 + + +def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: + assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' + return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads + + +def make_2d_tma_copy_desc(global_address: torch.Tensor, + gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], + stride_in_bytes: cbd.cuuint64_t, + smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], + swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: + tensor_dtype = tmap_type_map[global_address.dtype] + res, tensor_map = cbd.cuTensorMapEncodeTiled( + tensor_dtype, + 2, + global_address.data_ptr(), + gmem_dim, + (stride_in_bytes, ), + smem_dim, + (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), + cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle_type, + cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to encode tensor map: {res}') + return tensor_map + + +def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, + gmem_rows: int, gmem_cols: int, + smem_rows: int, smem_cols: int, + swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: + if layout == Layout.RowMajor: + gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows)) + smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) + else: + gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols)) + smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) + + +def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, + shape_m: int, shape_k: int, + block_m: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.RowMajor, + shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, + block_m, block_k) + + +def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, + shape_k: int, shape_n: int, + block_k: int, block_n: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.ColMajor, + shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), + block_k, block_n) + + +def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor, + shape_m: int, shape_n: int, + block_m: int, block_n: int, + num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap: + # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` + # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(global_address, Layout.RowMajor, + shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, + block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), + swizzle_type_map[swizzle_mode]) + + +def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: + # Make TMA aligned to 16 bytes + tma_alignment = 16 / global_address.element_size() + shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment + + return make_2d_tma_desc(global_address, Layout.ColMajor, + shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), + block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + +class FP8GemmRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, [ + 'NUM_TMA_MULTICAST', + 'M', + 'BLOCK_M', + 'GMEM_D', + 'SCALES_B', + 'GROUPED_LAYOUT', + 'NUM_SMS', + 'SMEM_SIZE', + 'TENSOR_MAP_A', + 'TENSOR_MAP_B', + 'TENSOR_MAP_SCALES_A', + 'TENSOR_MAP_D', + 'STREAM', + ]) + + @staticmethod + def generate(**kwargs) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include +#include + +#include + +using namespace deep_gemm; + +auto ptr = reinterpret_cast(&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >); +''' + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Generated FP8 GEMM code:\n{code}') + return code + + # noinspection PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, + block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, + grouped_layout: torch.Tensor, num_sms: int, smem_size: int, + tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, + tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap, + stream: cbd.CUstream) -> cbd.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to set max dynamic shared memory size: {res}') + + attr_val = cbd.CUlaunchAttributeValue() + attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cbd.CUlaunchAttribute() + attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cbd.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = num_sms + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = smem_size + config.hStream = stream + + arg_values = ( + gmem_d.data_ptr(), + scales_b.data_ptr(), + grouped_layout.data_ptr(), + shape_m, + tensor_map_a, + tensor_map_b, + tensor_map_scales_a, + tensor_map_d, + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + None, + None, + None, + None, + ) + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py index 6ed6749..9a8b6f2 100644 --- a/deep_gemm/jit_kernels/tuner.py +++ b/deep_gemm/jit_kernels/tuner.py @@ -1,9 +1,10 @@ import copy import os import torch -from typing import Any, Dict +import cuda.bindings.driver as cbd +from typing import Any, Callable, Dict, Type, Tuple -from ..jit import build, cpp_format, generate, Runtime +from ..jit import build, Runtime class JITTuner: @@ -11,22 +12,21 @@ class JITTuner: self.tuned = {} def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, - includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime: - # NOTES: we always assume the space and template will not change - # We also assume the GPU device will not be changed + kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: + # NOTES: we always assume the space, template and GPU devices will not change # NOTES: the function must have no accumulated side effects keys = {k: keys[k] for k in sorted(keys.keys())} signature = (name, f'{keys}') if signature in self.tuned: - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Using cached JIT kernel {name} with keys {keys}') return self.tuned[signature] - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Auto-tuning JIT kernel {name} with keys {keys}') assert signature not in self.tuned - assert args is not None + assert kwargs is not None space = (dict(), ) if len(space) == 0 else space kernels = [] @@ -34,30 +34,31 @@ class JITTuner: assert isinstance(tuned_keys, dict) full_keys = copy.deepcopy(keys) full_keys.update(tuned_keys) - code = generate(includes, arg_defs, cpp_format(template, full_keys)) - - # Illegal build must raise errors - kernels.append((build(name, arg_defs, code), tuned_keys)) + code = runtime_cls.generate(**kwargs, **full_keys) + kernels.append((build(name, code, runtime_cls), full_keys)) + # TODO: fix tuning with space > 1 best_runtime, best_time, best_keys = None, None, None for runtime, tuned_keys in kernels: if len(space) > 1: # Check kernel validity - return_code = runtime(*args) - if return_code != 0: - # Pass illegal kernels, e.g. insufficient shared memory capacity - if os.getenv('DG_JIT_DEBUG', None): + return_code = runtime(**tuned_keys, **kwargs) + if return_code != cbd.CUresult.CUDA_SUCCESS: + # Pass illegal kernels, e.g., insufficient shared memory capacity + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') continue # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() - torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') + torch.empty(int(256e6 // 4), dtype=torch.int, + device='cuda').zero_() + torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn( + (8192, 8192), dtype=torch.float, device='cuda') start_event.record() for i in range(20): - assert runtime(*args) == 0 + assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS end_event.record() end_event.synchronize() elapsed_time = start_event.elapsed_time(end_event) @@ -67,15 +68,16 @@ class JITTuner: # Compare if better if best_time is None or elapsed_time < best_time: best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' # Cache the best runtime and return - if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): - print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') - self.tuned[signature] = best_runtime - return best_runtime + if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)): + print( + f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') + self.tuned[signature] = (best_runtime, best_keys) + return best_runtime, best_keys jit_tuner = JITTuner() diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py index d5cdd01..f99ecd4 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -80,25 +80,10 @@ class suppress_stdout_stderr: def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): # Conflict with Nsight Systems - using_nsys = os.environ.get('DG_NSYS_PROFILING', False) + using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle - # this avoid thermal throttling while keeping DVFS at max clocks (slight gain vs sleep / more consistent on GH200) - sleep_between_tests = 0.0 flush_l2_size = int(8e9 // 4) - if os.environ.get('DG_BENCH_DISABLE_L2_FLUSH', False): - flush_l2 = False - if os.environ.get('DG_BENCH_POWER_LIMITED', False): - # if we want to be thermally limited, we need to run many iterations non-stop for a fairly long time - # and spend as little time as possible doing memset and other setup work (80MiB should be enough to flush L2) - num_tests = 2000 - flush_l2_size = int(80e6 // 4) - sleep_val = os.environ.get('DG_BENCH_SLEEP_BETWEEN_TESTS', False) - if sleep_val: - try: - sleep_between_tests = float(sleep_val) - except ValueError: - pass # Keep default # For some auto-tuning kernels with prints fn() @@ -117,8 +102,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: lhs @ rhs dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): - if sleep_between_tests > 0.0: - time.sleep(sleep_between_tests) if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() fn() diff --git a/tests/test_core.py b/tests/test_core.py index bdc1841..de544c4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,3 +1,8 @@ +# PyTorch has its own NVRTC, which may have a lower version than the system +# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch +import cuda.bindings.nvrtc as nvrtc +print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') + import random import torch from typing import Tuple diff --git a/tests/test_jit.py b/tests/test_jit.py index 78bc77b..37b8bc4 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,64 +1,101 @@ +import ctypes import os import torch -from typing import Any +import cuda.bindings.driver as cbd from deep_gemm import jit +# Essential debugging staffs +os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') +os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') -class Capture: - def __init__(self) -> None: - self.read_fd = None - self.write_fd = None - self.saved_stdout = None - self.captured = None - def __enter__(self) -> Any: - self.read_fd, self.write_fd = os.pipe() - self.saved_stdout = os.dup(1) - os.dup2(self.write_fd, 1) - return self +class VectorAddRuntime(jit.Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, [ + 'A', + 'B', + 'C', + 'STREAM', + ]) - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - os.dup2(self.saved_stdout, 1) - os.close(self.write_fd) - with os.fdopen(self.read_fd, 'r') as f: - self.captured = f.read() + @staticmethod + def generate(**kwargs) -> str: + return f""" +#ifdef __CUDACC_RTC__ +#include +#else +#include +#endif - def capture(self) -> str: - return self.captured +#include +#include + +template +__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ + uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) {{ + c[i] = a[i] + b[i]; + }} +}} + +auto ptr = reinterpret_cast(&vector_add); +""" + + # noinspection PyShadowingNames,PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, + stream: cbd.CUstream) -> cbd.CUresult: + assert a.shape == b.shape == c.shape + assert a.device == b.device == c.device + assert a.dim() == 1 + + config = cbd.CUlaunchConfig() + config.gridDimX = (a.numel() + 127) // 128 + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = 128 + config.blockDimY = 1 + config.blockDimZ = 1 + config.hStream = stream + + arg_values = ( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + a.numel(), + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ) + + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] if __name__ == '__main__': - # Runtime - print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') - - # Templates print('Generated code:') - args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), - ('enable_double_streams', bool), ('stream', torch.cuda.Stream)) - body = "\n" - body += 'std::cout << reinterpret_cast(lhs) << std::endl;\n' - body += 'std::cout << reinterpret_cast(rhs) << std::endl;\n' - body += 'std::cout << reinterpret_cast(scale) << std::endl;\n' - body += 'std::cout << reinterpret_cast(out) << std::endl;\n' - body += 'std::cout << enable_double_streams << std::endl;\n' - body += 'std::cout << reinterpret_cast(stream) << std::endl;\n' - code = jit.generate((), args, body) + code = VectorAddRuntime.generate(T='float') print(code) + print() - # Build - print('Building ...') - func = jit.build('test_func', args, code) + for compiler_name in ('NVCC', 'NVRTC'): + # Get compiler + compiler_cls = getattr(jit, f'{compiler_name}Compiler') + print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}') - # Test correctness - print('Running ...') - fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') - fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') - bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') - with Capture() as capture: - assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 - output = capture.capture() - ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' - assert output == ref_output, f'{output=}, {ref_output=}' + # Build + print('Building ...') + func = compiler_cls.build('test_func', code, VectorAddRuntime) - print('JIT test passed') + # Run and check + a = torch.randn((1024, ), dtype=torch.float32, device='cuda') + b = torch.randn((1024, ), dtype=torch.float32, device='cuda') + c = torch.empty_like(a) + ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) + assert ret == cbd.CUresult.CUDA_SUCCESS, ret + torch.testing.assert_close(c, a + b) + print(f'JIT test for {compiler_name} passed\n') From daec8fd2fc04c27ec141075e0f0c257e178ac72c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 11:40:34 +0800 Subject: [PATCH 2/5] Fix pipeline stage edge cases --- deep_gemm/jit_kernels/gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index d8023fe..2122683 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -105,10 +105,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 - stage_candidates = tuple(filter(lambda s: s <= k // 128, (8, 7, 6, 5, 4, 3))) + stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1))) if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: # Unrolling both stages and `num_former_iters` will cause large code size - stage_candidates = (4, 3) + stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) for num_stages in stage_candidates: best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) if best_smem_config[0] <= sm90_capacity: From 085b4a15328bba684a1b428d3c2e9319dfc971ff Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 11:46:52 +0800 Subject: [PATCH 3/5] Add `DG_PRINT_AUTOTUNE` to README --- README.md | 2 ++ deep_gemm/jit_kernels/tuner.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f4a6a60..bb5b00c 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,8 @@ The library also provides some environment variables, which may be useful: - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default - Post optimization - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default +- Heuristic selection + - `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default - Testing - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py index 9a8b6f2..4fc9283 100644 --- a/deep_gemm/jit_kernels/tuner.py +++ b/deep_gemm/jit_kernels/tuner.py @@ -74,8 +74,7 @@ class JITTuner: # Cache the best runtime and return if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)): - print( - f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') + print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') self.tuned[signature] = (best_runtime, best_keys) return best_runtime, best_keys From 8702f910e37cb95c7e822b49b26f3b27b8f79b44 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 13:23:40 +0800 Subject: [PATCH 4/5] Fix 12.9 compatibility --- deep_gemm/jit/compiler.py | 3 +-- deep_gemm/jit/runtime.py | 3 ++- deep_gemm/jit_kernels/runtime.py | 34 +++++++++++++++++--------------- tests/test_jit.py | 4 +++- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 559a2f6..2ab6b25 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -50,7 +50,6 @@ def get_nvcc_compiler() -> Tuple[str, str]: paths = [] if os.getenv('DG_JIT_NVCC_COMPILER'): paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) - paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) # Try to find the first available NVCC compiler @@ -181,7 +180,7 @@ class NVCCCompiler(Compiler): @classmethod def signature(cls) -> str: - return f'nvcc+{cls.__version__()}' + return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' @classmethod def flags(cls) -> List[str]: diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index b7c2f95..74ceff5 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -48,7 +48,8 @@ class Runtime: command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) assert result.returncode == 0 - kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')] + kernel_names = [line.split()[-1] for line in result.stdout.splitlines() + if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line] assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' # Load kernel from the library diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 5396601..fa0a61d 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -175,22 +175,24 @@ class FP8GemmRuntime(Runtime): using namespace deep_gemm; -auto ptr = reinterpret_cast(&fp8_gemm_kernel< - {kwargs['N']}, - {kwargs['K']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['BLOCK_N_PADDING']}, - {kwargs['SWIZZLE_D_MODE']}, - {kwargs['NUM_GROUPS']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, - GemmType::{kwargs['GEMM_TYPE']} - >); +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >); +}}; ''' if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Generated FP8 GEMM code:\n{code}') diff --git a/tests/test_jit.py b/tests/test_jit.py index 37b8bc4..fbd84e1 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -39,7 +39,9 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ }} }} -auto ptr = reinterpret_cast(&vector_add); +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); +}} """ # noinspection PyShadowingNames,PyMethodOverriding From d75b218b7b8f4a5dd5406ac87905039ead3ae42f Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 13:26:58 +0800 Subject: [PATCH 5/5] Update README with NVRTC news --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bb5b00c..f14601c 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## News +- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. ## Roadmap