fix: Update named barrier thread count to match actual participating threads

- Changed kNThreads (256) to 128 in NamedBarrier::arrive calls to match the actual number of threads in warp group
- Fixed potential deadlock issue where barrier was waiting for more threads than would arrive
- Updated both SReady and SoftmaxReady barrier synchronizations
This commit is contained in:
IshanaSabrish 2025-03-01 21:18:05 +05:30
parent 480405ada9
commit 927eebc10f

View File

@ -327,7 +327,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::copy(rP, tPsP);
cute::copy(scale_o, tScale_osScale_o);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
cutlass::arch::NamedBarrier::arrive(128, static_cast<int>(NamedBarriers::SReady));
flash::rescale_o(tOrO, scale_o);
@ -342,7 +342,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::copy(softmax.row_max, tRow_maxsRow_max);
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
cutlass::arch::NamedBarrier::arrive(128, static_cast<int>(NamedBarriers::SoftmaxReady));
} else {
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);
@ -411,7 +411,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::cp_async_fence();
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
cutlass::arch::NamedBarrier::sync(128, static_cast<int>(NamedBarriers::SReady));
if (n_block - 2 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 2]);
@ -434,7 +434,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
cutlass::arch::NamedBarrier::sync(128, static_cast<int>(NamedBarriers::SoftmaxReady));
cute::copy(tRow_maxsRow_max, softmax.row_max);
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
}