diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 820080d..154073a 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -379,7 +379,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, constexpr int kNumBankGroupBytes = 16; auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP; auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); - smem_ptr = reinterpret_cast(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * kNumBankGroupBytes; + smem_ptr = reinterpret_cast(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode; // Calculate the index of the bank group to be written in the atom auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); @@ -410,7 +410,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP; auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N; - auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + in_block_n_offset; + auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP; cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_block_idx * BLOCK_N + in_block_n_offset, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);