diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 0611c5c..ba1b90c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -308,20 +308,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if constexpr (not kMustUseUniformedScaleB) scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value; #pragma unroll - for (int i = 0; i < kNumFormerIters; ++ i) { - final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; - } - #pragma unroll - for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) { - final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; } }