refactor the loop if/else check

Keeps DivisibleK and NotDivisibleK logic explicit
Ensures clear separation between the bulk iterations and the last special case
Easier to modify later if conditions change
This commit is contained in:
A-transformer 2025-02-27 22:23:53 +04:00 committed by GitHub
parent a2e0d68eed
commit b0b9e03345
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -129,14 +129,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(kNumIterations - 1, NotDivisibleK{});
constexpr bool is_divisible = (SHAPE_K % kFullKOfAllStages == 0);
constexpr int last_iter = kNumIterations - 1;
static_assert(kNumIterations > 0, "At least one iteration required");
// Handle all but the last iteration with DivisibleK
for (int k_iter = 0; k_iter < last_iter; ++k_iter) {
func(k_iter, DivisibleK{});
}
// Last iteration: DivisibleK if SHAPE_K is divisible, NotDivisibleK otherwise
func(last_iter, is_divisible ? DivisibleK{} : NotDivisibleK{});
};
// Register reconfigurations