Remove unaligned predicates

This commit is contained in:
Chenggang Zhao 2025-03-25 16:32:40 +08:00
parent 3497428a5e
commit 7768319ffe
2 changed files with 43 additions and 20 deletions

View File

@ -27,6 +27,17 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
} }
template <int kNumFormerIters, int kGap, int kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
if (num_former_iters == kNumFormerIters) {
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
return;
}
if constexpr (kNumFormerIters + kGap <= kEnd)
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K, template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages, uint32_t kNumGroups, uint32_t kNumStages,
@ -129,15 +140,19 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// For pipeline unrolling // For pipeline unrolling
struct DivisibleK {}; struct DivisibleK {};
struct NotDivisibleK {}; struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) { auto launch_k_iterations = [](const auto& func, int num_former_iters) {
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = BLOCK_K / 8;
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) { if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}); func(k_iter, DivisibleK{}, num_former_iters_type);
} else { } else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}); func(k_iter, DivisibleK{}, num_former_iters_type);
func(kNumIterations - 1, NotDivisibleK{}); func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
} }
}, func, num_former_iters);
}; };
// Register reconfigurations // Register reconfigurations
@ -156,7 +171,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if (threadIdx.x == kNumMathThreads) { if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks // Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) { while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) { launch_k_iterations([&](int k_iter, auto type, auto _) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
@ -193,7 +208,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive(); full_barriers[s]->arrive();
} }
}); }, 0);
} }
// To safely deconstruct distributed shared barriers, we need another round of empty waits // To safely deconstruct distributed shared barriers, we need another round of empty waits
@ -246,7 +261,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
}; };
// Launch MMAs // Launch MMAs
launch_k_iterations([&](int k_iter, auto type) { launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
@ -292,13 +307,21 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
float scale_0_1, scale_1_1; float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB) if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value;
#pragma unroll #pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { for (int i = 0; i < kNumFormerIters; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; }
#pragma unroll
for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
} }
} }
@ -308,7 +331,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s); empty_barrier_arrive(s);
} }
}); }, num_former_iters);
// Write back to shared memory using STSM // Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");

View File

@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=177,174,940'] '--diag-suppress=177,174,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi'] cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()] include_dirs = [get_jit_include_dir()]