mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-02 17:10:56 +00:00
Remove unaligned predicates
This commit is contained in:
parent
3497428a5e
commit
7768319ffe
@ -27,6 +27,17 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
|||||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int kNumFormerIters, int kGap, int kEnd>
|
||||||
|
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
|
||||||
|
if (num_former_iters == kNumFormerIters) {
|
||||||
|
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kNumFormerIters + kGap <= kEnd)
|
||||||
|
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||||
|
}
|
||||||
|
|
||||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||||
uint32_t kNumGroups, uint32_t kNumStages,
|
uint32_t kNumGroups, uint32_t kNumStages,
|
||||||
@ -129,15 +140,19 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
// For pipeline unrolling
|
// For pipeline unrolling
|
||||||
struct DivisibleK {};
|
struct DivisibleK {};
|
||||||
struct NotDivisibleK {};
|
struct NotDivisibleK {};
|
||||||
auto launch_k_iterations = [](const auto& func) {
|
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
|
||||||
|
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
|
||||||
|
constexpr int kEnd = BLOCK_K / 8;
|
||||||
|
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
|
||||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||||
func(k_iter, DivisibleK{});
|
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||||
} else {
|
} else {
|
||||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||||
func(k_iter, DivisibleK{});
|
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||||
func(kNumIterations - 1, NotDivisibleK{});
|
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
|
||||||
}
|
}
|
||||||
|
}, func, num_former_iters);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register reconfigurations
|
// Register reconfigurations
|
||||||
@ -156,7 +171,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
if (threadIdx.x == kNumMathThreads) {
|
if (threadIdx.x == kNumMathThreads) {
|
||||||
// Persistently schedule over blocks
|
// Persistently schedule over blocks
|
||||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||||
launch_k_iterations([&](int k_iter, auto type) {
|
launch_k_iterations([&](int k_iter, auto type, auto _) {
|
||||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||||
@ -193,7 +208,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||||
full_barriers[s]->arrive();
|
full_barriers[s]->arrive();
|
||||||
}
|
}
|
||||||
});
|
}, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||||
@ -246,7 +261,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Launch MMAs
|
// Launch MMAs
|
||||||
launch_k_iterations([&](int k_iter, auto type) {
|
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
||||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||||
@ -292,13 +307,21 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
float scale_0_1, scale_1_1;
|
float scale_0_1, scale_1_1;
|
||||||
if constexpr (not kMustUseUniformedScaleB)
|
if constexpr (not kMustUseUniformedScaleB)
|
||||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||||
|
|
||||||
|
constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
for (int i = 0; i < kNumFormerIters; ++ i) {
|
||||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
|
||||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
|
||||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
|
||||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
|
||||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||||
|
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
|
||||||
|
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
|
||||||
|
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
|
||||||
|
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,7 +331,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||||
empty_barrier_arrive(s);
|
empty_barrier_arrive(s);
|
||||||
}
|
}
|
||||||
});
|
}, num_former_iters);
|
||||||
|
|
||||||
// Write back to shared memory using STSM
|
// Write back to shared memory using STSM
|
||||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||||
|
@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
|||||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||||
'--diag-suppress=177,174,940']
|
'--diag-suppress=177,174,940']
|
||||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||||
include_dirs = [get_jit_include_dir()]
|
include_dirs = [get_jit_include_dir()]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user