From 97575bf1c6009f7e8d1ae6388aeafcd8f57cf73a Mon Sep 17 00:00:00 2001 From: sazc Date: Tue, 8 Apr 2025 17:42:23 +0800 Subject: [PATCH] Performance: BlockTile 256x128 optimizations enable 1500+ TFLOPS FP8 performance on the H800-SXM platform --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 81 ++++++++++++++++++--- deep_gemm/include/deep_gemm/mma_utils.cuh | 89 +++++++++++++++++++++++ deep_gemm/jit_kernels/gemm.py | 10 ++- 3 files changed, 168 insertions(+), 12 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 41e563e..22a36dd 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -257,7 +257,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum*2] = {0}; // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { @@ -306,9 +306,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, warpgroup_fence_operand(accum[i]); warpgroup_wait<0>(); - // Notify barrier arrival - empty_barrier_arrive(s); - // Promote with scales // NOTES: making it as predicates is very important for performance, comparing to two loops float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; @@ -325,6 +322,48 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; } + if constexpr (BLOCK_M == 256) { + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_2 = ld_shared(smem_scales_a[s] + r_0 + 2 * WGMMA::M), scale_a_3 = ld_shared(smem_scales_a[s] + r_1 + 2 * WGMMA::M); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_2_0 = scale_a_2 * scale_b_0, scale_3_0 = scale_a_3 * scale_b_0; + float scale_2_1, scale_3_1; + if constexpr (not kMustUseUniformedScaleB) + scale_2_1 = scale_a_2 * scale_b_1, scale_3_1 = scale_a_3 * scale_b_1; + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K + 2 * WGMMA::M * BLOCK_K , 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or (i + WGMMA::kNumAccum / 4) < num_former_iters; + final_accum[i * 4 + 0 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 3]; + } + } } // Wait unaligned cases @@ -347,12 +386,34 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) ); } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); + if constexpr (BLOCK_M == 256) { + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0 + WGMMA::kNumAccum], final_accum[i * 8 + 1 + WGMMA::kNumAccum]}), + __float22bfloat162_rn({final_accum[i * 8 + 2 + WGMMA::kNumAccum], final_accum[i * 8 + 3 + WGMMA::kNumAccum]}), + __float22bfloat162_rn({final_accum[i * 8 + 4 + WGMMA::kNumAccum], final_accum[i * 8 + 5 + WGMMA::kNumAccum]}), + __float22bfloat162_rn({final_accum[i * 8 + 6 + WGMMA::kNumAccum], final_accum[i * 8 + 7 + WGMMA::kNumAccum]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + BLOCK_M / 2 * (BLOCK_N + BLOCK_N_PADDING) + ); + } + } + if constexpr (BLOCK_M == 256) { + if constexpr (WGMMA::kNumAccum * 2 % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 0], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 2], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum * 2 / 8 * 16 + ); + } + } else { + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + ); + } } cute::tma_store_fence(); cutlass::arch::NamedBarrier(kNumMathThreads).sync(); diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index 0cc554a..c57a609 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -866,6 +866,94 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int K = 32; static constexpr int kNumAccum = M * N / 128; }; +struct SM90_64x256x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, + float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, + float& d96, float& d97, float& d98, float& d99, float& d100,float& d101,float& d102,float& d103, + float& d104,float& d105,float& d106,float& d107,float& d108,float& d109,float& d110,float& d111, + float& d112,float& d113,float& d114,float& d115,float& d116,float& d117,float& d118,float& d119, + float& d120,float& d121,float& d122,float& d123,float& d124,float& d125,float& d126,float& d127, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95), + "+f"(d96), "+f"(d97), "+f"(d98), "+f"(d99), "+f"(d100),"+f"(d101),"+f"(d102),"+f"(d103), + "+f"(d104),"+f"(d105),"+f"(d106),"+f"(d107),"+f"(d108),"+f"(d109),"+f"(d110),"+f"(d111), + "+f"(d112),"+f"(d113),"+f"(d114),"+f"(d115),"+f"(d116),"+f"(d117),"+f"(d118),"+f"(d119), + "+f"(d120),"+f"(d121),"+f"(d122),"+f"(d123),"+f"(d124),"+f"(d125),"+f"(d126),"+f"(d127) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], + d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], + d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], + d[96], d[97], d[98], d[99], d[100],d[101],d[102],d[103], + d[104],d[105],d[106],d[107],d[108],d[109],d[110],d[111], + d[112],d[113],d[114],d[115],d[116],d[117],d[118],d[119], + d[120],d[121],d[122],d[123],d[124],d[125],d[126],d[127], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 256; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; template struct SM90_U32x2_STSM_N { @@ -1008,6 +1096,7 @@ struct FP8MMASelector { if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); + if constexpr (N == 256) return SM90_64x256x32_F32E4M3E4M3_SS(); } using type = decltype(select_type()); diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 3f031c1..9c6e0b2 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -74,10 +74,16 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: # TODO: for some cases, smaller M block is better, add them into tuning space - block_ms = (64 if m <= 64 else 128, ) + # block_ms = (64 if m <= 64 else 128, ) + if m <= 64: + block_ms = (64, ) + elif m <= 128: + block_ms = (64, 128, ) + else: + block_ms = (64, 128, 256, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + (144, 160, ) + block_ns = tuple(range(16, 129, 8)) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)