From 97575bf1c6009f7e8d1ae6388aeafcd8f57cf73a Mon Sep 17 00:00:00 2001 From: sazc Date: Tue, 8 Apr 2025 17:42:23 +0800 Subject: [PATCH 1/9] Performance: BlockTile 256x128 optimizations enable 1500+ TFLOPS FP8 performance on the H800-SXM platform --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 81 ++++++++++++++++++--- deep_gemm/include/deep_gemm/mma_utils.cuh | 89 +++++++++++++++++++++++ deep_gemm/jit_kernels/gemm.py | 10 ++- 3 files changed, 168 insertions(+), 12 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 41e563e..22a36dd 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -257,7 +257,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum*2] = {0}; // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { @@ -306,9 +306,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, warpgroup_fence_operand(accum[i]); warpgroup_wait<0>(); - // Notify barrier arrival - empty_barrier_arrive(s); - // Promote with scales // NOTES: making it as predicates is very important for performance, comparing to two loops float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; @@ -325,6 +322,48 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; } + if constexpr (BLOCK_M == 256) { + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_2 = ld_shared(smem_scales_a[s] + r_0 + 2 * WGMMA::M), scale_a_3 = ld_shared(smem_scales_a[s] + r_1 + 2 * WGMMA::M); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_2_0 = scale_a_2 * scale_b_0, scale_3_0 = scale_a_3 * scale_b_0; + float scale_2_1, scale_3_1; + if constexpr (not kMustUseUniformedScaleB) + scale_2_1 = scale_a_2 * scale_b_1, scale_3_1 = scale_a_3 * scale_b_1; + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K + 2 * WGMMA::M * BLOCK_K , 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or (i + WGMMA::kNumAccum / 4) < num_former_iters; + final_accum[i * 4 + 0 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 3]; + } + } } // Wait unaligned cases @@ -347,12 +386,34 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) ); } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); + if constexpr (BLOCK_M == 256) { + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0 + WGMMA::kNumAccum], final_accum[i * 8 + 1 + WGMMA::kNumAccum]}), + __float22bfloat162_rn({final_accum[i * 8 + 2 + WGMMA::kNumAccum], final_accum[i * 8 + 3 + WGMMA::kNumAccum]}), + __float22bfloat162_rn({final_accum[i * 8 + 4 + WGMMA::kNumAccum], final_accum[i * 8 + 5 + WGMMA::kNumAccum]}), + __float22bfloat162_rn({final_accum[i * 8 + 6 + WGMMA::kNumAccum], final_accum[i * 8 + 7 + WGMMA::kNumAccum]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + BLOCK_M / 2 * (BLOCK_N + BLOCK_N_PADDING) + ); + } + } + if constexpr (BLOCK_M == 256) { + if constexpr (WGMMA::kNumAccum * 2 % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 0], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 2], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum * 2 / 8 * 16 + ); + } + } else { + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + ); + } } cute::tma_store_fence(); cutlass::arch::NamedBarrier(kNumMathThreads).sync(); diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index 0cc554a..c57a609 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -866,6 +866,94 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int K = 32; static constexpr int kNumAccum = M * N / 128; }; +struct SM90_64x256x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, + float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, + float& d96, float& d97, float& d98, float& d99, float& d100,float& d101,float& d102,float& d103, + float& d104,float& d105,float& d106,float& d107,float& d108,float& d109,float& d110,float& d111, + float& d112,float& d113,float& d114,float& d115,float& d116,float& d117,float& d118,float& d119, + float& d120,float& d121,float& d122,float& d123,float& d124,float& d125,float& d126,float& d127, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95), + "+f"(d96), "+f"(d97), "+f"(d98), "+f"(d99), "+f"(d100),"+f"(d101),"+f"(d102),"+f"(d103), + "+f"(d104),"+f"(d105),"+f"(d106),"+f"(d107),"+f"(d108),"+f"(d109),"+f"(d110),"+f"(d111), + "+f"(d112),"+f"(d113),"+f"(d114),"+f"(d115),"+f"(d116),"+f"(d117),"+f"(d118),"+f"(d119), + "+f"(d120),"+f"(d121),"+f"(d122),"+f"(d123),"+f"(d124),"+f"(d125),"+f"(d126),"+f"(d127) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], + d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], + d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], + d[96], d[97], d[98], d[99], d[100],d[101],d[102],d[103], + d[104],d[105],d[106],d[107],d[108],d[109],d[110],d[111], + d[112],d[113],d[114],d[115],d[116],d[117],d[118],d[119], + d[120],d[121],d[122],d[123],d[124],d[125],d[126],d[127], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 256; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; template struct SM90_U32x2_STSM_N { @@ -1008,6 +1096,7 @@ struct FP8MMASelector { if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); + if constexpr (N == 256) return SM90_64x256x32_F32E4M3E4M3_SS(); } using type = decltype(select_type()); diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 3f031c1..9c6e0b2 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -74,10 +74,16 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: # TODO: for some cases, smaller M block is better, add them into tuning space - block_ms = (64 if m <= 64 else 128, ) + # block_ms = (64 if m <= 64 else 128, ) + if m <= 64: + block_ms = (64, ) + elif m <= 128: + block_ms = (64, 128, ) + else: + block_ms = (64, 128, 256, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + (144, 160, ) + block_ns = tuple(range(16, 129, 8)) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) From ce65d5e33c82185c6d13c27ab55b4b61b9c5c72c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 09:32:46 +0800 Subject: [PATCH 2/9] Remove unused x256 WGMMA --- deep_gemm/include/deep_gemm/mma_utils.cuh | 89 ----------------------- 1 file changed, 89 deletions(-) diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index c57a609..0cc554a 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -866,94 +866,6 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int K = 32; static constexpr int kNumAccum = M * N / 128; }; -struct SM90_64x256x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, - float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, - float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, - float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, - float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, - float& d96, float& d97, float& d98, float& d99, float& d100,float& d101,float& d102,float& d103, - float& d104,float& d105,float& d106,float& d107,float& d108,float& d109,float& d110,float& d111, - float& d112,float& d113,float& d114,float& d115,float& d116,float& d117,float& d118,float& d119, - float& d120,float& d121,float& d122,float& d123,float& d124,float& d125,float& d126,float& d127, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95), - "+f"(d96), "+f"(d97), "+f"(d98), "+f"(d99), "+f"(d100),"+f"(d101),"+f"(d102),"+f"(d103), - "+f"(d104),"+f"(d105),"+f"(d106),"+f"(d107),"+f"(d108),"+f"(d109),"+f"(d110),"+f"(d111), - "+f"(d112),"+f"(d113),"+f"(d114),"+f"(d115),"+f"(d116),"+f"(d117),"+f"(d118),"+f"(d119), - "+f"(d120),"+f"(d121),"+f"(d122),"+f"(d123),"+f"(d124),"+f"(d125),"+f"(d126),"+f"(d127) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], - d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], - d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], - d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], - d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], - d[96], d[97], d[98], d[99], d[100],d[101],d[102],d[103], - d[104],d[105],d[106],d[107],d[108],d[109],d[110],d[111], - d[112],d[113],d[114],d[115],d[116],d[117],d[118],d[119], - d[120],d[121],d[122],d[123],d[124],d[125],d[126],d[127], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 256; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; template struct SM90_U32x2_STSM_N { @@ -1096,7 +1008,6 @@ struct FP8MMASelector { if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); - if constexpr (N == 256) return SM90_64x256x32_F32E4M3E4M3_SS(); } using type = decltype(select_type()); From 48a5f071beca106a636e82f502ce7cd4d9201220 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 10:01:15 +0800 Subject: [PATCH 3/9] Clean up config heuristics --- deep_gemm/jit_kernels/gemm.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 9c6e0b2..a57d348 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -73,14 +73,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: - # TODO: for some cases, smaller M block is better, add them into tuning space - # block_ms = (64 if m <= 64 else 128, ) - if m <= 64: - block_ms = (64, ) - elif m <= 128: - block_ms = (64, 128, ) - else: - block_ms = (64, 128, 256, ) + block_ms = (64, 128, 256) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) @@ -103,7 +96,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) + success = util > best_util + if util == best_util: + # Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n + # Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m + # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better + success |= block_m != best_block_m and block_n > best_block_n best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None From a6524d411a0eb6a45b715d7b562eab565ec2d304 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 10:11:43 +0800 Subject: [PATCH 4/9] Larger block N candidates --- deep_gemm/jit_kernels/gemm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index a57d348..a4407e7 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -76,7 +76,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64, 128, 256) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + block_ns = tuple(range(16, 129, 8)) + (144, 160, ) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -85,7 +85,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Decide block sizes by waves best_block_m, best_block_n = None, None for block_m in block_ms: - for block_n in block_ns: + # NOTES: the block sizes can not be too large, so at least one dim less than 128 + for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): success = False num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) if best_block_m is None or best_block_n is None: From 4c0cc290c7114725bceaac9d2fd1518cde5324ed Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 10:50:44 +0800 Subject: [PATCH 5/9] Refactor M repetition with loops --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 191 +++++++++-------------- 1 file changed, 75 insertions(+), 116 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 22a36dd..5c73198 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -21,10 +21,14 @@ enum class Layout { ColMajor }; +__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { + return block_m == 64 ? 1 : 2; +} + template __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; + return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; } template @@ -257,7 +261,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum*2] = {0}; + constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { @@ -285,85 +291,55 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait TMA arrivals full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } - if constexpr (BLOCK_M == 256) { - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_2 = ld_shared(smem_scales_a[s] + r_0 + 2 * WGMMA::M), scale_a_3 = ld_shared(smem_scales_a[s] + r_1 + 2 * WGMMA::M); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_2_0 = scale_a_2 * scale_b_0, scale_3_0 = scale_a_3 * scale_b_0; - float scale_2_1, scale_3_1; - if constexpr (not kMustUseUniformedScaleB) - scale_2_1 = scale_a_2 * scale_b_1, scale_3_1 = scale_a_3 * scale_b_1; - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K + 2 * WGMMA::M * BLOCK_K , 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - - // #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or (i + WGMMA::kNumAccum / 4) < num_former_iters; - final_accum[i * 4 + 0 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 3]; - } - } } // Wait unaligned cases @@ -377,43 +353,26 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (BLOCK_M == 256) { - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0 + WGMMA::kNumAccum], final_accum[i * 8 + 1 + WGMMA::kNumAccum]}), - __float22bfloat162_rn({final_accum[i * 8 + 2 + WGMMA::kNumAccum], final_accum[i * 8 + 3 + WGMMA::kNumAccum]}), - __float22bfloat162_rn({final_accum[i * 8 + 4 + WGMMA::kNumAccum], final_accum[i * 8 + 5 + WGMMA::kNumAccum]}), - __float22bfloat162_rn({final_accum[i * 8 + 6 + WGMMA::kNumAccum], final_accum[i * 8 + 7 + WGMMA::kNumAccum]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + BLOCK_M / 2 * (BLOCK_N + BLOCK_N_PADDING) - ); - } - } - if constexpr (BLOCK_M == 256) { - if constexpr (WGMMA::kNumAccum * 2 % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 0], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 2], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum * 2 / 8 * 16 - ); - } - } else { - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); - } + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + ); + } } cute::tma_store_fence(); cutlass::arch::NamedBarrier(kNumMathThreads).sync(); From bdca8b06242c19278d4844301a997602397958fe Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 10:59:07 +0800 Subject: [PATCH 6/9] Fix indent --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 44 ++++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 5c73198..2e2e87d 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -304,7 +304,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Commit WGMMA instructions #pragma unroll for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + warpgroup_fence_operand(accum[i]); warpgroup_arrive(); #pragma unroll for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { @@ -315,29 +315,29 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, warpgroup_commit_batch(); #pragma unroll for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + warpgroup_fence_operand(accum[i]); warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); + empty_barrier_arrive(s); // Promote with scales // NOTES: making it as predicates is very important for performance, comparing to two loops float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; float scale_0_1, scale_1_1; if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; } } } @@ -356,22 +356,22 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto m_offset = local_idx * WAVE_BLOCK_M; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll + #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) - ); + __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + ); } if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + ); } } cute::tma_store_fence(); From 5a80e4bb96167b62bbb6527157e353df07a87caa Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 11:00:10 +0800 Subject: [PATCH 7/9] Fix indent x2 --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 2e2e87d..2523435 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -308,9 +308,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, warpgroup_arrive(); #pragma unroll for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); } warpgroup_commit_batch(); #pragma unroll @@ -358,7 +358,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( + SM90_U32x4_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), From a9967bc27cdcead54e8d11e1736644eb51d49b33 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 11:14:45 +0800 Subject: [PATCH 8/9] Update README --- README.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a55311e..366cf69 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. +## News + +- 2025.04.09: DeepGEMM now achieves up to **1520 TFLOPS** on H800! See #74, #78, and #81 for details. + ## Performance We test all shapes potentially used in DeepSeek-V3/R1 inference (including both prefilling and decoding, but without tensor parallelism) on H800 SXM5 with NVCC 12.8. All speedup metrics are calculated in comparison to our internally and carefully optimized implementation based on CUTLASS 3.6. @@ -28,11 +32,11 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed | 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x | | 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x | | 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x | -| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x | -| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x | -| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x | -| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x | -| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x | +| 4096 | 2112 | 7168 | 1127 TFLOPS | 562 GB/s | 1.2x | +| 4096 | 24576 | 1536 | 1212 TFLOPS | 962 GB/s | 1.2x | +| 4096 | 32768 | 512 | 775 TFLOPS | 1620 GB/s | 1.2x | +| 4096 | 7168 | 16384 | 1520 TFLOPS | 384 GB/s | 1.4x | +| 4096 | 4096 | 7168 | 1410 TFLOPS | 541 GB/s | 1.3x | | 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x | ### Grouped GEMMs for MoE models (contiguous layout) From 989c9e3694638dd02aa36b4d8cb42339305caa85 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 11:17:47 +0800 Subject: [PATCH 9/9] Update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 366cf69..7be96a7 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## News -- 2025.04.09: DeepGEMM now achieves up to **1520 TFLOPS** on H800! See #74, #78, and #81 for details. +- 2025.04.09: DeepGEMM now achieves up to **1520 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), and [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81) for details. ## Performance @@ -164,6 +164,8 @@ The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide - Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction - [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups +- Larger block sizes +- Less bank conflicts via 3D TMA 🐳 - Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳 #### A unified and optimized block scheduler