diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index f482300..87c942b 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -129,14 +129,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, struct DivisibleK {}; struct NotDivisibleK {}; auto launch_k_iterations = [](const auto& func) { - if constexpr (SHAPE_K % kFullKOfAllStages == 0) { - for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(kNumIterations - 1, NotDivisibleK{}); + constexpr bool is_divisible = (SHAPE_K % kFullKOfAllStages == 0); + constexpr int last_iter = kNumIterations - 1; + static_assert(kNumIterations > 0, "At least one iteration required"); + + // Handle all but the last iteration with DivisibleK + for (int k_iter = 0; k_iter < last_iter; ++k_iter) { + func(k_iter, DivisibleK{}); } + + // Last iteration: DivisibleK if SHAPE_K is divisible, NotDivisibleK otherwise + func(last_iter, is_divisible ? DivisibleK{} : NotDivisibleK{}); }; // Register reconfigurations