diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 2e379d2..9743871 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -135,7 +135,9 @@ struct Scheduler { return false; // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned - is_peer_cta_alive = (next_block_idx ^ 1) < num_blocks; + is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) + num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); } return true;