From dff6bb6f0bff01c841b44e615c52e7445f8d8ee9 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 3 Mar 2025 11:35:52 +0800 Subject: [PATCH] Add some notes --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 4a0c79e..4b0cebf 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -110,6 +110,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned #pragma unroll for (int i = 0; i < kNumStages; ++ i) { full_barriers[i]->init(1); @@ -159,6 +161,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll for (uint32_t s = 0; s < kNumInnerStages; ++ s) { // Wait consumer release