Fix unaligned cases

This commit is contained in:
Chenggang Zhao 2025-04-23 14:09:21 +08:00
parent f4b205bfa3
commit d7c068d467
2 changed files with 5 additions and 1 deletions

View File

@ -279,7 +279,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};

View File

@ -24,6 +24,7 @@ struct Scheduler {
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;
uint32_t num_blocks_in_group;
bool is_peer_cta_alive = true;
// For grouped GEMM
int* grouped_layout;
@ -133,6 +134,8 @@ struct Scheduler {
if (next_block_idx >= num_blocks)
return false;
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
is_peer_cta_alive = (next_block_idx ^ 1) < num_blocks;
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;