mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-23 06:44:17 +00:00
refactor the loop if/else check
Keeps DivisibleK and NotDivisibleK logic explicit Ensures clear separation between the bulk iterations and the last special case Easier to modify later if conditions change
This commit is contained in:
parent
a2e0d68eed
commit
b0b9e03345
@ -129,14 +129,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
struct DivisibleK {};
|
struct DivisibleK {};
|
||||||
struct NotDivisibleK {};
|
struct NotDivisibleK {};
|
||||||
auto launch_k_iterations = [](const auto& func) {
|
auto launch_k_iterations = [](const auto& func) {
|
||||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
constexpr bool is_divisible = (SHAPE_K % kFullKOfAllStages == 0);
|
||||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
constexpr int last_iter = kNumIterations - 1;
|
||||||
|
static_assert(kNumIterations > 0, "At least one iteration required");
|
||||||
|
|
||||||
|
// Handle all but the last iteration with DivisibleK
|
||||||
|
for (int k_iter = 0; k_iter < last_iter; ++k_iter) {
|
||||||
func(k_iter, DivisibleK{});
|
func(k_iter, DivisibleK{});
|
||||||
} else {
|
|
||||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
|
||||||
func(k_iter, DivisibleK{});
|
|
||||||
func(kNumIterations - 1, NotDivisibleK{});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Last iteration: DivisibleK if SHAPE_K is divisible, NotDivisibleK otherwise
|
||||||
|
func(last_iter, is_divisible ? DivisibleK{} : NotDivisibleK{});
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register reconfigurations
|
// Register reconfigurations
|
||||||
|
Loading…
Reference in New Issue
Block a user