add notes

This commit is contained in:
yukuai 2025-04-22 15:23:32 +08:00
parent d2369e8f30
commit bfb2bcc04d
2 changed files with 3 additions and 1 deletions

View File

@ -226,7 +226,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Issue TMA B
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
} else {

View File

@ -45,6 +45,9 @@ struct Scheduler {
}
__device__ __forceinline__ bool is_block_in_complete_cluster(const uint32_t& m_block_idx, uint32_t& n_block_idx) {
// NOTES: For the case where the total number is an odd number of blocks, the last block requires special barrier processing.
// Here, we need each cluster to have exactly two blocks.
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
if (num_blocks_in_group == 1 and n_block_idx == kNumNBlocks - 1 and m_block_idx == num_aligned_m_blocks - 1
and num_aligned_m_blocks % 2 == 1 and kNumNBlocks % 2 == 1)
return false;