From 981cc58932c3d5fbd320ba72581563223df4d364 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 6 May 2025 17:23:35 +0800 Subject: [PATCH] Some lints and refactor --- README.md | 4 +- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 3 +- deep_gemm/include/deep_gemm/mma_utils.cuh | 2 +- deep_gemm/include/deep_gemm/nvrtc_std.cuh | 16 +- deep_gemm/include/deep_gemm/scheduler.cuh | 5 +- deep_gemm/include/deep_gemm/tma_utils.cuh | 10 +- deep_gemm/include/deep_gemm/utils.cuh | 35 +-- deep_gemm/jit/__init__.py | 3 +- deep_gemm/jit/compiler.py | 119 +++++----- deep_gemm/jit/interleave_ffma.py | 6 +- deep_gemm/jit/runtime.py | 38 +--- deep_gemm/jit/template.py | 51 ----- deep_gemm/jit/utils.py | 164 -------------- deep_gemm/jit_kernels/gemm.py | 22 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 25 ++- deep_gemm/jit_kernels/runtime.py | 256 ++++++++++++++++++++++ deep_gemm/jit_kernels/tuner.py | 40 ++-- tests/test_jit.py | 71 +++--- 18 files changed, 421 insertions(+), 449 deletions(-) delete mode 100644 deep_gemm/jit/template.py delete mode 100644 deep_gemm/jit/utils.py create mode 100644 deep_gemm/jit_kernels/runtime.py diff --git a/README.md b/README.md index 07bf7f1..6d0689f 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] Fix TMA multicast compatibility for indivisible shapes - [ ] Skip useless computation on M - [x] NVRTC as a faster compiler +- [ ] Fully remove NVCC compilation - [ ] Sanitizer for testing - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models @@ -105,12 +106,13 @@ The library provides some utility functions besides the above kernels: The library also provides some environment variables, which may be useful: - `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default +- `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default - `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default - `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler - `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization - `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output - `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details -- `DG_JIT_PRINT_NVCC_COMMAND`: 0 or 1, print NVCC compilation command +- `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command - `DG_JIT_DEBUG`: 0 or 1, print more debugging information For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 56c9073..c57691b 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -84,8 +84,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_id(); - // Prefetch TMA descriptors at very beginning + // Prefetch TMA descriptors at the very beginning if (threadIdx.x == kNumMathThreads) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index f07e540..c6c7e28 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#ifndef NVRTC_JIT_COMPILATION +#ifndef __CUDACC_RTC__ #include #endif diff --git a/deep_gemm/include/deep_gemm/nvrtc_std.cuh b/deep_gemm/include/deep_gemm/nvrtc_std.cuh index d9f52dc..00ce734 100644 --- a/deep_gemm/include/deep_gemm/nvrtc_std.cuh +++ b/deep_gemm/include/deep_gemm/nvrtc_std.cuh @@ -17,7 +17,7 @@ #pragma once -#ifdef NVRTC_JIT_COMPILATION +#ifdef __CUDACC_RTC__ using int8_t = signed char; using uint8_t = unsigned char; @@ -32,8 +32,7 @@ using cuuint64_t = unsigned long long; #ifndef CU_TENSOR_MAP_NUM_QWORDS #define CU_TENSOR_MAP_NUM_QWORDS 16 -struct CUtensorMap_st -{ +struct CUtensorMap_st { #if defined(__cplusplus) && (__cplusplus >= 201103L) alignas(64) #elif __STDC_VERSION__ >= 201112L @@ -46,16 +45,16 @@ using CUtensorMap = CUtensorMap_st; #endif namespace std { + template struct integral_constant { static constexpr T value = v; + using value_type = T; - using type = integral_constant; // using injected-class-name + using type = integral_constant; __device__ constexpr operator value_type() const noexcept { return value; } - __device__ constexpr value_type operator()() const noexcept { - return value; - } // since c++14 + __device__ constexpr value_type operator()() const noexcept { return value; } }; using false_type = integral_constant; @@ -69,6 +68,7 @@ template inline constexpr bool is_same_v = is_same::value; namespace index_sequence_impl { + // Based on https://stackoverflow.com/a/32223343/11717224 template struct index_sequence { using type = index_sequence; @@ -89,6 +89,7 @@ struct make_index_sequence template <> struct make_index_sequence<0> : index_sequence<> {}; template <> struct make_index_sequence<1> : index_sequence<0> {}; + } // namespace index_sequence_impl template @@ -96,6 +97,7 @@ using index_sequence = index_sequence_impl::index_sequence; template using make_index_sequence = index_sequence_impl::make_index_sequence; + } // namespace std #endif diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 9743871..c213d57 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -46,7 +46,7 @@ struct Scheduler { } } - __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) { + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { if (num_blocks_in_group == 1) return false; if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { @@ -63,7 +63,8 @@ struct Scheduler { } } - __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, + uint32_t& m_block_idx, uint32_t& n_block_idx) { DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index 6b8ebda..795dca6 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -1,18 +1,10 @@ #pragma once -#ifndef NVRTC_JIT_COMPILATION -#include -#include -#include -#include -#endif - -#include - #include "utils.cuh" namespace deep_gemm { +// TODO: move this function to other files __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 8edf35b..598a414 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -1,39 +1,14 @@ #pragma once -#ifndef NVRTC_JIT_COMPILATION -#include - #ifdef __CLION_IDE__ -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + #define printf host_device_printf #endif -class AssertionException : public std::exception { -private: - std::string message{}; - -public: - explicit AssertionException(const std::string& message) : message(message) {} - - const char *what() const noexcept override { return message.c_str(); } -}; -#endif - -#ifndef DG_HOST_ASSERT -#ifdef NVRTC_JIT_COMPILATION -#define DG_HOST_ASSERT(cond) ((void)0) -#else -#define DG_HOST_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", \ - __FILE__, __LINE__, #cond); \ - throw AssertionException("Assertion failed: " #cond); \ - } \ -} while (0) -#endif -#endif - #ifndef DG_DEVICE_ASSERT #define DG_DEVICE_ASSERT(cond) \ do { \ diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py index 8e1ba3a..06a5194 100644 --- a/deep_gemm/jit/__init__.py +++ b/deep_gemm/jit/__init__.py @@ -1,3 +1,2 @@ -from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler -from .template import generate +from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 36a3361..238c0d1 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -1,4 +1,3 @@ -import abc import functools import hashlib import os @@ -14,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME from . import interleave_ffma -from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache +from .runtime import Runtime, RuntimeCache runtime_cache = RuntimeCache() @@ -32,11 +31,11 @@ def get_jit_include_dir() -> str: @functools.lru_cache(maxsize=None) def get_deep_gemm_version() -> str: + md5 = hashlib.md5() + # Update include directories include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') - assert os.path.exists( - include_dir), f'Cannot find GEMM include directory {include_dir}' - md5 = hashlib.md5() + assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): with open(os.path.join(include_dir, filename), 'rb') as f: md5.update(f.read()) @@ -98,24 +97,20 @@ def make_tmp_dir(): def put(path, data): - is_binary = isinstance(data, bytes) - # Write and do POSIX atomic replace tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') - with open(tmp_file_path, 'wb' if is_binary else 'w') as f: + with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: f.write(data) os.replace(tmp_file_path, path) -class Compiler(abc.ABC): +class Compiler: @staticmethod - @abc.abstractmethod def __version__() -> Tuple[int, int]: pass @classmethod - @abc.abstractmethod - def compile(cls, name: str, code: str, target_path: str) -> str: + def compile(cls, name: str, code: str, target_path: str) -> None: pass @staticmethod @@ -132,13 +127,12 @@ class Compiler(abc.ABC): return [get_jit_include_dir()] @classmethod - def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: + def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: # Compiler flags flags = cls.flags() # Build signature - enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int( - os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 + enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' name = f'kernel.{name}.{hash_to_hex(signature)}' path = os.path.join(get_cache_dir(), name) @@ -147,7 +141,7 @@ class Compiler(abc.ABC): global runtime_cache cached_runtime = runtime_cache.get(path, runtime_cls) if cached_runtime is not None: - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Using cached JIT runtime {name} during build') return cached_runtime @@ -160,9 +154,8 @@ class Compiler(abc.ABC): cls.compile(name, code, tmp_cubin_path) end_time = time.time() elapsed_time = end_time - start_time - if os.getenv('DG_JIT_DEBUG', None): - print( - f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') # Interleave FFMA reuse if enable_sass_opt: @@ -177,12 +170,12 @@ class Compiler(abc.ABC): return runtime -class NvccCompiler(Compiler): +class NVCCCompiler(Compiler): @staticmethod def __version__() -> Tuple[int, int]: _, version = get_nvcc_compiler() major, minor = map(int, version.split('.')) - return (major, minor) + return major, minor @classmethod def flags(cls) -> List[str]: @@ -197,7 +190,7 @@ class NvccCompiler(Compiler): f'--compiler-options={",".join(cxx_flags)}'] @classmethod - def compile(cls, name: str, code: str, target_path: str): + def compile(cls, name: str, code: str, target_path: str) -> None: # Write the code path = os.path.join(get_cache_dir(), name) src_path = os.path.join(path, 'kernel.cu') @@ -205,26 +198,23 @@ class NvccCompiler(Compiler): command = [get_nvcc_compiler()[0], src_path, '-o', target_path, *cls.flags()] - if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): + if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): print(f'Compiling JIT runtime {name} with command {command}') - result = subprocess.run(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, text=True) - if os.getenv('DG_JIT_DEBUG', None): - print(result.stdout) - print(result.stderr) - - assert result.returncode == 0, f'Failed to compile {src_path}' + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') + assert False, f'Failed to compile {src_path}' -class NvrtcCompiler(Compiler): +class NVRTCCompiler(Compiler): @staticmethod def __version__() -> Tuple[int, int]: res, major, minor = nvrtc.nvrtcVersion() if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - # Failed to get actual NVRTC version, use bindings version instead + # Failed to get the actual NVRTC version, use cuda-bindings version instead major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) - return (major, minor) + return major, minor @staticmethod def include_dirs() -> List[str]: @@ -238,54 +228,51 @@ class NvrtcCompiler(Compiler): '--gpu-architecture=sm_90a', '-default-device'] if cls.__version__() >= (12, 8): base_flags += ['--pch'] - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): base_flags += ['--pch-verbose=true'] return base_flags @classmethod - def compile(cls, name: str, code: str, target_path: str) -> str: + def compile(cls, name: str, code: str, target_path: str) -> None: + # Create program code_bytes = bytes(code, 'utf-8') - res, program = nvrtc.nvrtcCreateProgram( + result, program = nvrtc.nvrtcCreateProgram( code_bytes, bytes(name, 'utf-8'), 0, [], []) - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to create program: {res}') + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' + # Compile options = [bytes(flag, 'utf-8') for flag in cls.flags()] - compile_res = nvrtc.nvrtcCompileProgram( - program, len(options), options)[0] + if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): + print(f'Compiling JIT runtime {name} with options: {options}') + compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] + + # Print compiler log + if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + result, log_size = nvrtc.nvrtcGetProgramLogSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' - if os.getenv('DG_JIT_DEBUG', None): - res, log_size = nvrtc.nvrtcGetProgramLogSize(program) - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to get program log size: {res}') log_bytes = bytes(log_size) - res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to get program log: {res}') - log_str = log_bytes.decode('utf-8') - print(log_str) + result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' + print(f'Compiler log: {log_bytes.decode("utf-8")}') - if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to compile program: {compile_res}') - - res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to get CUBIN size: {res}') + # Exit if failed + assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' + # Create CUBIN + result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' cubin_bytes = bytes(cubin_size) - res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to get CUBIN: {res}') + result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] + assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' + # Write into the file system put(target_path, cubin_bytes) - res = nvrtc.nvrtcDestroyProgram(program)[0] - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f'Failed to destroy program: {res}') + # Destroy handler + assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' -def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: - if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: - return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls) - else: - return NvccCompiler.build(name, code, runtime_cls=runtime_cls) +def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: + compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler + return compiler_cls.build(name, code, runtime_cls=runtime_cls) diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py index fcb377e..12baa0d 100644 --- a/deep_gemm/jit/interleave_ffma.py +++ b/deep_gemm/jit/interleave_ffma.py @@ -37,7 +37,7 @@ def extract_ffma(sass): collected.append((f'{arch_name}::{func_name}', current)) current = [] - if os.getenv('DG_PRINT_REG_REUSE', None): + if int(os.getenv('DG_PRINT_REG_REUSE', 0)): print(f'Found {len(collected)} FFMA segments') return collected @@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines): dst_reg_set.add(dst_reg) new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) last_reused, last_dst_reg = reused, dst_reg - if os.getenv('DG_PRINT_REG_REUSE', None): + if int(os.getenv('DG_PRINT_REG_REUSE', 0)): print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') # Find the offset @@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines): def process(path): - if os.getenv('DG_PRINT_REG_REUSE', None): + if int(os.getenv('DG_PRINT_REG_REUSE', 0)): print(f'Processing {path}') output = run_cuobjdump(path) segments = extract_ffma(output) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index e5f7bfb..012f4da 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, List, Optional, Type import cuda.bindings.driver as cuda -from .utils import run_gemm - class Runtime: - def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None: + def __init__(self, path: str, kernel_name: str = None, + caller: Callable[..., cuda.CUresult] = None, + args: List[str] = None) -> None: self.path = path self.lib = None self.kernel = None @@ -27,7 +27,7 @@ class Runtime: files = ['kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) - def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: + def __call__(self, **kwargs) -> cuda.CUresult: # Load CUBIN if self.kernel is None: start_time = time.time_ns() @@ -59,9 +59,8 @@ class Runtime: end_time = time.time_ns() elapsed_time = (end_time - start_time) / 1000 - if os.getenv('DG_JIT_DEBUG', None): - print( - f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') return self.caller( self.kernel, @@ -75,25 +74,6 @@ class Runtime: raise Exception(f'Failed to unload library {self.path}: {res}') -class Fp8GemmRuntime(Runtime): - def __init__(self, path: str) -> None: - super().__init__(path, 'fp8_gemm', run_gemm, [ - 'NUM_TMA_MULTICAST', - 'M', - 'BLOCK_M', - 'GMEM_D', - 'SCALES_B', - 'GROUPED_LAYOUT', - 'NUM_SMS', - 'SMEM_SIZE', - 'TENSOR_MAP_A', - 'TENSOR_MAP_B', - 'TENSOR_MAP_SCALES_A', - 'TENSOR_MAP_D', - 'STREAM', - ]) - - class RuntimeCache: def __init__(self) -> None: self.cache = {} @@ -101,14 +81,14 @@ class RuntimeCache: def __setitem__(self, path, runtime) -> None: self.cache[path] = runtime - def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]: + def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled - if os.path.exists(path) and Runtime.is_path_valid(path): + if not int(os.getenv('DG_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): runtime = runtime_cls(path) self.cache[path] = runtime return runtime - return None \ No newline at end of file + return None diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py deleted file mode 100644 index 461691f..0000000 --- a/deep_gemm/jit/template.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -from typing import Any, Dict - - -def generate(**kwargs: Dict[str, Any]) -> str: - code = f''' -#ifdef __CUDACC_RTC__ -#ifndef NVRTC_JIT_COMPILATION -#define NVRTC_JIT_COMPILATION -#endif - -#include - -#else - -#include -#include - -#endif - -#include -#include -#include - -using namespace deep_gemm; - -__global__ void dummy_kernel() {{ - void *ptr = (void *)&fp8_gemm_kernel< - {kwargs['N']}, - {kwargs['K']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['BLOCK_N_PADDING']}, - {kwargs['SWIZZLE_D_MODE']}, - {kwargs['NUM_GROUPS']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, - GemmType::{kwargs['GEMM_TYPE']} - >; -}} -''' - - # Debug print - if os.getenv('DG_JIT_DEBUG', None): - print(f'Generated code:\n{code}') - - return code diff --git a/deep_gemm/jit/utils.py b/deep_gemm/jit/utils.py deleted file mode 100644 index 1321f24..0000000 --- a/deep_gemm/jit/utils.py +++ /dev/null @@ -1,164 +0,0 @@ -import ctypes -from enum import Enum -from typing import Any, Dict, Tuple - -import cuda.bindings.driver as cuda -import torch - - -class Layout(Enum): - RowMajor = 0 - ColMajor = 1 - - -class GemmType(Enum): - Normal = 0 - GroupedContiguous = 1 - GroupedMasked = 2 - - def __str__(self) -> str: - return { - 0: 'Normal', - 1: 'GroupedContiguous', - 2: 'GroupedMasked', - }[self.value] - - -typename_map: Dict[Any, str] = { - torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, - torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, - torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, - torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, - torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, - torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, - torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, -} - -swizzle_map = { - 128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, - 64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, - 32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, - 0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, -} - -def get_num_math_warpgroups(block_m: int) -> int: - return 1 if block_m == 64 else 2 - -def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: - assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' - return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads - - -def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap: - tensor_dtype = typename_map[global_address.dtype] - res, tensor_map = cuda.cuTensorMapEncodeTiled( - tensor_dtype, - 2, # tensor rank - global_address.data_ptr(), - gmem_dim, - (stride_in_bytes,), # global strides - smem_dim, - (cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides - cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle_type, - cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, - ) - - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to encode tensor map: {res}') - - return tensor_map - - -def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap: - if layout == Layout.RowMajor: - gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows)) - smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows)) - return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) - else: - gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols)) - smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols)) - return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) - - -def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap: - return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k) - - -def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap: - return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n) - - -def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap: - # Swizzling requires the inner box dim less or equal than `kSwizzleDMode` - # bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode]) - - -def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap: - # Make TMA aligned to 16 bytes - kAlignment = 16 / global_address.element_size() - shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment - - return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) - - -def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult: - num_tma_threads = 128 - num_math_threads_per_group = 128 - - res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0] - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to set max dynamic shared memory size: {res}') - - attr_val = cuda.CUlaunchAttributeValue() - attr_val.clusterDim.x = num_tma_multicast - attr_val.clusterDim.y = 1 - attr_val.clusterDim.z = 1 - attr = cuda.CUlaunchAttribute() - attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value = attr_val - - config = cuda.CUlaunchConfig() - config.numAttrs = 1 - config.attrs = [attr] - config.gridDimX = num_sms - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) - config.blockDimY = 1 - config.blockDimZ = 1 - config.sharedMemBytes = smem_size - config.hStream = stream - - kernelValues = ( - gmem_d.data_ptr(), - scales_b.data_ptr(), - grouped_layout.data_ptr(), - shape_m, - tensor_map_a, - tensor_map_b, - tensor_map_scales_a, - tensor_map_d, - ) - kernelTypes = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - None, - None, - None, - None, - ) - - return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index f27b025..825f313 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -3,8 +3,8 @@ import torch from functools import lru_cache from typing import Tuple -from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc - +from .runtime import FP8GemmRuntime, generate +from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc from .tuner import jit_tuner from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout @@ -122,7 +122,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, assert best_smem_config is not None assert best_num_stages is not None - # Decide the number of TMA multicast and whether broadcast on A + # Decide the number of TMA multicasts and whether broadcast on A best_tma_multicast_config = (1, True) # Try to multicast on the larger block side first @@ -155,13 +155,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[m, n]`, representing the result. """ @@ -183,7 +183,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], assert out.dtype == torch.bfloat16 assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() - # LHS scales must be transposed for TMA load, but not for RHS scales + # LHS scales must be transposed for TMA loads, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert rhs_scales.is_contiguous() @@ -201,11 +201,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], num_math_threads_per_group = 128 tensor_map_a = make_2d_tma_a_desc( - GemmType.Normal, lhs, m, k, block_m, block_k) + GemmType.Normal, lhs, m, k, block_m, block_k, 1) tensor_map_b = make_2d_tma_b_desc( - GemmType.Normal, rhs, k, n, block_k, block_n) + GemmType.Normal, rhs, k, n, block_k, block_n, 1) tensor_map_d = make_2d_tma_d_desc( - GemmType.Normal, smem_config[1], out, m, n, block_m, block_n) + GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1]) tensor_map_scales_a = make_2d_tma_scales_a_desc( GemmType.Normal, lhs_scales, m, k, block_m, block_k) @@ -237,7 +237,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), - kwargs=kwargs + kwargs=kwargs, + generator=generate, + runtime_cls=FP8GemmRuntime, ) # Run the kernel diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 22281b1..3d12ff6 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,9 +1,9 @@ import torch from typing import Tuple -from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc - -from .gemm import get_best_configs, get_block_n_padding_for_smem_d +from .gemm import get_best_configs +from .runtime import FP8GemmRuntime, generate +from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms @@ -15,7 +15,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. On the M axis, inputs are grouped into several batches, of which batch sizes aligned to `get_m_alignment_for_contiguous_layout()` (128). @@ -23,11 +23,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`, the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. m_indices: a tensor of shape `[m_sum]` with type `torch.int`. - `m_indices[i]` records the group which the i-th row of the LHS belong to, + `m_indices[i]` records the group which the i-th row of the LHS belongs to, which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. Values of `m_indices` in every-m-alignment-block must also be the same. """ @@ -70,7 +70,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten tensor_map_b = make_2d_tma_b_desc( GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) tensor_map_d = make_2d_tma_d_desc( - GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups) + GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1]) tensor_map_scales_a = make_2d_tma_scales_a_desc( GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) @@ -103,6 +103,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten 'GEMM_TYPE': GemmType.GroupedContiguous}, space=(), kwargs=kwargs, + generator=generate, + runtime_cls=FP8GemmRuntime, ) # Run the kernel @@ -116,7 +118,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch should be separately transposed. @@ -125,7 +127,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. - the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. + The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute in the i-th group. @@ -157,7 +159,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] assert rhs_scales.is_contiguous() # Auto-tuning with compilation - global includes, template num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) @@ -175,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] tensor_map_b = make_2d_tma_b_desc( GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) tensor_map_d = make_2d_tma_d_desc( - GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups) + GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1]) tensor_map_scales_a = make_2d_tma_scales_a_desc( GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) @@ -208,6 +209,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] 'GEMM_TYPE': GemmType.GroupedMasked}, space=(), kwargs=kwargs, + generator=generate, + runtime_cls=FP8GemmRuntime, ) # Run the kernel diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py new file mode 100644 index 0000000..225b27a --- /dev/null +++ b/deep_gemm/jit_kernels/runtime.py @@ -0,0 +1,256 @@ +import ctypes +import os +import enum +import torch +import cuda.bindings.driver as cbd +from typing import Any, Dict, Tuple + +from ..jit.runtime import Runtime + + +def generate(**kwargs: Dict[str, Any]) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include +#include + +#include + +using namespace deep_gemm; + +__global__ void dummy_kernel() {{ + void *ptr = (void *)&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >; +}} +''' + + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Generated code:\n{code}') + return code + + +class FP8GemmRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, 'fp8_gemm', launch, [ + 'NUM_TMA_MULTICAST', + 'M', + 'BLOCK_M', + 'GMEM_D', + 'SCALES_B', + 'GROUPED_LAYOUT', + 'NUM_SMS', + 'SMEM_SIZE', + 'TENSOR_MAP_A', + 'TENSOR_MAP_B', + 'TENSOR_MAP_SCALES_A', + 'TENSOR_MAP_D', + 'STREAM', + ]) + + +class Layout(enum.Enum): + RowMajor = 0 + ColMajor = 1 + + +class GemmType(enum.Enum): + Normal = 0 + GroupedContiguous = 1 + GroupedMasked = 2 + + def __str__(self) -> str: + return { + 0: 'Normal', + 1: 'GroupedContiguous', + 2: 'GroupedMasked', + }[self.value] + + +tmap_type_map: Dict[Any, str] = { + torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, + torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, + torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, + torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, + torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +} + +swizzle_type_map = { + 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, + 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, + 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, + 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, +} + + +def get_num_math_warpgroups(block_m: int) -> int: + return 1 if block_m == 64 else 2 + + +def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: + assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' + return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads + + +def make_2d_tma_copy_desc(global_address: torch.Tensor, + gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], + stride_in_bytes: cbd.cuuint64_t, + smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], + swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: + tensor_dtype = tmap_type_map[global_address.dtype] + res, tensor_map = cbd.cuTensorMapEncodeTiled( + tensor_dtype, + 2, + global_address.data_ptr(), + gmem_dim, + (stride_in_bytes, ), + smem_dim, + (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), + cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle_type, + cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to encode tensor map: {res}') + return tensor_map + + +def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, + gmem_rows: int, gmem_cols: int, + smem_rows: int, smem_cols: int, + swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: + if layout == Layout.RowMajor: + gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows)) + smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) + else: + gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols)) + smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols)) + return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) + + +def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, + shape_m: int, shape_k: int, + block_m: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.RowMajor, + shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, + block_m, block_k) + + +def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, + shape_k: int, shape_n: int, + block_k: int, block_n: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(global_address, Layout.ColMajor, + shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), + block_k, block_n) + + +def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor, + shape_m: int, shape_n: int, + block_m: int, block_n: int, + num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap: + # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` + # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(global_address, Layout.RowMajor, + shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, + block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), + swizzle_type_map[swizzle_mode]) + + +def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: + # Make TMA aligned to 16 bytes + tma_alignment = 16 / global_address.element_size() + shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment + + return make_2d_tma_desc(global_address, Layout.ColMajor, + shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), + block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + +def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, + block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, + grouped_layout: torch.Tensor, num_sms: int, smem_size: int, + tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, + tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap, + stream: cbd.CUstream) -> cbd.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to set max dynamic shared memory size: {res}') + + attr_val = cbd.CUlaunchAttributeValue() + attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cbd.CUlaunchAttribute() + attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cbd.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = num_sms + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = smem_size + config.hStream = stream + + arg_values = ( + gmem_d.data_ptr(), + scales_b.data_ptr(), + grouped_layout.data_ptr(), + shape_m, + tensor_map_a, + tensor_map_b, + tensor_map_scales_a, + tensor_map_d, + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + None, + None, + None, + None, + ) + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py index 4dcd6cd..18245f8 100644 --- a/deep_gemm/jit_kernels/tuner.py +++ b/deep_gemm/jit_kernels/tuner.py @@ -1,29 +1,28 @@ import copy import os import torch -from typing import Any, Dict +import cuda.bindings.driver as cbd +from typing import Any, Callable, Dict, Type, Tuple -import cuda.bindings.driver as cuda - -from ..jit import build, generate, Runtime +from ..jit import build, Runtime class JITTuner: def __init__(self) -> None: self.tuned = {} - def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime: - # NOTES: we always assume the space and template will not change - # We also assume the GPU device will not be changed + def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, + kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: + # NOTES: we always assume the space, template and GPU devices will not change # NOTES: the function must have no accumulated side effects keys = {k: keys[k] for k in sorted(keys.keys())} signature = (name, f'{keys}') if signature in self.tuned: - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Using cached JIT kernel {name} with keys {keys}') return self.tuned[signature] - if os.getenv('DG_JIT_DEBUG', None): + if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Auto-tuning JIT kernel {name} with keys {keys}') assert signature not in self.tuned @@ -35,19 +34,19 @@ class JITTuner: assert isinstance(tuned_keys, dict) full_keys = copy.deepcopy(keys) full_keys.update(tuned_keys) - code = generate(**kwargs, **full_keys) - kernels.append((build(name, code), full_keys)) + code = generator(**kwargs, **full_keys) + kernels.append((build(name, code, runtime_cls), full_keys)) + # TODO: fix tuning with space > 1 best_runtime, best_time, best_keys = None, None, None for runtime, tuned_keys in kernels: if len(space) > 1: # Check kernel validity return_code = runtime(**tuned_keys, **kwargs) - if return_code != cuda.CUresult.CUDA_SUCCESS: - # Pass illegal kernels, e.g. insufficient shared memory capacity - if os.getenv('DG_JIT_DEBUG', None): - print( - f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') + if return_code != cbd.CUresult.CUDA_SUCCESS: + # Pass illegal kernels, e.g., insufficient shared memory capacity + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') continue # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels @@ -59,7 +58,7 @@ class JITTuner: (8192, 8192), dtype=torch.float, device='cuda') start_event.record() for i in range(20): - assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS + assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS end_event.record() end_event.synchronize() elapsed_time = start_event.elapsed_time(end_event) @@ -69,13 +68,12 @@ class JITTuner: # Compare if better if best_time is None or elapsed_time < best_time: best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys - if os.getenv('DG_JIT_DEBUG', None): - print( - f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' # Cache the best runtime and return - if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): + if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)): print( f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') self.tuned[signature] = (best_runtime, best_keys) diff --git a/tests/test_jit.py b/tests/test_jit.py index cced2d6..66c4fcf 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,14 +1,19 @@ import ctypes import os import torch -from typing import Any, Dict - import cuda.bindings.driver as cuda from deep_gemm import jit +# Essential debugging staffs +os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') +os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1') -def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult: + +# noinspection PyShadowingNames +def launch_vector_add(kernel: cuda.CUkernel, + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, + stream: cuda.CUstream) -> cuda.CUresult: assert a.shape == b.shape == c.shape assert a.device == b.device == c.device assert a.dim() == 1 @@ -24,28 +29,25 @@ def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: t config.blockDimZ = 1 config.hStream = stream - kernelValues = ( + arg_values = ( a.data_ptr(), b.data_ptr(), c.data_ptr(), n, ) - kernelTypes = ( + arg_types = ( ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint32, ) - return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0] + return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] -def generate_vector_add(**kwargs: Dict[str, Any]) -> str: +def generate_vector_add(**kwargs) -> str: return f""" #ifdef __CUDACC_RTC__ -#ifndef NVRTC_JIT_COMPILATION -#define NVRTC_JIT_COMPILATION -#endif #include #else #include @@ -63,14 +65,14 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{ }} __global__ void dummy_kernel() {{ - void *ptr = (void *)&vector_add<{kwargs['T']}>; + auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); }} """ class VectorAddRuntime(jit.Runtime): def __init__(self, path: str) -> None: - super().__init__(path, 'vector_add', run_vector_add, [ + super().__init__(path, 'vector_add', launch_vector_add, [ 'A', 'B', 'C', @@ -79,38 +81,25 @@ class VectorAddRuntime(jit.Runtime): if __name__ == '__main__': - # NVCC - print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n') print('Generated code:') code = generate_vector_add(T='float') print(code) - print('Building ...') - func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime) + print() - a = torch.randn((1024, ), dtype=torch.float32, device='cuda') - b = torch.randn((1024, ), dtype=torch.float32, device='cuda') - c = torch.empty_like(a) - ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) - assert ret == cuda.CUresult.CUDA_SUCCESS, ret - ref_output = a + b - torch.testing.assert_close(c, ref_output) + for compiler_name in ('NVCC', 'NVRTC'): + # Get compiler + compiler_cls = getattr(jit, f'{compiler_name}Compiler') + print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}') - print('JIT test for NVCC passed\n') + # Build + print('Building ...') + func = compiler_cls.build('test_func', code, VectorAddRuntime) - # NVRTC - print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n') - print('Generated code:') - code = generate_vector_add(T='__nv_bfloat16') - print(code) - print('Building ...') - func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime) - - a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') - b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') - c = torch.empty_like(a) - ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) - assert ret == cuda.CUresult.CUDA_SUCCESS, ret - ref_output = a + b - torch.testing.assert_close(c, ref_output) - - print('JIT test for NVRTC passed') \ No newline at end of file + # Run and check + a = torch.randn((1024, ), dtype=torch.float32, device='cuda') + b = torch.randn((1024, ), dtype=torch.float32, device='cuda') + c = torch.empty_like(a) + ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) + assert ret == cuda.CUresult.CUDA_SUCCESS, ret + torch.testing.assert_close(c, a + b) + print(f'JIT test for {compiler_name} passed\n')