Performance: BlockTile 256x128 optimizations enable 1500+ TFLOPS FP8 performance on the H800-SXM platform

This commit is contained in:
sazc
2025-04-08 17:42:23 +08:00
parent b4ecf9c3ff
commit 97575bf1c6
3 changed files with 168 additions and 12 deletions

View File

@@ -257,7 +257,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum*2] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
@@ -306,9 +306,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
@@ -325,6 +322,48 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
if constexpr (BLOCK_M == 256) {
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_2 = ld_shared(smem_scales_a[s] + r_0 + 2 * WGMMA::M), scale_a_3 = ld_shared(smem_scales_a[s] + r_1 + 2 * WGMMA::M);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_2_0 = scale_a_2 * scale_b_0, scale_3_0 = scale_a_3 * scale_b_0;
float scale_2_1, scale_3_1;
if constexpr (not kMustUseUniformedScaleB)
scale_2_1 = scale_a_2 * scale_b_1, scale_3_1 = scale_a_3 * scale_b_1;
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K + 2 * WGMMA::M * BLOCK_K , 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// #pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or (i + WGMMA::kNumAccum / 4) < num_former_iters;
final_accum[i * 4 + 0 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 3];
}
}
}
// Wait unaligned cases
@@ -347,12 +386,34 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
);
if constexpr (BLOCK_M == 256) {
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0 + WGMMA::kNumAccum], final_accum[i * 8 + 1 + WGMMA::kNumAccum]}),
__float22bfloat162_rn({final_accum[i * 8 + 2 + WGMMA::kNumAccum], final_accum[i * 8 + 3 + WGMMA::kNumAccum]}),
__float22bfloat162_rn({final_accum[i * 8 + 4 + WGMMA::kNumAccum], final_accum[i * 8 + 5 + WGMMA::kNumAccum]}),
__float22bfloat162_rn({final_accum[i * 8 + 6 + WGMMA::kNumAccum], final_accum[i * 8 + 7 + WGMMA::kNumAccum]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + BLOCK_M / 2 * (BLOCK_N + BLOCK_N_PADDING)
);
}
}
if constexpr (BLOCK_M == 256) {
if constexpr (WGMMA::kNumAccum * 2 % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 0], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 2], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum * 2 / 8 * 16
);
}
} else {
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
);
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();