From b0d64817a77d93fe07ec22aee9cd877b3ce538cc Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 11 Apr 2025 11:00:47 +0800 Subject: [PATCH] OOB bugs fixed --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 89e9bef..a43ee2c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -375,7 +375,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; - cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16)); + auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N; + cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16)); } __syncwarp(); }