Use swizzling instead of padding (#86)

* Add swizzling params

* Add TMA D descriptor

* Always use STSMx2

* Swizzling draft

* Compatible with padding

* Fix bugs

* Optimize swizzle performance

* Optimize expression

* Optimize TMA issues

* Fix README

* Stricter assertions
This commit is contained in:
Chenggang Zhao
2025-04-14 15:20:58 +08:00
committed by GitHub
parent 2e7e58011b
commit 37aa127451
6 changed files with 147 additions and 104 deletions

View File

@@ -45,6 +45,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kSwizzleDMode,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
@@ -54,7 +55,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a) {
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
@@ -63,6 +65,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
@@ -86,6 +89,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
// `tensor_map_d` is only used in swizzling mode
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
if constexpr (kSwizzleDMode > 0)
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
}
__syncwarp();
@@ -345,6 +353,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
}
}, num_former_iters);
// TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
"Swizzling and padding are not compatible");
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
@@ -352,38 +371,65 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr int kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling, just padding
// NOTES: padding must be zero for BF16 output
DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output");
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
// Issue TMA store
cute::tma_store_fence();
if (lane_idx < 16) {
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N;
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16));
}
__syncwarp();
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Wait TMA to be finished
cute::tma_store_arrive();
cute::tma_store_wait<0>();
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
// Wait TMA to be finished
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
}
}
#else
@@ -395,6 +441,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kSwizzleDMode,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
@@ -410,6 +457,7 @@ public:
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
@@ -418,6 +466,7 @@ public:
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
BLOCK_M, BLOCK_N, BLOCK_K,
BLOCK_N_PADDING,
kSwizzleDMode,
kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
@@ -442,7 +491,7 @@ public:
auto status = cudaLaunchKernelEx(&config, kernel,
gmem_d, scales_b, grouped_layout,
shape_m,
tma_a_desc, tma_b_desc, tma_scales_a_desc);
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
@@ -458,6 +507,21 @@ public:
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
swizzle_mode);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
// Make TMA aligned to 16 bytes