Merge pull request #78 from deepseek-ai/tma-3d-padding

Solving bank conflict via padding and TMA 3D store
This commit is contained in:
Chenggang Zhao 2025-04-03 16:06:10 +08:00 committed by GitHub
commit c187c23ba8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 121 additions and 51 deletions

View File

@ -128,6 +128,7 @@ The library also provides some environment variables, which may be useful:
- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default - `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default
- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization - `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization
- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output - `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output
- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details - `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details

View File

@ -40,6 +40,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
template <uint32_t SHAPE_N, uint32_t SHAPE_K, template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kNumGroups, uint32_t kNumStages, uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup, uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA, uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
@ -50,11 +51,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
const __grid_constant__ CUtensorMap tensor_map_a, const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b, const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a, const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) { const __grid_constant__ std::pair<CUtensorMap, CUtensorMap> tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks // Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types // Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type; using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
@ -62,7 +63,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Shared memory // Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
@ -82,7 +83,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a)); cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d)); if constexpr (SHAPE_N >= BLOCK_N)
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.first));
if constexpr (SHAPE_N % BLOCK_N != 0)
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.second));
} }
__syncwarp(); __syncwarp();
@ -141,8 +145,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
struct DivisibleK {}; struct DivisibleK {};
struct NotDivisibleK {}; struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func, int num_former_iters) { auto launch_k_iterations = [](const auto& func, int num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8; constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization // NOTES: for too-many branches (> 5), we disable this optimization
@ -340,14 +344,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
); );
} }
if constexpr (WGMMA::kNumAccum % 8 != 0) { if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy( SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
); );
} }
cute::tma_store_fence(); cute::tma_store_fence();
@ -355,8 +359,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Use TMA store to write back to global memory // Use TMA store to write back to global memory
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, if (n_block_idx < SHAPE_N / BLOCK_N) {
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); // Except the last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
} else {
// The last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
}
cute::tma_store_arrive(); cute::tma_store_arrive();
cute::tma_store_wait<0>(); cute::tma_store_wait<0>();
} }
@ -371,6 +382,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
template <uint32_t SHAPE_N, uint32_t SHAPE_K, template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kNumGroups, uint32_t kNumStages, uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA, uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType> GemmType kGemmType>
@ -386,14 +398,17 @@ public:
const CUtensorMap& tma_a_desc, const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc, const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc, const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc, const std::pair<CUtensorMap, CUtensorMap>& tma_d_desc,
cudaStream_t stream, cudaStream_t stream,
int num_sms, uint32_t smem_size) { int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128; constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128; constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup, BLOCK_M, BLOCK_N, BLOCK_K,
BLOCK_N_PADDING,
kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>; kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
@ -433,11 +448,26 @@ public:
} }
template <typename T> template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { static std::pair<CUtensorMap, CUtensorMap> make_3d_tma_d_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor, // NOTES: must be row-major
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, auto m = shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1);
min(BLOCK_M, shape_m), BLOCK_N, uint64_t gmem_strides[2] = {BLOCK_N * sizeof(T), SHAPE_N * sizeof(T)};
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); uint32_t smem_dim[3] = {BLOCK_N + BLOCK_N_PADDING, 1, BLOCK_M};
// `SHAPE_N % BLOCK_N` maybe not zero, dividing them into two parts
CUtensorMap aligned;
if constexpr (SHAPE_N >= BLOCK_N) {
uint64_t gmem_dim[3] = {BLOCK_N, SHAPE_N / BLOCK_N, m};
aligned = make_3d_tma_copy_desc(global_address, gmem_dim, gmem_strides, smem_dim);
}
CUtensorMap unaligned;
if constexpr (SHAPE_N % BLOCK_N != 0) {
uint64_t gmem_dim[3] = {SHAPE_N % BLOCK_N, 1, m};
unaligned = make_3d_tma_copy_desc(global_address + (SHAPE_N / BLOCK_N) * BLOCK_N,
gmem_dim, gmem_strides, smem_dim);
}
return {aligned, unaligned};
} }
template <typename T> template <typename T>

View File

