Simplify code

This commit is contained in:
Chenggang Zhao 2025-03-25 16:45:20 +08:00
parent 046fab64b7
commit 612dd57001

View File

@ -308,20 +308,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value;
#pragma unroll
for (int i = 0; i < kNumFormerIters; ++ i) {
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
}
#pragma unroll
for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}