mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Restore previous indent
This commit is contained in:
parent
e7e38ed222
commit
2c5ab83c6c
@ -271,16 +271,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!scheduler.is_m_valid(math_wg_idx * WGMMA::M, m_block_idx)) {
|
|
||||||
// Skip useless computation for unaligned Ms
|
|
||||||
launch_k_iterations([&](int k_iter, auto type, auto _) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t s = 0; s < kNumStages; ++ s) {
|
|
||||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
|
||||||
empty_barrier_arrive(s);
|
|
||||||
}
|
|
||||||
}, num_former_iters);
|
|
||||||
} else {
|
|
||||||
// Launch MMAs
|
// Launch MMAs
|
||||||
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
||||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||||
@ -298,6 +288,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
|||||||
// Wait TMA arrivals
|
// Wait TMA arrivals
|
||||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||||
|
|
||||||
|
// TODO: remove some useless computation for unaligned Ms
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||||
@ -355,7 +346,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
|||||||
empty_barrier_arrive(s);
|
empty_barrier_arrive(s);
|
||||||
}
|
}
|
||||||
}, num_former_iters);
|
}, num_former_iters);
|
||||||
}
|
|
||||||
|
|
||||||
// TMA checks
|
// TMA checks
|
||||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user