@ -63,16 +63,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2], uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type, CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) { PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map{}; CUtensorMap tensor_map = {};
constexpr uint32_t rank = 2; uint64_t global_stride[1] = {stride_in_bytes};
uint64_t global_stride[rank - 1] = {stride_in_bytes}; uint32_t elem_strides[2] = {1, 1};
uint32_t elem_strides[rank] = {1, 1};
if (encode_func == nullptr) if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled(); encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func( auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank, &tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
global_address, gmem_dim, global_stride, smem_dim, elem_strides, global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
@ -81,6 +80,27 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
return tensor_map; return tensor_map;
} }
template <typename T>
CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3],
uint64_t gmem_strides[2], uint32_t smem_dim[3],
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map = {};
uint32_t elem_strides[3] = {1, 1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 3,
global_address, gmem_dim, gmem_strides, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1> template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void __device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,

View File

@ -48,6 +48,6 @@ __device__ __host__ constexpr T ceil_div(T a, T b) {
} }
template <typename T> template <typename T>
__device__ __host__ constexpr T gcd(T a, T b) { __device__ __host__ constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : gcd(b, a % b); return b == 0 ? a : constexpr_gcd(b, a % b);
} }

View File

@ -96,7 +96,8 @@ def put(path, data, is_binary=False):
def build(name: str, arg_defs: tuple, code: str) -> Runtime: def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Compiler flags # Compiler flags
nvcc_flags = ['-std=c++20', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a', '-gencode=arch=compute_90a,code=sm_90a',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases

View File

@ -14,22 +14,24 @@ using namespace deep_gemm;
constexpr auto N = {N}, K = {K}; constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated GEMM // Make a templated GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>; using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
// Launch kernel // Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, nullptr, gemm_t::run(out, rhs_scales, nullptr,
m, m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size); stream, num_sms, smem_size);
""" """
@ -39,8 +41,16 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def get_block_n_padding_for_smem_d(block_n: int) -> int:
elem_size, requirement = 2, (4, 8)
bank_stride = (block_n * elem_size) // 4
padding = (requirement[0] - bank_stride) % requirement[1]
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
smem_d = block_m * block_n * 2 block_n_padding = get_block_n_padding_for_smem_d(block_n)
smem_d = block_m * (block_n + block_n_padding) * 2
smem_a_per_stage = block_m * block_k smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4 smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k smem_b_per_stage = block_n * block_k
@ -91,10 +101,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Always pick the longest one # Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced # NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
stage_candidates = (8, 7, 6, 5, 4) stage_candidates = (8, 7, 6, 5, 4, 3)
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
# Unrolling both stages and `num_former_iters` will cause large code size # Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, ) stage_candidates = (4, 3)
for num_stages in stage_candidates: for num_stages in stage_candidates:
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
if best_smem_size <= sm90_capacity: if best_smem_size <= sm90_capacity:
@ -119,7 +129,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# NOTES: less L2 cache usage and less GPU frequency drop # NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n) num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0] num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size
@ -177,6 +187,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
runtime = jit_tuner.compile_and_tune( runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt', name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
'NUM_STAGES': num_stages, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0], 'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},

View File

@ -1,7 +1,7 @@
import torch import torch
from typing import Tuple from typing import Tuple
from .gemm import get_best_configs from .gemm import get_best_configs, get_block_n_padding_for_smem_d
from .tuner import jit_tuner from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms from .utils import get_col_major_tma_aligned_tensor, get_num_sms
@ -14,22 +14,25 @@ using namespace deep_gemm;
constexpr auto N = {N}, K = {K}; constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kNumGroups = {NUM_GROUPS};
constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated grouped GEMM // Make a templated grouped GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>; using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
// Launch kernel // Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, grouped_layout, gemm_t::run(out, rhs_scales, grouped_layout,
m, m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size); stream, num_sms, smem_size);
""" """
@ -91,7 +94,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
torch.cuda.current_stream(), num_sms, smem_size) torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune( runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt', name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0], 'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
@ -172,7 +177,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
torch.cuda.current_stream(), num_sms, smem_size) torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune( runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt', name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0], 'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],