diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 4b0cebf..ee6e4a4 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -283,6 +283,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, empty_barrier_arrive(s); // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; float scale_0_1, scale_1_1; if constexpr (not kMustUseUniformedScaleB)