diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 4029aa8..ef1df89 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -271,91 +271,81 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, } }; - if (!scheduler.is_m_valid(math_wg_idx * WGMMA::M, m_block_idx)) { - // Skip useless computation for unaligned Ms - launch_k_iterations([&](int k_iter, auto type, auto _) { + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // TODO: remove some useless computation for unaligned Ms #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } } - }, num_former_iters); - } else { - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + } - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } - } - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }, num_former_iters); - } + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }, num_former_iters); // TMA checks constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); @@ -446,4 +436,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, }; // namespace deep_gemm -#pragma clang diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file