mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-06 07:24:22 +00:00
Minor fixes
This commit is contained in:
parent
07ef809d82
commit
55d1d01c43
@ -226,8 +226,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
// Wait unaligned cases
|
// Wait unaligned cases
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
|
||||||
full_barriers[s]->arrive();
|
full_barriers[s]->arrive();
|
||||||
|
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||||
}
|
}
|
||||||
}, 0);
|
}, 0);
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ struct Scheduler {
|
|||||||
if (curr_group_idx == kNumGroups)
|
if (curr_group_idx == kNumGroups)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Within current group
|
// Within the current group
|
||||||
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||||
|
Loading…
Reference in New Issue
Block a user