Solve STSM bank conflict via padding and 3D TMA

This commit is contained in:
Chenggang Zhao
2025-04-03 15:39:35 +08:00
parent c57699ac93
commit 6db7e1863b
6 changed files with 119 additions and 51 deletions

View File

@@ -40,6 +40,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
@@ -50,11 +51,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) {
const __grid_constant__ std::pair<CUtensorMap, CUtensorMap> tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
@@ -62,7 +63,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
@@ -82,7 +83,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
if constexpr (SHAPE_N >= BLOCK_N)
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.first));
if constexpr (SHAPE_N % BLOCK_N != 0)
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.second));
}
__syncwarp();
@@ -141,8 +145,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization
@@ -340,14 +344,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
@@ -355,8 +359,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
if (n_block_idx < SHAPE_N / BLOCK_N) {
// Except the last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
} else {
// The last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
}
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
@@ -371,6 +382,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
@@ -386,14 +398,17 @@ public:
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc,
const std::pair<CUtensorMap, CUtensorMap>& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
BLOCK_M, BLOCK_N, BLOCK_K,
BLOCK_N_PADDING,
kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
@@ -433,11 +448,26 @@ public:
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
min(BLOCK_M, shape_m), BLOCK_N,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
static std::pair<CUtensorMap, CUtensorMap> make_3d_tma_d_desc(T* global_address, uint32_t shape_m) {
// NOTES: must be row-major
auto m = shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1);
uint64_t gmem_strides[2] = {BLOCK_N * sizeof(T), SHAPE_N * sizeof(T)};
uint32_t smem_dim[3] = {BLOCK_N + BLOCK_N_PADDING, 1, BLOCK_M};
// `SHAPE_N % BLOCK_N` maybe not zero, dividing them into two parts
CUtensorMap aligned;
if constexpr (SHAPE_N >= BLOCK_N) {
uint64_t gmem_dim[3] = {BLOCK_N, SHAPE_N / BLOCK_N, m};
aligned = make_3d_tma_copy_desc(global_address, gmem_dim, gmem_strides, smem_dim);
}
CUtensorMap unaligned;
if constexpr (SHAPE_N % BLOCK_N != 0) {
uint64_t gmem_dim[3] = {SHAPE_N % BLOCK_N, 1, m};
unaligned = make_3d_tma_copy_desc(global_address + (SHAPE_N / BLOCK_N) * BLOCK_N,
gmem_dim, gmem_strides, smem_dim);
}
return {aligned, unaligned};
}
template <typename T>

View File

@@ -63,16 +63,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map{};
constexpr uint32_t rank = 2;
uint64_t global_stride[rank - 1] = {stride_in_bytes};
uint32_t elem_strides[rank] = {1, 1};
CUtensorMap tensor_map = {};
uint64_t global_stride[1] = {stride_in_bytes};
uint32_t elem_strides[2] = {1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
@@ -81,6 +80,27 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
return tensor_map;
}
template <typename T>
CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3],
uint64_t gmem_strides[2], uint32_t smem_dim[3],
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map = {};
uint32_t elem_strides[3] = {1, 1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 3,
global_address, gmem_dim, gmem_strides, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,

View File

@@ -48,6 +48,6 @@ __device__ __host__ constexpr T ceil_div(T a, T b) {
}
template <typename T>
__device__ __host__ constexpr T gcd(T a, T b) {
return b == 0 ? a : gcd(b, a % b);
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}