Refactor launch-related structures

This commit is contained in:
Chenggang Zhao
2025-05-15 16:14:21 +08:00
parent e2d6a107ef
commit 816b39053a
9 changed files with 199 additions and 396 deletions

View File

@@ -18,7 +18,7 @@ namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kLastStages,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
@@ -127,7 +127,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [&](const auto& func) {
if constexpr (kLastStages == 0) {
if constexpr (kNumLastStages == 0) {
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
@@ -155,7 +155,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
@@ -244,7 +244,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll