mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-03 07:10:57 +00:00
Overlap TMA store
This commit is contained in:
parent
4499c4ccbb
commit
340d9880f4
@ -364,6 +364,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
|
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
|
||||||
"Swizzling and padding are not compatible");
|
"Swizzling and padding are not compatible");
|
||||||
|
|
||||||
|
// Wait last TMA store to be finished
|
||||||
|
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
||||||
|
cute::tma_store_wait<0>();
|
||||||
|
|
||||||
// Write back to shared memory using STSM and issue TMA stores
|
// Write back to shared memory using STSM and issue TMA stores
|
||||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -424,10 +428,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||||
|
|
||||||
// Wait TMA to be finished
|
|
||||||
cute::tma_store_arrive();
|
cute::tma_store_arrive();
|
||||||
cute::tma_store_wait<0>();
|
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user