mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 21:35:26 +00:00
Fix unaligned cases
This commit is contained in:
parent
f4b205bfa3
commit
d7c068d467
@ -279,7 +279,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -24,6 +24,7 @@ struct Scheduler {
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_blocks_in_group;
|
||||
bool is_peer_cta_alive = true;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
@ -133,6 +134,8 @@ struct Scheduler {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
|
||||
is_peer_cta_alive = (next_block_idx ^ 1) < num_blocks;
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
|
Loading…
Reference in New Issue
Block a user