diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index fdcf5a1..d9ab480 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -314,6 +314,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];