diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 4a0c79e..4b0cebf 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -110,6 +110,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned #pragma unroll for (int i = 0; i < kNumStages; ++ i) { full_barriers[i]->init(1); @@ -159,6 +161,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll for (uint32_t s = 0; s < kNumInnerStages; ++ s) { // Wait consumer release