Add some notes

This commit is contained in:
Chenggang Zhao 2025-03-03 11:35:52 +08:00
parent 6c5da03ba9
commit dff6bb6f0b

View File

@ -110,6 +110,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
@ -159,6 +161,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release