diff --git a/README.md b/README.md index a55311e..7be96a7 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. +## News + +- 2025.04.09: DeepGEMM now achieves up to **1520 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), and [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81) for details. + ## Performance We test all shapes potentially used in DeepSeek-V3/R1 inference (including both prefilling and decoding, but without tensor parallelism) on H800 SXM5 with NVCC 12.8. All speedup metrics are calculated in comparison to our internally and carefully optimized implementation based on CUTLASS 3.6. @@ -28,11 +32,11 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed | 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x | | 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x | | 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x | -| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x | -| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x | -| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x | -| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x | -| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x | +| 4096 | 2112 | 7168 | 1127 TFLOPS | 562 GB/s | 1.2x | +| 4096 | 24576 | 1536 | 1212 TFLOPS | 962 GB/s | 1.2x | +| 4096 | 32768 | 512 | 775 TFLOPS | 1620 GB/s | 1.2x | +| 4096 | 7168 | 16384 | 1520 TFLOPS | 384 GB/s | 1.4x | +| 4096 | 4096 | 7168 | 1410 TFLOPS | 541 GB/s | 1.3x | | 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x | ### Grouped GEMMs for MoE models (contiguous layout) @@ -160,6 +164,8 @@ The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide - Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction - [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups +- Larger block sizes +- Less bank conflicts via 3D TMA 🐳 - Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳 #### A unified and optimized block scheduler diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 41e563e..2523435 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -21,10 +21,14 @@ enum class Layout { ColMajor }; +__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { + return block_m == 64 ? 1 : 2; +} + template __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; + return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; } template @@ -257,7 +261,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { @@ -285,45 +291,54 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait TMA arrivals full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions + // TODO: remove some useless computation for unaligned Ms #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; - // Notify barrier arrival - empty_barrier_arrive(s); + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } } } @@ -338,21 +353,26 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + ); + } } cute::tma_store_fence(); cutlass::arch::NamedBarrier(kNumMathThreads).sync(); diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 3f031c1..a4407e7 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -73,8 +73,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: - # TODO: for some cases, smaller M block is better, add them into tuning space - block_ms = (64 if m <= 64 else 128, ) + block_ms = (64, 128, 256) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) + (144, 160, ) @@ -86,7 +85,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Decide block sizes by waves best_block_m, best_block_n = None, None for block_m in block_ms: - for block_n in block_ns: + # NOTES: the block sizes can not be too large, so at least one dim less than 128 + for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): success = False num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) if best_block_m is None or best_block_n is None: @@ -97,7 +97,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) + success = util > best_util + if util == best_util: + # Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n + # Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m + # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better + success |= block_m != best_block_m and block_n > best_block_n best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None