mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-07 23:14:22 +00:00
Simplify
This commit is contained in:
parent
1e23215eb6
commit
46c7a1ef36
@ -192,7 +192,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
if (threadIdx.x == kNumMathThreads) {
|
if (threadIdx.x == kNumMathThreads) {
|
||||||
// Persistently schedule over blocks
|
// Persistently schedule over blocks
|
||||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||||
launch_k_iterations([&](uint32_t k_iter, uint32_t stage_idx, auto _) {
|
|
||||||
// Assign TMA multicast number into A and B
|
// Assign TMA multicast number into A and B
|
||||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||||
@ -200,6 +199,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||||
|
|
||||||
|
launch_k_iterations([&](uint32_t k_iter, uint32_t stage_idx, auto _) {
|
||||||
// Wait consumer release
|
// Wait consumer release
|
||||||
auto phase_idx = ((scheduler.current_iter * kNumStagesPerBlock + k_iter) / kNumStages) & 1;
|
auto phase_idx = ((scheduler.current_iter * kNumStagesPerBlock + k_iter) / kNumStages) & 1;
|
||||||
empty_barriers[stage_idx]->wait(phase_idx ^ 1);
|
empty_barriers[stage_idx]->wait(phase_idx ^ 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user