Fix synchronization issues

This commit is contained in:
ljss 2025-04-28 18:53:04 +08:00
parent 70b9468520
commit 01a27728e6
3 changed files with 10 additions and 4 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ dist/
*perf.csv
*.png
/.vscode
compile_commands.json

View File

@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
cudaGridDependencySynchronize();
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
// Copy the first Q
launch_q_copy<T>(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
// Issue P0 = Q @ K0^T, wait
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
// We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0
NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized);
cute::warpgroup_wait<0>();
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
cute::tma_store_wait<0>();
} else {
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
// Don't use __ldg because of PDL and instruction reordering
int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<

View File

@ -102,5 +102,6 @@ enum NamedBarriers : int {
sScale0Ready = 0,
sScale1Ready = 1,
sP0Ready = 2,
rO1sP0sV0RIssued = 3
rO1sP0sV0RIssued = 3,
sMInitialized = 4,
};