Fix NVCC branch divergence

This commit is contained in:
Chenggang Zhao 2025-03-25 17:12:51 +08:00
parent 9c4f6f53f5
commit ddccb230ca

View File

@ -27,15 +27,15 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <bool kMustUseUniformedScaleB, int kNumFormerIters, int kGap, int kEnd>
template <int kNumFormerIters, int kGap, int kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
if (kMustUseUniformedScaleB or num_former_iters == kNumFormerIters) {
if (num_former_iters == kNumFormerIters) {
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
return;
}
if constexpr (kNumFormerIters + kGap <= kEnd and not kMustUseUniformedScaleB)
outer_launch_k_iterations<kMustUseUniformedScaleB, kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
if constexpr (kNumFormerIters + kGap <= kEnd)
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
@ -141,9 +141,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = BLOCK_K / 8;
outer_launch_k_iterations<kMustUseUniformedScaleB, 0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type);
@ -152,7 +156,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
func(k_iter, DivisibleK{}, num_former_iters_type);
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
}
}, func, num_former_iters);
}, func, kShouldOptimize ? num_former_iters : 0);
};
// Register reconfigurations
@ -310,7 +314,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value;
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
@ -324,7 +328,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
}, kMustUseUniformedScaleB ? 0 : num_former_iters);
}, num_former_iters);
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");