diff --git a/README.md b/README.md index dab1f05..07bf7f1 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] MoE scheduler with TMA multicast compatibility - [x] Fix TMA multicast compatibility for indivisible shapes - [ ] Skip useless computation on M -- [ ] NVRTC as a faster compiler +- [x] NVRTC as a faster compiler - [ ] Sanitizer for testing - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index e8370af..56c9073 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Prefetch TMA descriptors at very beginning if (threadIdx.x == kNumMathThreads) { - cute::prefetch_tma_descriptor(&tensor_map_a); - cute::prefetch_tma_descriptor(&tensor_map_b); - cute::prefetch_tma_descriptor(&tensor_map_scales_a); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); // `tensor_map_d` is only used in swizzling mode // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode if constexpr (kSwizzleDMode > 0) - cute::prefetch_tma_descriptor(&tensor_map_d); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); } __syncwarp(); @@ -447,119 +447,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #endif } -template -class Gemm { -private: - using Barrier = cuda::barrier; - -public: - Gemm() = default; - - static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, - uint32_t shape_m, - const CUtensorMap& tma_a_desc, - const CUtensorMap& tma_b_desc, - const CUtensorMap& tma_scales_a_desc, - const CUtensorMap& tma_d_desc, - cudaStream_t stream, - int num_sms, uint32_t smem_size) { - // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps - constexpr uint32_t kNumTMAThreads = 128; - constexpr uint32_t kNumMathThreadsPerGroup = 128; - auto kernel = fp8_gemm_kernel; - DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); - - // Cluster launch - cudaLaunchConfig_t config; - config.gridDim = num_sms; - config.blockDim = get_num_threads_per_sm(BLOCK_M); - config.dynamicSmemBytes = smem_size; - config.stream = stream; - - // Clusters for TMA multicast - // NOTES: `>= 4` cluster size will cause performance degradation - cudaLaunchAttribute attr; - attr.id = cudaLaunchAttributeClusterDimension; - attr.val.clusterDim = {kNumTMAMulticast, 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; - - // Launch - auto status = cudaLaunchKernelEx(&config, kernel, - gmem_d, scales_b, grouped_layout, - shape_m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); - DG_HOST_ASSERT(status == cudaSuccess); - } - - template - static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); - } - - template - static CUtensorMap make_2d_tma_b_desc(T* global_address) { - return make_2d_tma_desc(global_address, Layout::ColMajor, - SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); - } - - template - static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { - auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; - if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; - if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; - if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; - - // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes - // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), - swizzle_mode); - } - - template - static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { - // Make TMA aligned to 16 bytes - constexpr uint32_t kAlignment = 16 / sizeof(T); - shape_m = ceil_div(shape_m, kAlignment) * kAlignment; - - return make_2d_tma_desc(global_address, Layout::ColMajor, - shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); - } - - template - static CUtensorMap make_2d_tma_desc( - T* global_address, Layout layout, - uint32_t gmem_rows, uint32_t gmem_cols, - uint32_t smem_rows, uint32_t smem_cols, - CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { - if (layout == Layout::RowMajor) { - uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; - uint32_t smem_dim[2] = {smem_cols, smem_rows}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); - } else { - uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; - uint32_t smem_dim[2] = {smem_rows, smem_cols}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); - } - } -}; - }; // namespace deep_gemm #pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index a442af7..f07e540 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -1,6 +1,8 @@ #pragma once +#ifndef NVRTC_JIT_COMPILATION #include +#endif #include #include diff --git a/deep_gemm/include/deep_gemm/nvrtc_std.cuh b/deep_gemm/include/deep_gemm/nvrtc_std.cuh new file mode 100644 index 0000000..d9f52dc --- /dev/null +++ b/deep_gemm/include/deep_gemm/nvrtc_std.cuh @@ -0,0 +1,101 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef NVRTC_JIT_COMPILATION + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using int32_t = signed int; +using uint32_t = unsigned int; +using int64_t = signed long long; +using uint64_t = unsigned long long; +using cuuint64_t = unsigned long long; + +#ifndef CU_TENSOR_MAP_NUM_QWORDS +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +struct CUtensorMap_st +{ +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +}; + +using CUtensorMap = CUtensorMap_st; +#endif + +namespace std { +template struct integral_constant { + static constexpr T value = v; + using value_type = T; + using type = integral_constant; // using injected-class-name + + __device__ constexpr operator value_type() const noexcept { return value; } + + __device__ constexpr value_type operator()() const noexcept { + return value; + } // since c++14 +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +template struct is_same : false_type {}; + +template struct is_same : true_type {}; + +template +inline constexpr bool is_same_v = is_same::value; + +namespace index_sequence_impl { +// Based on https://stackoverflow.com/a/32223343/11717224 +template struct index_sequence { + using type = index_sequence; + using value_type = size_t; + static constexpr size_t size() noexcept { return sizeof...(Ints); } +}; + +template struct _merge_and_renumber; + +template +struct _merge_and_renumber, index_sequence> + : index_sequence {}; + +template +struct make_index_sequence + : _merge_and_renumber::type, + typename make_index_sequence::type> {}; + +template <> struct make_index_sequence<0> : index_sequence<> {}; +template <> struct make_index_sequence<1> : index_sequence<0> {}; +} // namespace index_sequence_impl + +template +using index_sequence = index_sequence_impl::index_sequence; + +template +using make_index_sequence = index_sequence_impl::make_index_sequence; +} // namespace std + +#endif diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index 18cdb58..6b8ebda 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -1,85 +1,18 @@ #pragma once +#ifndef NVRTC_JIT_COMPILATION #include #include -#include -#include #include +#include +#endif + #include #include "utils.cuh" namespace deep_gemm { -template -constexpr CUtensorMapDataType get_CUtensorMapDataType() { - if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_INT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_INT64; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; - } -} - -inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { - // Get pointer to `cuTensorMapEncodeTiled` - cudaDriverEntryPointQueryResult driver_status; - void* cuTensorMapEncodeTiled_ptr = nullptr; - -#if CUDA_VERSION >= 12050 - cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, - cudaEnableDefault, &driver_status); -#else - cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, - cudaEnableDefault, &driver_status); -#endif - - if (driver_status != cudaDriverEntryPointSuccess) - throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); -} - -template -CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], - uint64_t stride_in_bytes, uint32_t smem_dim[2], - CUtensorMapSwizzle swizzle_type, - PFN_cuTensorMapEncodeTiled encode_func = nullptr) { - CUtensorMap tensor_map = {}; - uint64_t global_stride[1] = {stride_in_bytes}; - uint32_t elem_strides[2] = {1, 1}; - - if (encode_func == nullptr) - encode_func = get_cuTensorMapEncodeTiled(); - - auto result = encode_func( - &tensor_map, get_CUtensorMapDataType>(), 2, - global_address, gmem_dim, global_stride, smem_dim, elem_strides, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - DG_HOST_ASSERT(result == CUDA_SUCCESS); - return tensor_map; -} - __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 9b93af5..8edf35b 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -1,5 +1,6 @@ #pragma once +#ifndef NVRTC_JIT_COMPILATION #include #ifdef __CLION_IDE__ @@ -16,8 +17,12 @@ public: const char *what() const noexcept override { return message.c_str(); } }; +#endif #ifndef DG_HOST_ASSERT +#ifdef NVRTC_JIT_COMPILATION +#define DG_HOST_ASSERT(cond) ((void)0) +#else #define DG_HOST_ASSERT(cond) \ do { \ if (not (cond)) { \ @@ -27,6 +32,7 @@ do { \ } \ } while (0) #endif +#endif #ifndef DG_DEVICE_ASSERT #define DG_DEVICE_ASSERT(cond) \ diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py index eb08b14..8e1ba3a 100644 --- a/deep_gemm/jit/__init__.py +++ b/deep_gemm/jit/__init__.py @@ -1,3 +1,3 @@ -from .compiler import get_nvcc_compiler, build -from .template import cpp_format, generate +from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler +from .template import generate from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index c17d466..36a3361 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -1,15 +1,20 @@ -import hashlib +import abc import functools +import hashlib import os +import platform import re import subprocess +import time import uuid +from typing import List, Tuple, Type + +import cuda.bindings +import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME -from typing import Tuple from . import interleave_ffma -from .runtime import Runtime, RuntimeCache -from .template import typename_map +from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache runtime_cache = RuntimeCache() @@ -22,21 +27,22 @@ def hash_to_hex(s: str) -> str: @functools.lru_cache(maxsize=None) def get_jit_include_dir() -> str: - return f'{os.path.dirname(os.path.abspath(__file__))}/../include' + return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') @functools.lru_cache(maxsize=None) def get_deep_gemm_version() -> str: # Update include directories - include_dir = f'{get_jit_include_dir()}/deep_gemm' - assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' + include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') + assert os.path.exists( + include_dir), f'Cannot find GEMM include directory {include_dir}' md5 = hashlib.md5() for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): - with open(f'{include_dir}/{filename}', 'rb') as f: + with open(os.path.join(include_dir, filename), 'rb') as f: md5.update(f.read()) # Update `interleave_ffma.py` - with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f: + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: md5.update(f.read()) return md5.hexdigest()[0:12] @@ -46,14 +52,19 @@ def get_nvcc_compiler() -> Tuple[str, str]: paths = [] if os.getenv('DG_NVCC_COMPILER'): paths.append(os.getenv('DG_NVCC_COMPILER')) - paths.append(f'{CUDA_HOME}/bin/nvcc') + + nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc' + paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin)) # Try to find the first available NVCC compiler least_version_required = '12.3' version_pattern = re.compile(r'release (\d+\.\d+)') for path in paths: if os.path.exists(path): - match = version_pattern.search(os.popen(f'{path} --version').read()) + command = [path, '--version'] + result = subprocess.run(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True) + match = version_pattern.search(result.stdout) version = match.group(1) assert match, f'Cannot get the version of NVCC compiler {path}' assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' @@ -67,17 +78,17 @@ def get_default_user_dir(): path = os.getenv('DG_CACHE_DIR') os.makedirs(path, exist_ok=True) return path - return os.path.expanduser('~') + '/.deep_gemm' + return os.path.join(os.path.expanduser('~'), '.deep_gemm') @functools.lru_cache(maxsize=None) def get_tmp_dir(): - return f'{get_default_user_dir()}/tmp' + return os.path.join(get_default_user_dir(), 'tmp') @functools.lru_cache(maxsize=None) def get_cache_dir(): - return f'{get_default_user_dir()}/cache' + return os.path.join(get_default_user_dir(), 'cache') def make_tmp_dir(): @@ -86,67 +97,195 @@ def make_tmp_dir(): return tmp_dir -def put(path, data, is_binary=False): +def put(path, data): + is_binary = isinstance(data, bytes) + # Write and do POSIX atomic replace - tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' + tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') with open(tmp_file_path, 'wb' if is_binary else 'w') as f: f.write(data) os.replace(tmp_file_path, path) -def build(name: str, arg_defs: tuple, code: str) -> Runtime: - # Compiler flags - cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) - nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '-gencode=arch=compute_90a,code=sm_90a', - '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), - # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - '--diag-suppress=39,174,177,940'] - cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] - flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] - include_dirs = [get_jit_include_dir()] +class Compiler(abc.ABC): + @staticmethod + @abc.abstractmethod + def __version__() -> Tuple[int, int]: + pass - # Build signature - enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 - signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' - name = f'kernel.{name}.{hash_to_hex(signature)}' - path = f'{get_cache_dir()}/{name}' + @classmethod + @abc.abstractmethod + def compile(cls, name: str, code: str, target_path: str) -> str: + pass - # Check runtime cache or file system hit - global runtime_cache - if runtime_cache[path] is not None: + @staticmethod + def flags() -> List[str]: + cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) + return [f'-std=c++{cpp_standard}', + '--ptxas-options=--register-usage-level=10' + + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), + # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases + '--diag-suppress=39,161,174,177,940'] + + @staticmethod + def include_dirs() -> List[str]: + return [get_jit_include_dir()] + + @classmethod + def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: + # Compiler flags + flags = cls.flags() + + # Build signature + enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int( + os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 + signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' + name = f'kernel.{name}.{hash_to_hex(signature)}' + path = os.path.join(get_cache_dir(), name) + + # Check runtime cache or file system hit + global runtime_cache + cached_runtime = runtime_cache.get(path, runtime_cls) + if cached_runtime is not None: + if os.getenv('DG_JIT_DEBUG', None): + print(f'Using cached JIT runtime {name} during build') + return cached_runtime + + # Compile into a temporary CU file + os.makedirs(path, exist_ok=True) + cubin_path = os.path.join(path, 'kernel.cubin') + tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') + + start_time = time.time() + cls.compile(name, code, tmp_cubin_path) + end_time = time.time() + elapsed_time = end_time - start_time if os.getenv('DG_JIT_DEBUG', None): - print(f'Using cached JIT runtime {name} during build') - return runtime_cache[path] + print( + f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') - # Write the code - os.makedirs(path, exist_ok=True) - args_path = f'{path}/kernel.args' - src_path = f'{path}/kernel.cu' - put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs])) - put(src_path, code) + # Interleave FFMA reuse + if enable_sass_opt: + interleave_ffma.process(tmp_cubin_path) + + # Atomic replace files + os.replace(tmp_cubin_path, cubin_path) - # Compile into a temporary SO file - so_path = f'{path}/kernel.so' - tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so' + # Put cache and return + runtime = runtime_cls(path) + runtime_cache[path] = runtime + return runtime - # Compile - command = [get_nvcc_compiler()[0], - src_path, '-o', tmp_so_path, - *flags, - *[f'-I{d}' for d in include_dirs]] - if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): - print(f'Compiling JIT runtime {name} with command {command}') - return_code = subprocess.check_call(command) - assert return_code == 0, f'Failed to compile {src_path}' - # Interleave FFMA reuse - if enable_sass_opt: - interleave_ffma.process(tmp_so_path) +class NvccCompiler(Compiler): + @staticmethod + def __version__() -> Tuple[int, int]: + _, version = get_nvcc_compiler() + major, minor = map(int, version.split('.')) + return (major, minor) - # Atomic replace SO file - os.replace(tmp_so_path, so_path) + @classmethod + def flags(cls) -> List[str]: + if platform.system() != 'Windows': + cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] + else: + cxx_flags = ['/O2', '/std:c++20'] + + return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], + '-gencode=arch=compute_90a,code=sm_90a', + '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', + f'--compiler-options={",".join(cxx_flags)}'] - # Put cache and return - runtime_cache[path] = Runtime(path) - return runtime_cache[path] + @classmethod + def compile(cls, name: str, code: str, target_path: str): + # Write the code + path = os.path.join(get_cache_dir(), name) + src_path = os.path.join(path, 'kernel.cu') + put(src_path, code) + command = [get_nvcc_compiler()[0], + src_path, '-o', target_path, + *cls.flags()] + if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): + print(f'Compiling JIT runtime {name} with command {command}') + + result = subprocess.run(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True) + if os.getenv('DG_JIT_DEBUG', None): + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f'Failed to compile {src_path}' + + +class NvrtcCompiler(Compiler): + @staticmethod + def __version__() -> Tuple[int, int]: + res, major, minor = nvrtc.nvrtcVersion() + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + # Failed to get actual NVRTC version, use bindings version instead + major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) + return (major, minor) + + @staticmethod + def include_dirs() -> List[str]: + if CUDA_HOME is None: + raise RuntimeError('CUDA_HOME is required for NVRTC compilation') + return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] + + @classmethod + def flags(cls) -> List[str]: + base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], + '--gpu-architecture=sm_90a', '-default-device'] + if cls.__version__() >= (12, 8): + base_flags += ['--pch'] + if os.getenv('DG_JIT_DEBUG', None): + base_flags += ['--pch-verbose=true'] + return base_flags + + @classmethod + def compile(cls, name: str, code: str, target_path: str) -> str: + code_bytes = bytes(code, 'utf-8') + res, program = nvrtc.nvrtcCreateProgram( + code_bytes, bytes(name, 'utf-8'), 0, [], []) + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to create program: {res}') + + options = [bytes(flag, 'utf-8') for flag in cls.flags()] + compile_res = nvrtc.nvrtcCompileProgram( + program, len(options), options)[0] + + if os.getenv('DG_JIT_DEBUG', None): + res, log_size = nvrtc.nvrtcGetProgramLogSize(program) + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to get program log size: {res}') + log_bytes = bytes(log_size) + res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to get program log: {res}') + log_str = log_bytes.decode('utf-8') + print(log_str) + + if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to compile program: {compile_res}') + + res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to get CUBIN size: {res}') + + cubin_bytes = bytes(cubin_size) + res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to get CUBIN: {res}') + + put(target_path, cubin_bytes) + + res = nvrtc.nvrtcDestroyProgram(program)[0] + if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise Exception(f'Failed to destroy program: {res}') + + +def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: + if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: + return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls) + else: + return NvccCompiler.build(name, code, runtime_cls=runtime_cls) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 66c370a..e5f7bfb 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,17 +1,20 @@ -import ctypes import os -import torch -from typing import Optional +import time +from typing import Any, Callable, Dict, List, Optional, Type -from .template import map_ctype +import cuda.bindings.driver as cuda + +from .utils import run_gemm class Runtime: - def __init__(self, path: str) -> None: + def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None: self.path = path self.lib = None - self.args = None - + self.kernel = None + self.kernel_name = kernel_name + self.caller = caller + self.args = args assert self.is_path_valid(self.path) @staticmethod @@ -21,46 +24,91 @@ class Runtime: return False # Contains all necessary files - files = ['kernel.cu', 'kernel.args', 'kernel.so'] + files = ['kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) - def __call__(self, *args) -> int: - # Load SO file - if self.lib is None or self.args is None: - self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so')) - with open(os.path.join(self.path, 'kernel.args'), 'r') as f: - self.args = eval(f.read()) + def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: + # Load CUBIN + if self.kernel is None: + start_time = time.time_ns() + res, lib = cuda.cuLibraryLoadFromFile( + bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to load library: {res}') - # Check args and launch - assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' - cargs = [] - for arg, (name, dtype) in zip(args, self.args): - if isinstance(arg, torch.Tensor): - assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' + res, kernel_count = cuda.cuLibraryGetKernelCount(lib) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to get kernel count: {res}') + + res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to enumerate kernels: {res}') + + for kernel in kernels: + res, kernel_name = cuda.cuKernelGetName(kernel) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to get kernel name: {res}') + if bytes(self.kernel_name, encoding='utf-8') in kernel_name: + self.kernel = kernel + break + + if self.kernel is not None: + self.lib = lib else: - assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' - cargs.append(map_ctype(arg)) + raise Exception('Failed to find required kernel') - return_code = ctypes.c_int(0) - self.lib.launch(*cargs, ctypes.byref(return_code)) - return return_code.value + end_time = time.time_ns() + elapsed_time = (end_time - start_time) / 1000 + if os.getenv('DG_JIT_DEBUG', None): + print( + f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') + + return self.caller( + self.kernel, + *[kwargs[arg] for arg in self.args] + ) + + def __del__(self) -> None: + if self.lib is not None: + res = cuda.cuLibraryUnload(self.lib)[0] + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to unload library {self.path}: {res}') + + +class Fp8GemmRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, 'fp8_gemm', run_gemm, [ + 'NUM_TMA_MULTICAST', + 'M', + 'BLOCK_M', + 'GMEM_D', + 'SCALES_B', + 'GROUPED_LAYOUT', + 'NUM_SMS', + 'SMEM_SIZE', + 'TENSOR_MAP_A', + 'TENSOR_MAP_B', + 'TENSOR_MAP_SCALES_A', + 'TENSOR_MAP_D', + 'STREAM', + ]) class RuntimeCache: def __init__(self) -> None: self.cache = {} - def __getitem__(self, path: str) -> Optional[Runtime]: + def __setitem__(self, path, runtime) -> None: + self.cache[path] = runtime + + def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled if os.path.exists(path) and Runtime.is_path_valid(path): - runtime = Runtime(path) + runtime = runtime_cls(path) self.cache[path] = runtime return runtime - return None - - def __setitem__(self, path, runtime) -> None: - self.cache[path] = runtime + return None \ No newline at end of file diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index ead37f5..461691f 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -1,111 +1,48 @@ -import copy -import ctypes import os -import torch -from typing import Any, Dict, Iterable, Tuple +from typing import Any, Dict -# Name map for Python `eval` -typename_map: Dict[Any, str] = { - **{t: t.__name__ for t in (bool, int, float)}, - torch.int: 'torch.int', - torch.float: 'torch.float', - torch.bfloat16: 'torch.bfloat16', - torch.float8_e4m3fn: 'torch.float8_e4m3fn', - torch.cuda.Stream: 'torch.cuda.Stream', -} +def generate(**kwargs: Dict[str, Any]) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#ifndef NVRTC_JIT_COMPILATION +#define NVRTC_JIT_COMPILATION +#endif -# `ctype` map for Python casting -ctype_map: Dict[Any, Any] = { - **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, - **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)}, -} +#include +#else -# Type map for both Python API and source code usages -genc_map = { - bool: ('bool', 'bool'), - int: ('int', 'int'), - float: ('float', 'float'), - torch.int: ('void*', 'int*'), - torch.float: ('void*', 'float*'), - torch.bfloat16: ('void*', '__nv_bfloat16*'), - torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), - torch.cuda.Stream: ('void*', 'cudaStream_t'), -} +#include +#include +#endif -def map_ctype(value: Any) -> Any: - if hasattr(value, 'data_ptr'): - if value.dtype == torch.int: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.bfloat16: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float16: - return ctypes.c_void_p(value.data_ptr()) - elif value.dtype == torch.float8_e4m3fn: - return ctypes.c_void_p(value.data_ptr()) - else: - return ctypes.c_void_p(value.data_ptr()) +#include +#include +#include - if hasattr(value, 'cuda_stream'): - return ctypes.c_void_p(value.cuda_stream) +using namespace deep_gemm; - if isinstance(value, bool): - return ctypes.c_bool(value) - elif isinstance(value, int): - return ctypes.c_int(value) - elif isinstance(value, float): - return ctypes.c_float(value) - - return ctype_map[type(value)](value) - - -def cpp_format(template: str, keys: Dict[str, Any]) -> str: - # We don't use `str.format` because it's not safe for C++ {} braces - new_template = copy.deepcopy(template) - for key, value in keys.items(): - value_str = str(value) - if isinstance(value, bool): - value_str = value_str.lower() - new_template = new_template.replace(f'{{{key}}}', f'{value_str}') - return new_template - - -def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str: - # Common prefix - code = '// DeepGEMM auto-generated JIT CUDA source file\n\n' - - # Includes - preload_sys_includes = ['', '', '', ''] - preload_package_includes = ['"cutlass/cutlass.h"'] - - assert isinstance(includes, list) or isinstance(includes, tuple) - sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')]))) - package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')]))) - code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n' - code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n' - - # Function signature - raw = '__raw_' - get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n - code += f'extern "C" void launch(' - code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ]) - code += ') {\n' - - # Cast raw types - code += ' // Cast raw types (if needed)\n' - for arg_name, arg_type in arg_defs: - if genc_map[arg_type][0] != genc_map[arg_type][1]: - code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n' - - # Function body - code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')]) - - # End the function - code += '}\n\n' +__global__ void dummy_kernel() {{ + void *ptr = (void *)&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >; +}} +''' # Debug print if os.getenv('DG_JIT_DEBUG', None): diff --git a/deep_gemm/jit/utils.py b/deep_gemm/jit/utils.py new file mode 100644 index 0000000..1321f24 --- /dev/null +++ b/deep_gemm/jit/utils.py @@ -0,0 +1,164 @@ +import ctypes +from enum import Enum +from typing import Any, Dict, Tuple + +import cuda.bindings.driver as cuda +import torch + + +class Layout(Enum): + RowMajor = 0 + ColMajor = 1 + + +class GemmType(Enum): + Normal = 0 + GroupedContiguous = 1 + GroupedMasked = 2 + + def __str__(self) -> str: + return { + 0: 'Normal', + 1: 'GroupedContiguous', + 2: 'GroupedMasked', + }[self.value] + + +typename_map: Dict[Any, str] = { + torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, + torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, + torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, + torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, + torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +} + +swizzle_map = { + 128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + 64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, + 32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, + 0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, +} + +def get_num_math_warpgroups(block_m: int) -> int: + return 1 if block_m == 64 else 2 + +def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: + assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' + return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads + + +def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap: + tensor_dtype = typename_map[global_address.dtype] + res, tensor_map = cuda.cuTensorMapEncodeTiled( + tensor_dtype, + 2, # tensor rank + global_address.data_ptr(), + gmem_dim, + (stride_in_bytes,), # global strides + smem_dim, + (cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides + cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle_type, + cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to encode tensor map: {res}') + + return tensor_map + + +def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap: + if layout == Layout.RowMajor: + gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows)) + smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) + else: + gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols)) + smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) + + +def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k) + + +def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n) + + +def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap: + # Swizzling requires the inner box dim less or equal than `kSwizzleDMode` + # bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode]) + + +def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap: + # Make TMA aligned to 16 bytes + kAlignment = 16 / global_address.element_size() + shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment + + return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + +def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0] + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to set max dynamic shared memory size: {res}') + + attr_val = cuda.CUlaunchAttributeValue() + attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cuda.CUlaunchAttribute() + attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cuda.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = num_sms + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = smem_size + config.hStream = stream + + kernelValues = ( + gmem_d.data_ptr(), + scales_b.data_ptr(), + grouped_layout.data_ptr(), + shape_m, + tensor_map_a, + tensor_map_b, + tensor_map_scales_a, + tensor_map_d, + ) + kernelTypes = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + None, + None, + None, + None, + ) + + return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index c6fd29d..f27b025 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -3,40 +3,11 @@ import torch from functools import lru_cache from typing import Tuple +from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc + from .tuner import jit_tuner from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"', ) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto BLOCK_K = 128; -constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; -constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; -constexpr auto kNumGroups = 1; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; -constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; - -// Make a templated GEMM -using gemm_t = Gemm; - -// Launch kernel -auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); -gemm_t::run(out, rhs_scales, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, require_divisible: bool = False) -> bool: @@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: # Try swizzle first, as it does not waste shared memory swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 + block_n_padding = get_block_n_padding_for_smem_d( + block_n) if swizzle_mode == 0 else 0 smem_d = block_m * (block_n + block_n_padding) * 2 smem_a_per_stage = block_m * block_k @@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += ceil_div(smem_scales_b * (1 if block_k % + block_n == 0 else 2), 8) * 8 smem_size += smem_barrier # Swizzle and padding are not compatible @@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) + (144, 160, ) - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) + def fix_wave_saturate(x): return num_sms if x == 0 else x + + def get_num_waves(bm, bn): return (ceil_div( + ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) + + def get_last_wave_util(bm, bn): return fix_wave_saturate( + (ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) # Decide block sizes by waves best_block_m, best_block_n = None, None @@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # NOTES: the block sizes can not be too large, so at least one dim less than 128 for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) + num_waves, best_num_waves = get_num_waves( + block_m, block_n), get_num_waves(best_block_m, best_block_n) if best_block_m is None or best_block_n is None: success = True elif num_waves < best_num_waves: @@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, success |= block_n == best_block_n and block_m < best_block_m # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better success |= block_m != best_block_m and block_n > best_block_n - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) + best_block_m, best_block_n = (block_m, block_n) if success else ( + best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None # Always pick the longest one @@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = (4, 3) for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) + best_smem_config = get_smem_config( + num_stages, k, best_block_m, best_block_n) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break @@ -159,8 +139,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + num_min_sms = ceil_div(ceil_div(m, best_block_m) * + ceil_div(n, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div( + num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] assert num_min_sms <= num_sms return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config @@ -211,11 +193,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], return # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, k, 1, num_sms) + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.Normal, lhs, m, k, block_m, block_k) + tensor_map_b = make_2d_tma_b_desc( + GemmType.Normal, rhs, k, n, block_k, block_n) + tensor_map_d = make_2d_tma_d_desc( + GemmType.Normal, smem_config[1], out, m, n, block_m, block_n) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.Normal, lhs_scales, m, k, block_m, block_k) + + kwargs = { + 'GEMM_TYPE': GemmType.Normal, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'NUM_GROUPS': 1, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -224,14 +237,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 3b518c9..22281b1 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,41 +1,12 @@ import torch from typing import Tuple +from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc + from .gemm import get_best_configs, get_block_n_padding_for_smem_d from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"', ) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto BLOCK_K = 128; -constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; -constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; -constexpr auto kNumGroups = {NUM_GROUPS}; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; -constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; - -// Make a templated grouped GEMM -using gemm_t = Gemm; - -// Launch kernel -auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); -gemm_t::run(out, rhs_scales, grouped_layout, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], @@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten return # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) - args = (lhs, lhs_scales, rhs, rhs_scales, out, - m_indices, m, num_groups, - torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, k, 1, num_sms, is_grouped_contiguous=True) + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc( + GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) + tensor_map_d = make_2d_tma_d_desc( + GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) + + kwargs = { + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': m_indices, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': 'GroupedContiguous'}, + 'GEMM_TYPE': GemmType.GroupedContiguous}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), - ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) # Extra checks for TMA store if num_groups > 1 and m > block_m: assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' - args = (lhs, lhs_scales, rhs, rhs_scales, out, - masked_m, m, - torch.cuda.current_stream(), num_sms, smem_config[0]) - runtime = jit_tuner.compile_and_tune( + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc( + GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) + tensor_map_d = make_2d_tma_d_desc( + GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) + + kwargs = { + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'SCALES_B': rhs_scales, + 'GROUPED_LAYOUT': masked_m, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], @@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': 'GroupedMasked'}, + 'GEMM_TYPE': GemmType.GroupedMasked}, space=(), - includes=includes, - arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), - ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), - ('out', torch.bfloat16), - ('grouped_layout', torch.int32), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), - template=template, - args=args + kwargs=kwargs, ) # Run the kernel - runtime(*args) + runtime(**best_keys, **kwargs) diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py index 6ed6749..4dcd6cd 100644 --- a/deep_gemm/jit_kernels/tuner.py +++ b/deep_gemm/jit_kernels/tuner.py @@ -3,15 +3,16 @@ import os import torch from typing import Any, Dict -from ..jit import build, cpp_format, generate, Runtime +import cuda.bindings.driver as cuda + +from ..jit import build, generate, Runtime class JITTuner: def __init__(self) -> None: self.tuned = {} - def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, - includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime: + def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime: # NOTES: we always assume the space and template will not change # We also assume the GPU device will not be changed # NOTES: the function must have no accumulated side effects @@ -26,7 +27,7 @@ class JITTuner: print(f'Auto-tuning JIT kernel {name} with keys {keys}') assert signature not in self.tuned - assert args is not None + assert kwargs is not None space = (dict(), ) if len(space) == 0 else space kernels = [] @@ -34,30 +35,31 @@ class JITTuner: assert isinstance(tuned_keys, dict) full_keys = copy.deepcopy(keys) full_keys.update(tuned_keys) - code = generate(includes, arg_defs, cpp_format(template, full_keys)) - - # Illegal build must raise errors - kernels.append((build(name, arg_defs, code), tuned_keys)) + code = generate(**kwargs, **full_keys) + kernels.append((build(name, code), full_keys)) best_runtime, best_time, best_keys = None, None, None for runtime, tuned_keys in kernels: if len(space) > 1: # Check kernel validity - return_code = runtime(*args) - if return_code != 0: + return_code = runtime(**tuned_keys, **kwargs) + if return_code != cuda.CUresult.CUDA_SUCCESS: # Pass illegal kernels, e.g. insufficient shared memory capacity if os.getenv('DG_JIT_DEBUG', None): - print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') + print( + f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') continue # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() - torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') + torch.empty(int(256e6 // 4), dtype=torch.int, + device='cuda').zero_() + torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn( + (8192, 8192), dtype=torch.float, device='cuda') start_event.record() for i in range(20): - assert runtime(*args) == 0 + assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS end_event.record() end_event.synchronize() elapsed_time = start_event.elapsed_time(end_event) @@ -68,14 +70,16 @@ class JITTuner: if best_time is None or elapsed_time < best_time: best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys if os.getenv('DG_JIT_DEBUG', None): - print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') + print( + f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' # Cache the best runtime and return if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): - print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') - self.tuned[signature] = best_runtime - return best_runtime + print( + f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') + self.tuned[signature] = (best_runtime, best_keys) + return best_runtime, best_keys jit_tuner = JITTuner() diff --git a/tests/test_jit.py b/tests/test_jit.py index 78bc77b..cced2d6 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,64 +1,116 @@ +import ctypes import os import torch -from typing import Any +from typing import Any, Dict + +import cuda.bindings.driver as cuda from deep_gemm import jit -class Capture: - def __init__(self) -> None: - self.read_fd = None - self.write_fd = None - self.saved_stdout = None - self.captured = None +def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult: + assert a.shape == b.shape == c.shape + assert a.device == b.device == c.device + assert a.dim() == 1 - def __enter__(self) -> Any: - self.read_fd, self.write_fd = os.pipe() - self.saved_stdout = os.dup(1) - os.dup2(self.write_fd, 1) - return self + n = a.numel() - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - os.dup2(self.saved_stdout, 1) - os.close(self.write_fd) - with os.fdopen(self.read_fd, 'r') as f: - self.captured = f.read() + config = cuda.CUlaunchConfig() + config.gridDimX = (n + 127) // 128 + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = 128 + config.blockDimY = 1 + config.blockDimZ = 1 + config.hStream = stream - def capture(self) -> str: - return self.captured + kernelValues = ( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + n, + ) + kernelTypes = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ) + + return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0] + + +def generate_vector_add(**kwargs: Dict[str, Any]) -> str: + return f""" +#ifdef __CUDACC_RTC__ +#ifndef NVRTC_JIT_COMPILATION +#define NVRTC_JIT_COMPILATION +#endif +#include +#else +#include +#endif + +#include +#include + +template +__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{ + uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < N) {{ + c[i] = a[i] + b[i]; + }} +}} + +__global__ void dummy_kernel() {{ + void *ptr = (void *)&vector_add<{kwargs['T']}>; +}} +""" + + +class VectorAddRuntime(jit.Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, 'vector_add', run_vector_add, [ + 'A', + 'B', + 'C', + 'STREAM', + ]) if __name__ == '__main__': - # Runtime - print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') - - # Templates + # NVCC + print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n') print('Generated code:') - args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), - ('enable_double_streams', bool), ('stream', torch.cuda.Stream)) - body = "\n" - body += 'std::cout << reinterpret_cast(lhs) << std::endl;\n' - body += 'std::cout << reinterpret_cast(rhs) << std::endl;\n' - body += 'std::cout << reinterpret_cast(scale) << std::endl;\n' - body += 'std::cout << reinterpret_cast(out) << std::endl;\n' - body += 'std::cout << enable_double_streams << std::endl;\n' - body += 'std::cout << reinterpret_cast(stream) << std::endl;\n' - code = jit.generate((), args, body) + code = generate_vector_add(T='float') print(code) - - # Build print('Building ...') - func = jit.build('test_func', args, code) + func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime) - # Test correctness - print('Running ...') - fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') - fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') - bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') - with Capture() as capture: - assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 - output = capture.capture() - ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' - assert output == ref_output, f'{output=}, {ref_output=}' + a = torch.randn((1024, ), dtype=torch.float32, device='cuda') + b = torch.randn((1024, ), dtype=torch.float32, device='cuda') + c = torch.empty_like(a) + ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) + assert ret == cuda.CUresult.CUDA_SUCCESS, ret + ref_output = a + b + torch.testing.assert_close(c, ref_output) - print('JIT test passed') + print('JIT test for NVCC passed\n') + + # NVRTC + print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n') + print('Generated code:') + code = generate_vector_add(T='__nv_bfloat16') + print(code) + print('Building ...') + func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime) + + a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') + b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') + c = torch.empty_like(a) + ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) + assert ret == cuda.CUresult.CUDA_SUCCESS, ret + ref_output = a + b + torch.testing.assert_close(c, ref_output) + + print('JIT test for NVRTC passed') \ No newline at end of file