Some lints and refactor

This commit is contained in:
Chenggang Zhao
2025-05-06 17:23:35 +08:00
parent 8aff6309d4
commit 981cc58932
18 changed files with 421 additions and 449 deletions

View File

@@ -19,6 +19,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] Fix TMA multicast compatibility for indivisible shapes
- [ ] Skip useless computation on M
- [x] NVRTC as a faster compiler
- [ ] Fully remove NVCC compilation
- [ ] Sanitizer for testing
- [ ] Weight gradient kernels for dense models
- [ ] Weight gradient kernels for MoE models
@@ -105,12 +106,13 @@ The library provides some utility functions besides the above kernels:
The library also provides some environment variables, which may be useful:
- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default
- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default
- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization
- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output
- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details
- `DG_JIT_PRINT_NVCC_COMMAND`: 0 or 1, print NVCC compilation command
- `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command
- `DG_JIT_DEBUG`: 0 or 1, print more debugging information
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.

View File

@@ -84,8 +84,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));

View File

@@ -1,6 +1,6 @@
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#ifndef __CUDACC_RTC__
#include <cuda.h>
#endif

View File

@@ -17,7 +17,7 @@
#pragma once
#ifdef NVRTC_JIT_COMPILATION
#ifdef __CUDACC_RTC__
using int8_t = signed char;
using uint8_t = unsigned char;
@@ -32,8 +32,7 @@ using cuuint64_t = unsigned long long;
#ifndef CU_TENSOR_MAP_NUM_QWORDS
#define CU_TENSOR_MAP_NUM_QWORDS 16
struct CUtensorMap_st
{
struct CUtensorMap_st {
#if defined(__cplusplus) && (__cplusplus >= 201103L)
alignas(64)
#elif __STDC_VERSION__ >= 201112L
@@ -46,16 +45,16 @@ using CUtensorMap = CUtensorMap_st;
#endif
namespace std {
template <class T, T v> struct integral_constant {
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
using type = integral_constant;
__device__ constexpr operator value_type() const noexcept { return value; }
__device__ constexpr value_type operator()() const noexcept {
return value;
} // since c++14
__device__ constexpr value_type operator()() const noexcept { return value; }
};
using false_type = integral_constant<bool, false>;
@@ -69,6 +68,7 @@ template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value;
namespace index_sequence_impl {
// Based on https://stackoverflow.com/a/32223343/11717224
template <size_t... Ints> struct index_sequence {
using type = index_sequence;
@@ -89,6 +89,7 @@ struct make_index_sequence
template <> struct make_index_sequence<0> : index_sequence<> {};
template <> struct make_index_sequence<1> : index_sequence<0> {};
} // namespace index_sequence_impl
template <size_t... Ns>
@@ -96,6 +97,7 @@ using index_sequence = index_sequence_impl::index_sequence<Ns...>;
template <size_t N>
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
} // namespace std
#endif

View File

@@ -46,7 +46,7 @@ struct Scheduler {
}
}
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
@@ -63,7 +63,8 @@ struct Scheduler {
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx,
uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages

View File

@@ -1,18 +1,10 @@
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <cassert>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudaTypedefs.h>
#endif
#include <cuda/barrier>
#include "utils.cuh"
namespace deep_gemm {
// TODO: move this function to other files
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {

View File

@@ -1,39 +1,14 @@
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <exception>
#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
asm volatile("trap;");
}
#define printf host_device_printf
#endif
class AssertionException : public std::exception {
private:
std::string message{};
public:
explicit AssertionException(const std::string& message) : message(message) {}
const char *what() const noexcept override { return message.c_str(); }
};
#endif
#ifndef DG_HOST_ASSERT
#ifdef NVRTC_JIT_COMPILATION
#define DG_HOST_ASSERT(cond) ((void)0)
#else
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", \
__FILE__, __LINE__, #cond); \
throw AssertionException("Assertion failed: " #cond); \
} \
} while (0)
#endif
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \

View File

@@ -1,3 +1,2 @@
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
from .template import generate
from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler
from .runtime import Runtime

View File

@@ -1,4 +1,3 @@
import abc
import functools
import hashlib
import os
@@ -14,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache
from .runtime import Runtime, RuntimeCache
runtime_cache = RuntimeCache()
@@ -32,11 +31,11 @@ def get_jit_include_dir() -> str:
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
md5 = hashlib.md5()
# Update include directories
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(os.path.join(include_dir, filename), 'rb') as f:
md5.update(f.read())
@@ -98,24 +97,20 @@ def make_tmp_dir():
def put(path, data):
is_binary = isinstance(data, bytes)
# Write and do POSIX atomic replace
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
f.write(data)
os.replace(tmp_file_path, path)
class Compiler(abc.ABC):
class Compiler:
@staticmethod
@abc.abstractmethod
def __version__() -> Tuple[int, int]:
pass
@classmethod
@abc.abstractmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str) -> None:
pass
@staticmethod
@@ -132,13 +127,12 @@ class Compiler(abc.ABC):
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
# Compiler flags
flags = cls.flags()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = os.path.join(get_cache_dir(), name)
@@ -147,7 +141,7 @@ class Compiler(abc.ABC):
global runtime_cache
cached_runtime = runtime_cache.get(path, runtime_cls)
if cached_runtime is not None:
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT runtime {name} during build')
return cached_runtime
@@ -160,9 +154,8 @@ class Compiler(abc.ABC):
cls.compile(name, code, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
@@ -177,12 +170,12 @@ class Compiler(abc.ABC):
return runtime
class NvccCompiler(Compiler):
class NVCCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return (major, minor)
return major, minor
@classmethod
def flags(cls) -> List[str]:
@@ -197,7 +190,7 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, code: str, target_path: str):
def compile(cls, name: str, code: str, target_path: str) -> None:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
@@ -205,26 +198,23 @@ class NvccCompiler(Compiler):
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with command {command}')
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
if os.getenv('DG_JIT_DEBUG', None):
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f'Failed to compile {src_path}'
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
assert False, f'Failed to compile {src_path}'
class NvrtcCompiler(Compiler):
class NVRTCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get actual NVRTC version, use bindings version instead
# Failed to get the actual NVRTC version, use cuda-bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return (major, minor)
return major, minor
@staticmethod
def include_dirs() -> List[str]:
@@ -238,54 +228,51 @@ class NvrtcCompiler(Compiler):
'--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8):
base_flags += ['--pch']
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
base_flags += ['--pch-verbose=true']
return base_flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str) -> None:
# Create program
code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram(
result, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to create program: {res}')
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
# Compile
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0]
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with options: {options}')
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
# Print compiler log
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
if os.getenv('DG_JIT_DEBUG', None):
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get program log size: {res}')
log_bytes = bytes(log_size)
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get program log: {res}')
log_str = log_bytes.decode('utf-8')
print(log_str)
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
print(f'Compiler log: {log_bytes.decode("utf-8")}')
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to compile program: {compile_res}')
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get CUBIN size: {res}')
# Exit if failed
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
# Create CUBIN
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
cubin_bytes = bytes(cubin_size)
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get CUBIN: {res}')
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
# Write into the file system
put(target_path, cubin_bytes)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to destroy program: {res}')
# Destroy handler
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls)
else:
return NvccCompiler.build(name, code, runtime_cls=runtime_cls)
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
return compiler_cls.build(name, code, runtime_cls=runtime_cls)

View File

@@ -37,7 +37,7 @@ def extract_ffma(sass):
collected.append((f'{arch_name}::{func_name}', current))
current = []
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
print(f'Found {len(collected)} FFMA segments')
return collected
@@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines):
dst_reg_set.add(dst_reg)
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
# Find the offset
@@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines):
def process(path):
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
print(f'Processing {path}')
output = run_cuobjdump(path)
segments = extract_ffma(output)

View File

@@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda
from .utils import run_gemm
class Runtime:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
def __init__(self, path: str, kernel_name: str = None,
caller: Callable[..., cuda.CUresult] = None,
args: List[str] = None) -> None:
self.path = path
self.lib = None
self.kernel = None
@@ -27,7 +27,7 @@ class Runtime:
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
def __call__(self, **kwargs) -> cuda.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
@@ -59,9 +59,8 @@ class Runtime:
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return self.caller(
self.kernel,
@@ -75,25 +74,6 @@ class Runtime:
raise Exception(f'Failed to unload library {self.path}: {res}')
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
@@ -101,14 +81,14 @@ class RuntimeCache:
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
if not int(os.getenv('DG_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime
return None
return None

View File

@@ -1,51 +0,0 @@
import os
from typing import Any, Dict
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <string>
#include <cuda.h>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
# Debug print
if os.getenv('DG_JIT_DEBUG', None):
print(f'Generated code:\n{code}')
return code

View File

@@ -1,164 +0,0 @@
import ctypes
from enum import Enum
from typing import Any, Dict, Tuple
import cuda.bindings.driver as cuda
import torch
class Layout(Enum):
RowMajor = 0
ColMajor = 1
class GemmType(Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
typename_map: Dict[Any, str] = {
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_map = {
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
tensor_dtype = typename_map[global_address.dtype]
res, tensor_map = cuda.cuTensorMapEncodeTiled(
tensor_dtype,
2, # tensor rank
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes,), # global strides
smem_dim,
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to encode tensor map: {res}')
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Make TMA aligned to 16 bytes
kAlignment = 16 / global_address.element_size()
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
attr_val = cuda.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cuda.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
kernelValues = (
gmem_d.data_ptr(),
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)

View File

@@ -3,8 +3,8 @@ import torch
from functools import lru_cache
from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .runtime import FP8GemmRuntime, generate
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@@ -122,7 +122,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
assert best_smem_config is not None
assert best_num_stages is not None
# Decide the number of TMA multicast and whether broadcast on A
# Decide the number of TMA multicasts and whether broadcast on A
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
@@ -155,13 +155,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
"""
@@ -183,7 +183,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
@@ -201,11 +201,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k)
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n)
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
@@ -237,7 +237,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel

View File

@@ -1,9 +1,9 @@
import torch
from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
from .gemm import get_best_configs
from .runtime import FP8GemmRuntime, generate
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
@@ -15,7 +15,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
@@ -23,11 +23,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`,
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
`m_indices[i]` records the group which the i-th row of the LHS belong to,
`m_indices[i]` records the group which the i-th row of the LHS belongs to,
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
Values of `m_indices` in every-m-alignment-block must also be the same.
"""
@@ -70,7 +70,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
@@ -103,6 +103,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'GEMM_TYPE': GemmType.GroupedContiguous},
space=(),
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel
@@ -116,7 +118,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
@@ -125,7 +127,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
in the i-th group.
@@ -157,7 +159,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
@@ -175,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
@@ -208,6 +209,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'GEMM_TYPE': GemmType.GroupedMasked},
space=(),
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel

View File

@@ -0,0 +1,256 @@
import ctypes
import os
import enum
import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict, Tuple
from ..jit.runtime import Runtime
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated code:\n{code}')
return code
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', launch, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class Layout(enum.Enum):
RowMajor = 0
ColMajor = 1
class GemmType(enum.Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
tmap_type_map: Dict[Any, str] = {
torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_type_map = {
0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor,
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t],
stride_in_bytes: cbd.cuuint64_t,
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
tensor_dtype = tmap_type_map[global_address.dtype]
res, tensor_map = cbd.cuTensorMapEncodeTiled(
tensor_dtype,
2,
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes, ),
smem_dim,
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to encode tensor map: {res}')
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
gmem_rows: int, gmem_cols: int,
smem_rows: int, smem_cols: int,
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_k: int,
block_m: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k,
block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_k: int, shape_n: int,
block_k: int, block_n: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1),
block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_n: int,
block_m: int, block_n: int,
num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
swizzle_type_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size()
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1),
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
arg_values = (
gmem_d.data_ptr(),
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)

View File

@@ -1,29 +1,28 @@
import copy
import os
import torch
from typing import Any, Dict
import cuda.bindings.driver as cbd
from typing import Any, Callable, Dict, Type, Tuple
import cuda.bindings.driver as cuda
from ..jit import build, generate, Runtime
from ..jit import build, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
@@ -35,19 +34,19 @@ class JITTuner:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(**kwargs, **full_keys)
kernels.append((build(name, code), full_keys))
code = generator(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(**tuned_keys, **kwargs)
if return_code != cuda.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
if return_code != cbd.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g., insufficient shared memory capacity
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
@@ -59,7 +58,7 @@ class JITTuner:
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
@@ -69,13 +68,12 @@ class JITTuner:
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
print(
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = (best_runtime, best_keys)

View File

@@ -1,14 +1,19 @@
import ctypes
import os
import torch
from typing import Any, Dict
import cuda.bindings.driver as cuda
from deep_gemm import jit
# Essential debugging staffs
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
# noinspection PyShadowingNames
def launch_vector_add(kernel: cuda.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cuda.CUstream) -> cuda.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
@@ -24,28 +29,25 @@ def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: t
config.blockDimZ = 1
config.hStream = stream
kernelValues = (
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
kernelTypes = (
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
def generate_vector_add(**kwargs) -> str:
return f"""
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
@@ -63,14 +65,14 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
}}
__global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>;
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
}}
"""
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', run_vector_add, [
super().__init__(path, 'vector_add', launch_vector_add, [
'A',
'B',
'C',
@@ -79,38 +81,25 @@ class VectorAddRuntime(jit.Runtime):
if __name__ == '__main__':
# NVCC
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='float')
print(code)
print('Building ...')
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
print()
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
for compiler_name in ('NVCC', 'NVRTC'):
# Get compiler
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
print('JIT test for NVCC passed\n')
# Build
print('Building ...')
func = compiler_cls.build('test_func', code, VectorAddRuntime)
# NVRTC
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='__nv_bfloat16')
print(code)
print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test for NVRTC passed')
# Run and check
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n')