This commit is contained in:
yukuai
2025-04-22 15:03:39 +08:00
parent ee4204ad98
commit 69477036c0
2 changed files with 1 additions and 2 deletions

View File

@@ -292,7 +292,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta_idx = scheduler.is_tma_multicast_valid(m_block_idx) ? lane_idx : cute::block_rank_in_cluster();
auto target_cta_idx = scheduler.num_blocks_in_group != 1 ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta_idx) : void();
}
};

View File

@@ -45,7 +45,6 @@ struct Scheduler {
}
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
return true;
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {