diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 625e178..fdcf5a1 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,15 +27,15 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template +template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { - if (kMustUseUniformedScaleB or num_former_iters == kNumFormerIters) { + if (num_former_iters == kNumFormerIters) { inner_launch_k_iterations(func, cute::Int{}); return; } - if constexpr (kNumFormerIters + kGap <= kEnd and not kMustUseUniformedScaleB) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); } template ([](const auto& func, auto num_former_iters_type) { + constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) { if constexpr (SHAPE_K % kFullKOfAllStages == 0) { for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) func(k_iter, DivisibleK{}, num_former_iters_type); @@ -152,7 +156,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, func(k_iter, DivisibleK{}, num_former_iters_type); func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); } - }, func, num_former_iters); + }, func, kShouldOptimize ? num_former_iters : 0); }; // Register reconfigurations @@ -310,7 +314,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value; + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -324,7 +328,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }, kMustUseUniformedScaleB ? 0 : num_former_iters); + }, num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");