mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-05-14 16:45:59 +00:00
Fix synchronization issues
This commit is contained in:
parent
70b9468520
commit
01a27728e6
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ dist/
|
|||||||
*perf.csv
|
*perf.csv
|
||||||
*.png
|
*.png
|
||||||
/.vscode
|
/.vscode
|
||||||
|
compile_commands.json
|
||||||
|
@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
|||||||
cudaGridDependencySynchronize();
|
cudaGridDependencySynchronize();
|
||||||
|
|
||||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
||||||
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
|
||||||
|
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
||||||
int begin_idx = tile_scheduler_metadata.x;
|
int begin_idx = tile_scheduler_metadata.x;
|
||||||
int begin_seqlen = tile_scheduler_metadata.y;
|
int begin_seqlen = tile_scheduler_metadata.y;
|
||||||
int end_idx = tile_scheduler_metadata.z;
|
int end_idx = tile_scheduler_metadata.z;
|
||||||
int end_seqlen = tile_scheduler_metadata.w;
|
int end_seqlen = tile_scheduler_metadata.w;
|
||||||
if (begin_idx >= params.b) return;
|
if (begin_idx >= params.b) return;
|
||||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
|
||||||
|
|
||||||
// Copy the first Q
|
// Copy the first Q
|
||||||
launch_q_copy<T>(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
|
launch_q_copy<T>(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
|
||||||
@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
|||||||
|
|
||||||
// Issue P0 = Q @ K0^T, wait
|
// Issue P0 = Q @ K0^T, wait
|
||||||
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
|
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
|
||||||
|
// We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0
|
||||||
|
NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized);
|
||||||
cute::warpgroup_wait<0>();
|
cute::warpgroup_wait<0>();
|
||||||
|
|
||||||
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
|
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
|
||||||
@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
|||||||
|
|
||||||
cute::tma_store_wait<0>();
|
cute::tma_store_wait<0>();
|
||||||
} else {
|
} else {
|
||||||
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
|
// Don't use __ldg because of PDL and instruction reordering
|
||||||
|
int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx;
|
||||||
float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
|
float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
|
||||||
float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
|
float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
|
||||||
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
|
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
|
||||||
|
@ -102,5 +102,6 @@ enum NamedBarriers : int {
|
|||||||
sScale0Ready = 0,
|
sScale0Ready = 0,
|
||||||
sScale1Ready = 1,
|
sScale1Ready = 1,
|
||||||
sP0Ready = 2,
|
sP0Ready = 2,
|
||||||
rO1sP0sV0RIssued = 3
|
rO1sP0sV0RIssued = 3,
|
||||||
|
sMInitialized = 4,
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user