Restore previous indent

This commit is contained in:
Chenggang Zhao 2025-05-27 11:11:46 +08:00
parent e7e38ed222
commit 2c5ab83c6c

View File

@ -271,91 +271,81 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
}
};
if (!scheduler.is_m_valid(math_wg_idx * WGMMA::M, m_block_idx)) {
// Skip useless computation for unaligned Ms
launch_k_iterations([&](int k_iter, auto type, auto _) {
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
}, num_former_iters);
} else {
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
}
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
}, num_former_iters);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
}, num_former_iters);
// TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);