From 46eb0d08fb41d86e2c0c2dae4d896ad078fc82d8 Mon Sep 17 00:00:00 2001 From: sazc Date: Tue, 25 Mar 2025 10:44:57 +0800 Subject: [PATCH 01/14] Performance: Larger BlockTile optimizations enable 1470+ TFLOPS FP8 performance on the H800-SXM platform --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 895 ++++++++++++++++++---- deep_gemm/include/deep_gemm/mma_utils.cuh | 131 ++++ deep_gemm/include/deep_gemm/scheduler.cuh | 92 +++ deep_gemm/jit_kernels/gemm.py | 62 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 12 +- 5 files changed, 1008 insertions(+), 184 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ee6e4a4..85e13a9 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,7 +27,7 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template = 900)) or defined(__CLION_IDE__) + // Currently, only BLOCK_N size of 160 is classified as a largeBlockTile configuration in our optimization framework. + constexpr bool largeBlockTile = (BLOCK_N == 160); // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + if constexpr(!largeBlockTile) { + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + } // Types using WGMMA = typename FP8MMASelector::type; @@ -146,205 +150,760 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); + if constexpr (largeBlockTile) { + auto scheduler = SchedulerLargeBlockTile(SHAPE_N, grouped_layout); + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + + // Issue TMA A without broadcasting + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + + // Issue TMA B with broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + // To ensure optimal performance, conditional checks must never be placed inside the loop body of kNumInnerStages. + switch (n_block_idx % 4) { + case 0: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + // In performance-critical code paths, we rigorously avoid branching logic and unnecessary instructions. + // Here, we adopt a brute-force approach by explicitly unrolling computations for all possible scenarios. + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 4; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + case 1: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 3; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 3; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + case 2: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 2; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 2; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + case 3: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 1; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 1; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + default: + + break; + } + } + } + } else { + auto scheduler = Scheduler(shape_m, grouped_layout); + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + + // Issue TMA A with broadcasting + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs launch_k_iterations([&](int k_iter, auto type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - // Issue TMA A with broadcasting - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } } // Wait unaligned cases #pragma unroll for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); } }); - } - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); - - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); } - }; - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); + __syncwarp(); } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); } } + + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); #endif } -template ; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index b44bf95..b242261 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -665,6 +665,135 @@ struct SM90_64x128x32_F32E4M3E4M3_SS { static constexpr int kNumAccum = M * N / 128; }; + +struct SM90_64x144x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}, " + " %72," + " %73," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 144; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + + +struct SM90_64x160x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}, " + " %80," + " %81," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 160; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + + struct SM90_64x192x32_F32E4M3E4M3_SS { __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, @@ -876,6 +1005,8 @@ struct FP8MMASelector { if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); + if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); + if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 329fbb0..6393713 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -98,6 +98,98 @@ struct Scheduler { return true; } }; +template +struct SchedulerLargeBlockTile { + int current_iter = -1; + uint32_t num_aligned_n_blocks; + + // For normal GEMM + // Maybe not used in the masked grouped GEMM + uint32_t num_blocks; + + // For grouped GEMM + int* grouped_layout; + // Only used for masked layout + uint32_t curr_group_idx, curr_cumsum; + + __device__ __forceinline__ explicit SchedulerLargeBlockTile(const uint32_t shape_n, + int* grouped_layout = nullptr) { + num_aligned_n_blocks = ceil_div(shape_n, BLOCK_N); + if constexpr (kGemmType == GemmType::Normal) { + num_blocks = num_aligned_n_blocks * kNumMBlocks; + } else if (kGemmType == GemmType::GroupedContiguous) { + num_blocks = num_aligned_n_blocks * kNumMBlocks; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::GroupedMasked) { + curr_group_idx = curr_cumsum = 0; + this->grouped_layout = grouped_layout; + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_n_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNumMBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto num_blocks_per_group = num_n_blocks * kNumMBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_m_block_idx = group_idx * kNumMBlocksPerGroup; + auto num_m_blocks_in_group = min(kNumMBlocksPerGroup, kNumMBlocks - first_m_block_idx); + // auto num_m_blocks_in_group = kNumMBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + n_block_idx = in_group_idx / num_m_blocks_in_group; + m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; + // if (threadIdx.x == 256) { + // printf("blockIdx.x: %d group_idx: %d in_group_idx: %d m_block_idx: %d n_block_idx: %d\n", blockIdx.x, group_idx, in_group_idx, m_block_idx, n_block_idx); + // } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx=0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if (kGemmType == GemmType::GroupedContiguous) { + auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); + return offset * shape_dim + block_idx * block_size; + } else if (kGemmType == GemmType::GroupedMasked) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; + + if constexpr (kGemmType == GemmType::GroupedMasked) { + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumMBlocks) + break; + + // Move to check the next group + curr_group_idx ++, curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumMBlocks, m_block_idx, n_block_idx); + } else { + if (next_block_idx >= num_blocks) + return false; + + get_swizzled_block_idx(num_aligned_n_blocks, next_block_idx, m_block_idx, n_block_idx); + } + return true; + } +}; #pragma clang diagnostic pop } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index d97a615..903c1a0 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -10,14 +10,14 @@ template = """ using namespace deep_gemm; // Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; +constexpr auto M = {M}, N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; // Make a templated GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -62,7 +62,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64 if m <= 64 else 128, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + # Current optimizations target large-scale Normal GEMMs for dense models and + # Grouped GEMMs for MoE models (contiguous memory layout), with a potential + # block_n tile size of 160 to enhance data reuse in block tiling. + if m >= 4096 and num_groups == 1: + block_ns = tuple((16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 160)) + else: + block_ns = tuple(range(16, 129, 8)) + fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -77,7 +84,16 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, if best_block_m is None or best_block_n is None: success = True elif num_waves < best_num_waves: - success = True + # The value 0.02 is currently an empirically estimated threshold to + # filter out cases unsuitable for large block tile configurations, + # with optimizations planned for later stages to address excluded scenarios. + if block_n == 160 and \ + (num_waves * block_m * block_n - best_num_waves * best_block_m * best_block_n) / (best_num_waves * best_block_m * best_block_n) < 0.02: + success = True + elif block_n == 160: + success = False + else: + success = True elif num_waves == best_num_waves: # Check last wave utilization util = get_last_wave_util(block_m, block_n) @@ -89,24 +105,44 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + if best_block_n != 160: + for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) + if best_smem_size <= sm90_capacity: + best_num_stages = num_stages + break + else: + # NOTES: This is done to reduce the code footprint after unrolling. + # Additionally, if k does not meet the following conditions, a slight performance penalty will occur. + num_stages = 4 + assert k / 128 % num_stages == 0 best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: - best_num_stages = num_stages - break + assert best_smem_size <= sm90_capacity + best_num_stages = num_stages + assert best_num_stages is not None # Decide the number of TMA multicast best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 + # When using large block tiling, broadcasting B is required to achieve maximum performance gains. + if best_block_n != 160: + if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 + else: + if m >= 4096 and is_tma_multicast_legal(m, best_block_m, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 + # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast - assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + if best_block_n != 160: + assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + else: + assert num_min_sms <= num_sms and is_tma_multicast_legal(m, best_block_m, best_num_tma_multicast, num_min_sms) + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size @@ -162,7 +198,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, space=(), includes=includes, @@ -176,3 +212,5 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Run the kernel runtime(*args) + # For debug + return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 415fc67..c3b26a7 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -11,14 +11,14 @@ template = """ using namespace deep_gemm; // Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; +constexpr auto M = {M}, N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; // Make a templated grouped GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, space=(), includes=includes, @@ -106,6 +106,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Run the kernel runtime(*args) + # For debug + return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -169,7 +171,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, space=(), includes=includes, @@ -184,3 +186,5 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Run the kernel runtime(*args) + # For debug + return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size From b922e64cb284bf258e71ae2cd86a6fb6ba00759e Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 13:37:59 +0800 Subject: [PATCH 02/14] Support block size 160 --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 903 +++++----------------- deep_gemm/include/deep_gemm/scheduler.cuh | 91 --- deep_gemm/jit_kernels/gemm.py | 61 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 12 +- 4 files changed, 188 insertions(+), 879 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 85e13a9..c2572a6 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,7 +27,7 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template = 900)) or defined(__CLION_IDE__) - // Currently, only BLOCK_N size of 160 is classified as a largeBlockTile configuration in our optimization framework. - constexpr bool largeBlockTile = (BLOCK_N == 160); // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - if constexpr(!largeBlockTile) { - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); - } + // TODO: check `ceil_div(BLOCK_N, BLOCK_K) == 1` or `gcd(BLOCK_K, BLOCK_N) == BLOCK_N - BLOCK_K` // Types using WGMMA = typename FP8MMASelector::type; @@ -150,760 +146,205 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, grouped_layout); - if constexpr (largeBlockTile) { - auto scheduler = SchedulerLargeBlockTile(SHAPE_N, grouped_layout); - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - - // Issue TMA A without broadcasting - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - - // Issue TMA B with broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); - // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); - } - }; - // To ensure optimal performance, conditional checks must never be placed inside the loop body of kNumInnerStages. - switch (n_block_idx % 4) { - case 0: - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - // In performance-critical code paths, we rigorously avoid branching logic and unnecessary instructions. - // Here, we adopt a brute-force approach by explicitly unrolling computations for all possible scenarios. - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; - } - #pragma unroll - for (int i = WGMMA::kNumAccum / 4 / 5 * 4; i < WGMMA::kNumAccum / 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; - } - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); - break; - case 1: - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 3; ++ i) { - final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; - } - #pragma unroll - for (int i = WGMMA::kNumAccum / 4 / 5 * 3; i < WGMMA::kNumAccum / 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; - } - } - - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); - break; - case 2: - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 2; ++ i) { - final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; - } - #pragma unroll - for (int i = WGMMA::kNumAccum / 4 / 5 * 2; i < WGMMA::kNumAccum / 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; - } - } - - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); - break; - case 3: - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 1; ++ i) { - final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; - } - #pragma unroll - for (int i = WGMMA::kNumAccum / 4 / 5 * 1; i < WGMMA::kNumAccum / 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; - } - } - - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); - break; - default: - - break; - } - } - } - } else { - auto scheduler = Scheduler(shape_m, grouped_layout); - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - - // Issue TMA A with broadcasting - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); - - - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); - } - }; - - // Launch MMAs launch_k_iterations([&](int k_iter, auto type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + // Issue TMA A with broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } // Wait unaligned cases #pragma unroll for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); } }); + } - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); } } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + } } - - #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); #endif } -template ; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 6393713..c339b53 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -98,98 +98,7 @@ struct Scheduler { return true; } }; -template -struct SchedulerLargeBlockTile { - int current_iter = -1; - uint32_t num_aligned_n_blocks; - // For normal GEMM - // Maybe not used in the masked grouped GEMM - uint32_t num_blocks; - - // For grouped GEMM - int* grouped_layout; - // Only used for masked layout - uint32_t curr_group_idx, curr_cumsum; - - __device__ __forceinline__ explicit SchedulerLargeBlockTile(const uint32_t shape_n, - int* grouped_layout = nullptr) { - num_aligned_n_blocks = ceil_div(shape_n, BLOCK_N); - if constexpr (kGemmType == GemmType::Normal) { - num_blocks = num_aligned_n_blocks * kNumMBlocks; - } else if (kGemmType == GemmType::GroupedContiguous) { - num_blocks = num_aligned_n_blocks * kNumMBlocks; - this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::GroupedMasked) { - curr_group_idx = curr_cumsum = 0; - this->grouped_layout = grouped_layout; - } - } - - __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_n_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNumMBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); - - // Swizzle for better L2 usages - auto num_blocks_per_group = num_n_blocks * kNumMBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_m_block_idx = group_idx * kNumMBlocksPerGroup; - auto num_m_blocks_in_group = min(kNumMBlocksPerGroup, kNumMBlocks - first_m_block_idx); - // auto num_m_blocks_in_group = kNumMBlocksPerGroup; - auto in_group_idx = block_idx % num_blocks_per_group; - n_block_idx = in_group_idx / num_m_blocks_in_group; - m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; - // if (threadIdx.x == 256) { - // printf("blockIdx.x: %d group_idx: %d in_group_idx: %d m_block_idx: %d n_block_idx: %d\n", blockIdx.x, group_idx, in_group_idx, m_block_idx, n_block_idx); - // } - } - - template - __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, - const uint32_t& block_idx, const uint32_t& m_block_idx=0) { - if constexpr (kGemmType == GemmType::Normal) { - return block_idx * block_size; - } else if (kGemmType == GemmType::GroupedContiguous) { - auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); - return offset * shape_dim + block_idx * block_size; - } else if (kGemmType == GemmType::GroupedMasked) { - return curr_group_idx * shape_dim + block_idx * block_size; - } - } - - __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { - const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; - - if constexpr (kGemmType == GemmType::GroupedMasked) { - uint32_t num_m_blocks; - while (true) { - // End of the task - if (curr_group_idx == kNumGroups) - return false; - - // Within current group - num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); - auto current_m_block_cumsum = curr_cumsum + num_m_blocks; - if (next_block_idx < current_m_block_cumsum * kNumMBlocks) - break; - - // Move to check the next group - curr_group_idx ++, curr_cumsum = current_m_block_cumsum; - } - - get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumMBlocks, m_block_idx, n_block_idx); - } else { - if (next_block_idx >= num_blocks) - return false; - - get_swizzled_block_idx(num_aligned_n_blocks, next_block_idx, m_block_idx, n_block_idx); - } - return true; - } -}; #pragma clang diagnostic pop } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 903c1a0..6852d5e 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -10,14 +10,14 @@ template = """ using namespace deep_gemm; // Templated args from Python JIT call -constexpr auto M = {M}, N = {N}, K = {K}; +constexpr auto N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; // Make a templated GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -62,14 +62,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64 if m <= 64 else 128, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - # Current optimizations target large-scale Normal GEMMs for dense models and - # Grouped GEMMs for MoE models (contiguous memory layout), with a potential - # block_n tile size of 160 to enhance data reuse in block tiling. - if m >= 4096 and num_groups == 1: - block_ns = tuple((16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 160)) - else: - block_ns = tuple(range(16, 129, 8)) - + block_ns = tuple(range(16, 129, 8)) + (160, ) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -84,16 +77,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, if best_block_m is None or best_block_n is None: success = True elif num_waves < best_num_waves: - # The value 0.02 is currently an empirically estimated threshold to - # filter out cases unsuitable for large block tile configurations, - # with optimizations planned for later stages to address excluded scenarios. - if block_n == 160 and \ - (num_waves * block_m * block_n - best_num_waves * best_block_m * best_block_n) / (best_num_waves * best_block_m * best_block_n) < 0.02: - success = True - elif block_n == 160: - success = False - else: - success = True + success = True elif num_waves == best_num_waves: # Check last wave utilization util = get_last_wave_util(block_m, block_n) @@ -105,44 +89,25 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - if best_block_n != 160: - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): - best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: - best_num_stages = num_stages - break - else: - # NOTES: This is done to reduce the code footprint after unrolling. - # Additionally, if k does not meet the following conditions, a slight performance penalty will occur. - num_stages = 4 - assert k / 128 % num_stages == 0 + for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - assert best_smem_size <= sm90_capacity - best_num_stages = num_stages - + if best_smem_size <= sm90_capacity: + best_num_stages = num_stages + break assert best_num_stages is not None # Decide the number of TMA multicast best_num_tma_multicast = 1 # When using large block tiling, broadcasting B is required to achieve maximum performance gains. - if best_block_n != 160: - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 - else: - if m >= 4096 and is_tma_multicast_legal(m, best_block_m, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 - + if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast - if best_block_n != 160: - assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) - else: - assert num_min_sms <= num_sms and is_tma_multicast_legal(m, best_block_m, best_num_tma_multicast, num_min_sms) - + assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size @@ -198,7 +163,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', - keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, space=(), includes=includes, @@ -212,5 +177,3 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Run the kernel runtime(*args) - # For debug - return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index c3b26a7..415fc67 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -11,14 +11,14 @@ template = """ using namespace deep_gemm; // Templated args from Python JIT call -constexpr auto M = {M}, N = {N}, K = {K}; +constexpr auto N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; // Make a templated grouped GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, space=(), includes=includes, @@ -106,8 +106,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Run the kernel runtime(*args) - # For debug - return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -171,7 +169,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, space=(), includes=includes, @@ -186,5 +184,3 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Run the kernel runtime(*args) - # For debug - return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size From 742fb1c8a589e8f5ceffdcb3e71080dba380f516 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 13:41:28 +0800 Subject: [PATCH 03/14] Compilation-time GCD --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 2 +- deep_gemm/include/deep_gemm/utils.cuh | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index c2572a6..a041d40 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - // TODO: check `ceil_div(BLOCK_N, BLOCK_K) == 1` or `gcd(BLOCK_K, BLOCK_N) == BLOCK_N - BLOCK_K` + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 0005907..fe2c016 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -46,3 +46,8 @@ template __device__ __host__ constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; } + +template +__device__ __host__ constexpr T gcd(T a, T b) { + return b == 0 ? a : gcd(b, a % b); +} From 7ffb118e5453898ef8daf6a9acf37be575db5156 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 14:56:42 +0800 Subject: [PATCH 04/14] Support multicasting on B --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 30 ++++++++++-------- deep_gemm/include/deep_gemm/scheduler.cuh | 32 +++++++++++++------ deep_gemm/jit/template.py | 5 ++- deep_gemm/jit_kernels/gemm.py | 38 ++++++++++++++--------- 4 files changed, 67 insertions(+), 38 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index a041d40..065bd88 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -31,7 +31,7 @@ template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, @@ -146,7 +146,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); + auto scheduler = Scheduler(shape_m, grouped_layout); if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data @@ -161,6 +161,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // Assign TMA multicast number into A and B + constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1; + constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast; + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll @@ -168,18 +172,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait consumer release empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - // Issue TMA A with broadcasting + // Issue TMA A auto& full_barrier = *full_barriers[s]; int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } @@ -347,7 +351,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, template class Gemm { private: @@ -369,7 +373,7 @@ public: constexpr uint32_t kNumMathThreadsPerGroup = 128; auto kernel = fp8_gemm_kernel; + kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); // Cluster launch diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index c339b53..6e3cb52 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -12,9 +12,10 @@ enum class GemmType { #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template + uint32_t kNum1DBlocksPerGroup = 16> struct Scheduler { int current_iter = -1; uint32_t num_aligned_m_blocks; @@ -43,16 +44,27 @@ struct Scheduler { } __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages - auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; - auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = in_group_idx / num_n_blocks_in_group; - n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + // TODO: unify these 2 branches + if constexpr (kIsTMAMulticastOnA) { + auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + } else { + auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; + n_block_idx = in_group_idx / num_m_blocks_in_group; + } } template diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index cdca4c4..ead37f5 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -67,7 +67,10 @@ def cpp_format(template: str, keys: Dict[str, Any]) -> str: # We don't use `str.format` because it's not safe for C++ {} braces new_template = copy.deepcopy(template) for key, value in keys.items(): - new_template = new_template.replace(f'{{{key}}}', f'{value}') + value_str = str(value) + if isinstance(value, bool): + value_str = value_str.lower() + new_template = new_template.replace(f'{{{key}}}', f'{value_str}') return new_template diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 6852d5e..d0fd8f5 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -15,9 +15,10 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -31,10 +32,10 @@ GemmType::run(out, rhs_scales, nullptr, """ -def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: +def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: if num_tma_multicast == 1: return True - return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: @@ -56,7 +57,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int, int]: + is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: # TODO: for some cases, smaller M block is better, add them into tuning space block_ms = (64 if m <= 64 else 128, ) @@ -96,20 +97,27 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, break assert best_num_stages is not None - # Decide the number of TMA multicast - best_num_tma_multicast = 1 - # When using large block tiling, broadcasting B is required to achieve maximum performance gains. - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 + # Decide the number of TMA multicast and whether broadcast on A + best_tma_multicast_config = (1, True) + + # Try to multicast on the larger block side first + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms), + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if m >= 1024 and is_multicast_legal[i] and num_groups == 1: + best_tma_multicast_config = (2, i == 'A') + break # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast - assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms - return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -159,12 +167,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms) args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), From 3497428a5e509644e476f19c52f836b426861aa1 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 15:16:26 +0800 Subject: [PATCH 05/14] Minor fix --- deep_gemm/jit_kernels/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index d0fd8f5..ff20b45 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -63,7 +63,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64 if m <= 64 else 128, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + (160, ) + block_ns = tuple(range(16, 129, 8)) + (144, 160, ) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) From 7768319ffe68eaad45c0b507703f6a1a5a9607a4 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 16:32:40 +0800 Subject: [PATCH 06/14] Remove unaligned predicates --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 61 ++++++++++++++++-------- deep_gemm/jit/compiler.py | 2 +- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 065bd88..0611c5c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,6 +27,17 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } +template +__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { + if (num_former_iters == kNumFormerIters) { + inner_launch_k_iterations(func, cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); +} + template ([](const auto& func, auto num_former_iters_type) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); + } + }, func, num_former_iters); }; // Register reconfigurations @@ -156,7 +171,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto _) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -193,7 +208,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); full_barriers[s]->arrive(); } - }); + }, 0); } // To safely deconstruct distributed shared barriers, we need another round of empty waits @@ -246,7 +261,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, }; // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -292,13 +307,21 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, float scale_0_1, scale_1_1; if constexpr (not kMustUseUniformedScaleB) scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value; #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + for (int i = 0; i < kNumFormerIters; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; } } @@ -308,7 +331,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }); + }, num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 0f099d8..aad8939 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime: '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases '--diag-suppress=177,174,940'] - cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi'] + cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] include_dirs = [get_jit_include_dir()] From 046fab64b775966cd027f9aca774484c2241ba99 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 16:41:44 +0800 Subject: [PATCH 07/14] Fix grouped GEMM cases --- deep_gemm/jit_kernels/m_grouped_gemm.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 415fc67..bffe137 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated grouped GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, - is_grouped_contiguous=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) args = (lhs, lhs_scales, rhs, rhs_scales, out, m_indices, m, num_groups, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': 'GroupedContiguous'}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), @@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) # Extra checks for TMA store if num_groups > 1 and m > block_m: @@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': 'GroupedMasked'}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), From 612dd57001878793ad8818797c6f84383f967e6c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 16:45:20 +0800 Subject: [PATCH 08/14] Simplify code --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 0611c5c..ba1b90c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -308,20 +308,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if constexpr (not kMustUseUniformedScaleB) scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value; #pragma unroll - for (int i = 0; i < kNumFormerIters; ++ i) { - final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; - } - #pragma unroll - for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; } } From 9c4f6f53f57d9c00f6a32784bff719edca5039ab Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 16:51:21 +0800 Subject: [PATCH 09/14] Optimize compilation speed --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ba1b90c..625e178 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,15 +27,15 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template +template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { - if (num_former_iters == kNumFormerIters) { + if (kMustUseUniformedScaleB or num_former_iters == kNumFormerIters) { inner_launch_k_iterations(func, cute::Int{}); return; } - if constexpr (kNumFormerIters + kGap <= kEnd) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); + if constexpr (kNumFormerIters + kGap <= kEnd and not kMustUseUniformedScaleB) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); } template ([](const auto& func, auto num_former_iters_type) { + outer_launch_k_iterations([](const auto& func, auto num_former_iters_type) { if constexpr (SHAPE_K % kFullKOfAllStages == 0) { for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) func(k_iter, DivisibleK{}, num_former_iters_type); @@ -324,7 +324,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }, num_former_iters); + }, kMustUseUniformedScaleB ? 0 : num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); From ddccb230ca773179d3899abab54ff0fe54b36c26 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 17:12:51 +0800 Subject: [PATCH 10/14] Fix NVCC branch divergence --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 625e178..fdcf5a1 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,15 +27,15 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template +template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { - if (kMustUseUniformedScaleB or num_former_iters == kNumFormerIters) { + if (num_former_iters == kNumFormerIters) { inner_launch_k_iterations(func, cute::Int{}); return; } - if constexpr (kNumFormerIters + kGap <= kEnd and not kMustUseUniformedScaleB) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); } template ([](const auto& func, auto num_former_iters_type) { + constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) { if constexpr (SHAPE_K % kFullKOfAllStages == 0) { for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) func(k_iter, DivisibleK{}, num_former_iters_type); @@ -152,7 +156,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, func(k_iter, DivisibleK{}, num_former_iters_type); func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); } - }, func, num_former_iters); + }, func, kShouldOptimize ? num_former_iters : 0); }; // Register reconfigurations @@ -310,7 +314,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value; + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -324,7 +328,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }, kMustUseUniformedScaleB ? 0 : num_former_iters); + }, num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); From 1999d553e512740b0bf99e2564199c18839683a3 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 17:18:53 +0800 Subject: [PATCH 11/14] Lower TMA requirement --- deep_gemm/jit_kernels/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index ff20b45..31d1a2e 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -106,7 +106,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms), } for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): - if m >= 1024 and is_multicast_legal[i] and num_groups == 1: + if m >= 512 and is_multicast_legal[i] and num_groups == 1: best_tma_multicast_config = (2, i == 'A') break From 25db8de3454800bd40b93860b99c303570d09632 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 17:34:06 +0800 Subject: [PATCH 12/14] Better performance --- deep_gemm/jit_kernels/gemm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 31d1a2e..65b44ff 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -1,3 +1,4 @@ +import math import torch from typing import Tuple @@ -90,7 +91,11 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + stage_candidates = (8, 7, 6, 5, 4) + if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: + # Unrolling both stages and `num_former_iters` will cause large code size + stage_candidates = (4, ) + for num_stages in stage_candidates: best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) if best_smem_size <= sm90_capacity: best_num_stages = num_stages From 09d097f84d823b3b19f0280d868dbdbf705f230d Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 17:41:49 +0800 Subject: [PATCH 13/14] Add some notes --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index fdcf5a1..d9ab480 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -314,6 +314,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; From 55ab91f72f76e8a6db0e86282c3169a88d93df1b Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 18:06:47 +0800 Subject: [PATCH 14/14] Update performance --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index d3cbbdf..d1b0817 100644 --- a/README.md +++ b/README.md @@ -28,21 +28,21 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed | 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x | | 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x | | 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x | -| 4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x | -| 4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x | -| 4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x | -| 4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x | -| 4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x | -| 4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x | +| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x | +| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x | +| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x | +| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x | +| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x | +| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x | ### Grouped GEMMs for MoE models (contiguous layout) | #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup | |:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:| -| 4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x | -| 4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x | -| 8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x | -| 8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x | +| 4 | 8192 | 4096 | 7168 | 1346 TFLOPS | 434 GB/s | 1.3x | +| 4 | 8192 | 7168 | 2048 | 1214 TFLOPS | 752 GB/s | 1.3x | +| 8 | 4096 | 4096 | 7168 | 1346 TFLOPS | 516 GB/s | 1.3x | +| 8 | 4096 | 7168 | 2048 | 1214 TFLOPS | 826 GB/s | 1.2x | ### Grouped GEMMs for MoE models (masked layout)