Add SkipComputation types

This commit is contained in:
Chenggang Zhao 2025-05-27 11:57:25 +08:00
parent 1169f83c36
commit a5373e4bbd
3 changed files with 47 additions and 42 deletions

View File

@ -17,8 +17,8 @@
namespace deep_gemm { namespace deep_gemm {
template <int kNumFormerIters, int kGap, int kEnd> template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
if (num_former_iters == kNumFormerIters) { if (num_former_iters == kNumFormerIters) {
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{}); inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
return; return;
@ -54,7 +54,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory // Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
@ -101,7 +101,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Fill shared memory pointers // Fill shared memory pointers
#pragma unroll #pragma unroll
for (int i = 0; i < kNumStages; ++ i) { for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
@ -111,7 +111,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Fill barriers // Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE); auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
#pragma unroll #pragma unroll
for (int i = 0; i < kNumStages; ++ i) { for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i; full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i; empty_barriers[i] = barrier_start_ptr + kNumStages + i;
} }
@ -122,7 +122,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned // even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll #pragma unroll
for (int i = 0; i < kNumStages; ++ i) { for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1); full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
} }
@ -138,28 +138,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// For pipeline unrolling // For pipeline unrolling
struct DivisibleK {}; struct DivisibleK {};
struct NotDivisibleK {}; struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func, int num_former_iters) { struct SkipComputation {};
struct NotSkipComputation {};
auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization // NOTES: for too-many branches (> 5), we disable this optimization
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) { outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) { if (skip_computation) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type); func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
} else if (SHAPE_K % kFullKOfAllStages == 0) {
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
} else { } else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type); func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
} }
}, func, kShouldOptimize ? num_former_iters : 0); }, func, kShouldOptimize ? num_former_iters : 0);
}; };
// Register reconfigurations // Register reconfigurations
constexpr int kNumTMARegisters = 40; constexpr uint32_t kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232; constexpr uint32_t kNumMathRegisters = 232;
// Block scheduler // Block scheduler
uint32_t m_block_idx, n_block_idx; uint32_t m_block_idx, n_block_idx;
@ -173,10 +178,9 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
if (threadIdx.x == kNumMathThreads) { if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks // Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) { while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type, auto _) { launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B // Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible. // NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
@ -194,7 +198,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Issue TMA A // Issue TMA A
auto& full_barrier = *full_barriers[s]; auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier), tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a); num_tma_multicast_a);
@ -216,7 +220,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive(); full_barriers[s]->arrive();
} }
}, 0); }, false, 0);
} }
// To safely deconstruct distributed shared barriers, we need another round of empty waits // To safely deconstruct distributed shared barriers, we need another round of empty waits
@ -257,12 +261,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
cutlass::arch::NamedBarrier(kNumMathThreads).sync(); cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion // Accumulation for WGMMA or CUDA promotion
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
// Empty barrier arrival // Empty barrier arrival
auto empty_barrier_arrive = [&](int s) { auto empty_barrier_arrive = [&](uint32_t s) {
if constexpr (kNumTMAMulticast == 1) { if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void(); lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else { } else {
@ -272,13 +276,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
}; };
// Launch MMAs // Launch MMAs
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kSkipComputation = std::is_same_v<decltype(skip_type), SkipComputation>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 :
(kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K);
#pragma unroll #pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) { for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Read B scales // Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
@ -300,18 +305,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Commit WGMMA instructions // Commit WGMMA instructions
#pragma unroll #pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i) for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]); warpgroup_fence_operand(accum[i]);
warpgroup_arrive(); warpgroup_arrive();
#pragma unroll #pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k); WGMMA::wgmma(desc_a, desc_b, accum, k);
} }
warpgroup_commit_batch(); warpgroup_commit_batch();
#pragma unroll #pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i) for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]); warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>(); warpgroup_wait<0>();
@ -328,7 +333,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll #pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters; bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
@ -345,7 +350,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s); empty_barrier_arrive(s);
} }
}, num_former_iters); }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
// TMA checks // TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
@ -355,7 +360,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions"); "Unaligned TMA store or too many TMA store instructions");
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1, DG_STATIC_ASSERT(static_cast<uint32_t>(kSwizzleDMode > 0) + static_cast<uint32_t>(BLOCK_N_PADDING > 0) <= 1,
"Swizzling and padding are not compatible"); "Swizzling and padding are not compatible");
// Wait last TMA store to be finished // Wait last TMA store to be finished
@ -375,7 +380,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
uint8_t* smem_ptr = nullptr; uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) { if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset // Calculate the swizzling atom offset and in-atom offset
constexpr int kNumBankGroupBytes = 16; constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
// Calculate the index of the bank group to be written in the atom // Calculate the index of the bank group to be written in the atom

View File

@ -49,11 +49,11 @@ struct Scheduler {
} }
// ReSharper disable once CppNotAllPathsReturnValue // ReSharper disable once CppNotAllPathsReturnValue
__device__ __forceinline__ bool is_m_valid(const uint32_t& m_offset, const uint32_t& m_block_idx) const { __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
if constexpr (kGemmType == GemmType::Normal) { if constexpr (kGemmType == GemmType::Normal) {
return true; return true;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) { } else if constexpr (kGemmType == GemmType::GroupedContiguous) {
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) != -1; return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) > 0;
} else if constexpr (kGemmType == GemmType::GroupedMasked) { } else if constexpr (kGemmType == GemmType::GroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx); return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
} }
@ -76,7 +76,7 @@ struct Scheduler {
} }
} }
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, uint32_t block_idx,
uint32_t& m_block_idx, uint32_t& n_block_idx) { uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");

View File

@ -121,7 +121,7 @@ class Compiler:
'--ptxas-options=--register-usage-level=10' + '--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,161,174,177,940'] '--diag-suppress=39,161,174,177,186,940']
@staticmethod @staticmethod
def include_dirs() -> List[str]: def include_dirs() -> List[str]: