mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
add notes
This commit is contained in:
parent
d2369e8f30
commit
bfb2bcc04d
@ -226,7 +226,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Issue TMA B
|
||||
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
|
||||
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
} else {
|
||||
|
||||
@ -45,6 +45,9 @@ struct Scheduler {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_block_in_complete_cluster(const uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
// NOTES: For the case where the total number is an odd number of blocks, the last block requires special barrier processing.
|
||||
// Here, we need each cluster to have exactly two blocks.
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
if (num_blocks_in_group == 1 and n_block_idx == kNumNBlocks - 1 and m_block_idx == num_aligned_m_blocks - 1
|
||||
and num_aligned_m_blocks % 2 == 1 and kNumNBlocks % 2 == 1)
|
||||
return false;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user