Finish a draft version

This commit is contained in:
Chenggang Zhao
2025-04-28 09:49:45 +08:00
parent 86afd0c212
commit 1e23215eb6
2 changed files with 101 additions and 125 deletions

View File

@@ -18,8 +18,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] MoE scheduler with TMA multicast compatibility
- [x] Fix TMA multicast compatibility for indivisible shapes
- [ ] Skip useless computation on M
- [ ] Share pipeline stages between scheduled blocks
- [ ] TMA store pipeline
- [x] Share pipeline stages between scheduled blocks
- [ ] NVRTC as a faster compiler
- [ ] Sanitizer for testing
- [ ] Weight gradient kernels for dense models

View File

@@ -42,6 +42,17 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
}
template <uint32_t kNumStages, uint32_t kStageIdx = 0>
__device__ __host__ void dispatch_stage_idx(const auto& func, uint32_t k_iter, uint32_t stage_idx, const auto& num_former_iters_type) {
if (stage_idx == kStageIdx) {
func(k_iter, kStageIdx, num_former_iters_type);
return;
}
if constexpr (kStageIdx + 1 < kNumStages)
dispatch_stage_idx<kNumStages, kStageIdx + 1>(func, k_iter, stage_idx, num_former_iters_type);
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
@@ -77,14 +88,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
@@ -148,24 +157,25 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
DG_STATIC_ASSERT(SHAPE_K % BLOCK_K == 0, "Invalid shape of the K dim");
constexpr uint32_t kNumStagesPerBlock = SHAPE_K / BLOCK_K;
auto launch_k_iterations = [&](const auto& func, int num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type);
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type);
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
outer_launch_k_iterations<0, kGap, kEnd>([&](const auto& func, auto num_former_iters_type) {
#pragma unroll
for (uint32_t k_iter = 0; k_iter < kNumStagesPerBlock; ++ k_iter) {
uint32_t stage_idx = (scheduler.current_iter * kNumStagesPerBlock + k_iter) % kNumStages;
dispatch_stage_idx<kNumStages>(func, k_iter, stage_idx, num_former_iters_type);
}
}, func, kShouldOptimize ? num_former_iters : 0);
};
@@ -174,10 +184,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
@@ -186,11 +192,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type, auto _) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
launch_k_iterations([&](uint32_t k_iter, uint32_t stage_idx, auto _) {
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
@@ -198,46 +200,30 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Wait consumer release
auto phase_idx = ((scheduler.current_iter * kNumStagesPerBlock + k_iter) / kNumStages) & 1;
empty_barriers[stage_idx]->wait(phase_idx ^ 1);
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
num_tma_multicast_a);
// Issue TMA A
auto& full_barrier = *full_barriers[stage_idx];
uint32_t k_idx = k_iter * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[stage_idx], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[stage_idx], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}, 0);
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
// TODO: to safely deconstruct distributed shared barriers, we need another round of empty waits
}
} else {
// Math warp-groups for WGMMA
@@ -285,78 +271,65 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
launch_k_iterations([&](uint32_t k_iter, uint32_t stage_idx, auto _) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter + SHAPE_K_SCALES);
// Wait TMA arrivals
auto phase_idx = ((scheduler.current_iter * kNumStagesPerBlock + k_iter) / kNumStages) & 1;
full_barriers[stage_idx]->wait(phase_idx);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[stage_idx] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[stage_idx] + r_1 + m_offset);
// TODO: remove some useless computation for unaligned Ms
// Commit WGMMA instructions
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(stage_idx);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
}, num_former_iters);
@@ -441,6 +414,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
__syncwarp();
}
}
// TODO: refactor here
if constexpr (kNumTMAMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");