mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-31 18:48:16 +00:00
Optimize compilation speed
This commit is contained in:
parent
612dd57001
commit
9c4f6f53f5
@ -27,15 +27,15 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <int kNumFormerIters, int kGap, int kEnd>
|
||||
template <bool kMustUseUniformedScaleB, int kNumFormerIters, int kGap, int kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
|
||||
if (num_former_iters == kNumFormerIters) {
|
||||
if (kMustUseUniformedScaleB or num_former_iters == kNumFormerIters) {
|
||||
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (kNumFormerIters + kGap <= kEnd)
|
||||
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||
if constexpr (kNumFormerIters + kGap <= kEnd and not kMustUseUniformedScaleB)
|
||||
outer_launch_k_iterations<kMustUseUniformedScaleB, kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
@ -143,7 +143,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
|
||||
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr int kEnd = BLOCK_K / 8;
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
|
||||
outer_launch_k_iterations<kMustUseUniformedScaleB, 0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
@ -324,7 +324,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
}, num_former_iters);
|
||||
}, kMustUseUniformedScaleB ? 0 : num_former_iters);
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
|
Loading…
Reference in New Issue
Block a user