From 9c4f6f53f57d9c00f6a32784bff719edca5039ab Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 16:51:21 +0800 Subject: [PATCH] Optimize compilation speed --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ba1b90c..625e178 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,15 +27,15 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template +template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { - if (num_former_iters == kNumFormerIters) { + if (kMustUseUniformedScaleB or num_former_iters == kNumFormerIters) { inner_launch_k_iterations(func, cute::Int{}); return; } - if constexpr (kNumFormerIters + kGap <= kEnd) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); + if constexpr (kNumFormerIters + kGap <= kEnd and not kMustUseUniformedScaleB) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); } template ([](const auto& func, auto num_former_iters_type) { + outer_launch_k_iterations([](const auto& func, auto num_former_iters_type) { if constexpr (SHAPE_K % kFullKOfAllStages == 0) { for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) func(k_iter, DivisibleK{}, num_former_iters_type); @@ -324,7 +324,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }, num_former_iters); + }, kMustUseUniformedScaleB ? 0 : num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");