mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 06:44:21 +00:00
OOB bugs fixed
This commit is contained in:
parent
99eb6ec563
commit
b0d64817a7
@ -375,7 +375,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
|
||||
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
||||
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16));
|
||||
auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N;
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user