mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-05-14 00:31:40 +00:00
Fix synchronization issues
This commit is contained in:
parent
70b9468520
commit
01a27728e6
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ dist/
|
||||
*perf.csv
|
||||
*.png
|
||||
/.vscode
|
||||
compile_commands.json
|
||||
|
@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
||||
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
||||
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
|
||||
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
||||
int begin_idx = tile_scheduler_metadata.x;
|
||||
int begin_seqlen = tile_scheduler_metadata.y;
|
||||
int end_idx = tile_scheduler_metadata.z;
|
||||
int end_seqlen = tile_scheduler_metadata.w;
|
||||
if (begin_idx >= params.b) return;
|
||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
||||
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
|
||||
|
||||
// Copy the first Q
|
||||
launch_q_copy<T>(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
|
||||
@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
|
||||
// Issue P0 = Q @ K0^T, wait
|
||||
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
|
||||
// We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0
|
||||
NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized);
|
||||
cute::warpgroup_wait<0>();
|
||||
|
||||
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
|
||||
@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
|
||||
cute::tma_store_wait<0>();
|
||||
} else {
|
||||
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
|
||||
// Don't use __ldg because of PDL and instruction reordering
|
||||
int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx;
|
||||
float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
|
||||
float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
|
||||
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
|
||||
|
@ -102,5 +102,6 @@ enum NamedBarriers : int {
|
||||
sScale0Ready = 0,
|
||||
sScale1Ready = 1,
|
||||
sP0Ready = 2,
|
||||
rO1sP0sV0RIssued = 3
|
||||
rO1sP0sV0RIssued = 3,
|
||||
sMInitialized = 4,
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user