Restore previous indent

This commit is contained in:
Chenggang Zhao 2025-05-27 11:11:46 +08:00
parent e7e38ed222
commit 2c5ab83c6c

View File

@ -271,16 +271,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
} }
}; };
if (!scheduler.is_m_valid(math_wg_idx * WGMMA::M, m_block_idx)) {
// Skip useless computation for unaligned Ms
launch_k_iterations([&](int k_iter, auto type, auto _) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
}, num_former_iters);
} else {
// Launch MMAs // Launch MMAs
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
@ -298,6 +288,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Wait TMA arrivals // Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll #pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M; auto m_offset = local_idx * WAVE_BLOCK_M;
@ -355,7 +346,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
empty_barrier_arrive(s); empty_barrier_arrive(s);
} }
}, num_former_iters); }, num_former_iters);
}
// TMA checks // TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);