mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Finish a draft version
This commit is contained in:
@@ -18,8 +18,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [x] MoE scheduler with TMA multicast compatibility
|
||||
- [x] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [ ] Skip useless computation on M
|
||||
- [ ] Share pipeline stages between scheduled blocks
|
||||
- [ ] TMA store pipeline
|
||||
- [x] Share pipeline stages between scheduled blocks
|
||||
- [ ] NVRTC as a faster compiler
|
||||
- [ ] Sanitizer for testing
|
||||
- [ ] Weight gradient kernels for dense models
|
||||
|
||||
@@ -42,6 +42,17 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
|
||||
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||
}
|
||||
|
||||
template <uint32_t kNumStages, uint32_t kStageIdx = 0>
|
||||
__device__ __host__ void dispatch_stage_idx(const auto& func, uint32_t k_iter, uint32_t stage_idx, const auto& num_former_iters_type) {
|
||||
if (stage_idx == kStageIdx) {
|
||||
func(k_iter, kStageIdx, num_former_iters_type);
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (kStageIdx + 1 < kNumStages)
|
||||
dispatch_stage_idx<kNumStages, kStageIdx + 1>(func, k_iter, stage_idx, num_former_iters_type);
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
@@ -77,14 +88,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
@@ -148,24 +157,25 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
|
||||
DG_STATIC_ASSERT(SHAPE_K % BLOCK_K == 0, "Invalid shape of the K dim");
|
||||
constexpr uint32_t kNumStagesPerBlock = SHAPE_K / BLOCK_K;
|
||||
auto launch_k_iterations = [&](const auto& func, int num_former_iters) {
|
||||
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
|
||||
// NOTES: for too-many branches (> 5), we disable this optimization
|
||||
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([&](const auto& func, auto num_former_iters_type) {
|
||||
#pragma unroll
|
||||
for (uint32_t k_iter = 0; k_iter < kNumStagesPerBlock; ++ k_iter) {
|
||||
uint32_t stage_idx = (scheduler.current_iter * kNumStagesPerBlock + k_iter) % kNumStages;
|
||||
dispatch_stage_idx<kNumStages>(func, k_iter, stage_idx, num_former_iters_type);
|
||||
}
|
||||
}, func, kShouldOptimize ? num_former_iters : 0);
|
||||
};
|
||||
@@ -174,10 +184,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
@@ -186,11 +192,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type, auto _) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
launch_k_iterations([&](uint32_t k_iter, uint32_t stage_idx, auto _) {
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
@@ -198,46 +200,30 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
// Wait consumer release
|
||||
auto phase_idx = ((scheduler.current_iter * kNumStagesPerBlock + k_iter) / kNumStages) & 1;
|
||||
empty_barriers[stage_idx]->wait(phase_idx ^ 1);
|
||||
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
|
||||
num_tma_multicast_a);
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[stage_idx];
|
||||
uint32_t k_idx = k_iter * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[stage_idx], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
|
||||
num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[stage_idx], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
// TODO: to safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
@@ -285,78 +271,65 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
launch_k_iterations([&](uint32_t k_iter, uint32_t stage_idx, auto _) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
auto phase_idx = ((scheduler.current_iter * kNumStagesPerBlock + k_iter) / kNumStages) & 1;
|
||||
full_barriers[stage_idx]->wait(phase_idx);
|
||||
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[stage_idx] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_scales_a[stage_idx] + r_1 + m_offset);
|
||||
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(stage_idx);
|
||||
|
||||
// Promote with scales
|
||||
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
}, num_former_iters);
|
||||
|
||||
@@ -441,6 +414,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: refactor here
|
||||
if constexpr (kNumTMAMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
|
||||
Reference in New Issue
Block a user