Support block size 160

This commit is contained in:
Chenggang Zhao 2025-03-25 13:37:59 +08:00
parent 46eb0d08fb
commit b922e64cb2
4 changed files with 188 additions and 879 deletions

View File

@ -27,7 +27,7 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
@ -41,13 +41,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Currently, only BLOCK_N size of 160 is classified as a largeBlockTile configuration in our optimization framework.
constexpr bool largeBlockTile = (BLOCK_N == 160);
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
if constexpr(!largeBlockTile) {
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
}
// TODO: check `ceil_div(BLOCK_N, BLOCK_K) == 1` or `gcd(BLOCK_K, BLOCK_N) == BLOCK_N - BLOCK_K`
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
@ -150,760 +146,205 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
if constexpr (largeBlockTile) {
auto scheduler = SchedulerLargeBlockTile<kGemmType, SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(SHAPE_N, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
// Issue TMA A without broadcasting
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx));
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B with broadcasting
tma_copy<kNumTMAMulticast>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// To ensure optimal performance, conditional checks must never be placed inside the loop body of kNumInnerStages.
switch (n_block_idx % 4) {
case 0:
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
// In performance-critical code paths, we rigorously avoid branching logic and unnecessary instructions.
// Here, we adopt a brute-force approach by explicitly unrolling computations for all possible scenarios.
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
}
#pragma unroll
for (int i = WGMMA::kNumAccum / 4 / 5 * 4; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
break;
case 1:
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 3; ++ i) {
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
}
#pragma unroll
for (int i = WGMMA::kNumAccum / 4 / 5 * 3; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
break;
case 2:
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 2; ++ i) {
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
}
#pragma unroll
for (int i = WGMMA::kNumAccum / 4 / 5 * 2; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
break;
case 3:
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 1; ++ i) {
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
}
#pragma unroll
for (int i = WGMMA::kNumAccum / 4 / 5 * 1; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
break;
default:
break;
}
}
}
} else {
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
// Issue TMA A with broadcasting
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Issue TMA A with broadcasting
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast,
@ -926,7 +367,7 @@ public:
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_M, SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);

View File

@ -98,98 +98,7 @@ struct Scheduler {
return true;
}
};
template <GemmType kGemmType,
uint32_t SHAPE_M, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M),
uint32_t kNumMBlocksPerGroup = 16>
struct SchedulerLargeBlockTile {
int current_iter = -1;
uint32_t num_aligned_n_blocks;
// For normal GEMM
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;
// For grouped GEMM
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
__device__ __forceinline__ explicit SchedulerLargeBlockTile(const uint32_t shape_n,
int* grouped_layout = nullptr) {
num_aligned_n_blocks = ceil_div(shape_n, BLOCK_N);
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_aligned_n_blocks * kNumMBlocks;
} else if (kGemmType == GemmType::GroupedContiguous) {
num_blocks = num_aligned_n_blocks * kNumMBlocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::GroupedMasked) {
curr_group_idx = curr_cumsum = 0;
this->grouped_layout = grouped_layout;
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_n_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNumMBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
auto num_blocks_per_group = num_n_blocks * kNumMBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_m_block_idx = group_idx * kNumMBlocksPerGroup;
auto num_m_blocks_in_group = min(kNumMBlocksPerGroup, kNumMBlocks - first_m_block_idx);
// auto num_m_blocks_in_group = kNumMBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
n_block_idx = in_group_idx / num_m_blocks_in_group;
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
// if (threadIdx.x == 256) {
// printf("blockIdx.x: %d group_idx: %d in_group_idx: %d m_block_idx: %d n_block_idx: %d\n", blockIdx.x, group_idx, in_group_idx, m_block_idx, n_block_idx);
// }
}
template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if (kGemmType == GemmType::GroupedContiguous) {
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
return offset * shape_dim + block_idx * block_size;
} else if (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
}
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
if constexpr (kGemmType == GemmType::GroupedMasked) {
uint32_t num_m_blocks;
while (true) {
// End of the task
if (curr_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumMBlocks)
break;
// Move to check the next group
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumMBlocks, m_block_idx, n_block_idx);
} else {
if (next_block_idx >= num_blocks)
return false;
get_swizzled_block_idx(num_aligned_n_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@ -10,14 +10,14 @@ template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto M = {M}, N = {N}, K = {K};
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated GEMM
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@ -62,14 +62,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (64 if m <= 64 else 128, )
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
# Current optimizations target large-scale Normal GEMMs for dense models and
# Grouped GEMMs for MoE models (contiguous memory layout), with a potential
# block_n tile size of 160 to enhance data reuse in block tiling.
if m >= 4096 and num_groups == 1:
block_ns = tuple((16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 160))
else:
block_ns = tuple(range(16, 129, 8))
block_ns = tuple(range(16, 129, 8)) + (160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@ -84,16 +77,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
# The value 0.02 is currently an empirically estimated threshold to
# filter out cases unsuitable for large block tile configurations,
# with optimizations planned for later stages to address excluded scenarios.
if block_n == 160 and \
(num_waves * block_m * block_n - best_num_waves * best_block_m * best_block_n) / (best_num_waves * best_block_m * best_block_n) < 0.02:
success = True
elif block_n == 160:
success = False
else:
success = True
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
@ -105,44 +89,25 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
if best_block_n != 160:
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
else:
# NOTES: This is done to reduce the code footprint after unrolling.
# Additionally, if k does not meet the following conditions, a slight performance penalty will occur.
num_stages = 4
assert k / 128 % num_stages == 0
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
assert best_smem_size <= sm90_capacity
best_num_stages = num_stages
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
assert best_num_stages is not None
# Decide the number of TMA multicast
best_num_tma_multicast = 1
# When using large block tiling, broadcasting B is required to achieve maximum performance gains.
if best_block_n != 160:
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
else:
if m >= 4096 and is_tma_multicast_legal(m, best_block_m, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast
if best_block_n != 160:
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
else:
assert num_min_sms <= num_sms and is_tma_multicast_legal(m, best_block_m, best_num_tma_multicast, num_min_sms)
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
@ -198,7 +163,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
space=(),
includes=includes,
@ -212,5 +177,3 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size

View File

@ -11,14 +11,14 @@ template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto M = {M}, N = {N}, K = {K};
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated grouped GEMM
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
space=(),
includes=includes,
@ -106,8 +106,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -171,7 +169,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
space=(),
includes=includes,
@ -186,5 +184,3 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size