This commit is contained in:
Chenggang Zhao 2025-04-11 18:27:56 +08:00
parent 23d4289365
commit 5ff0eb24b5

View File

@ -379,7 +379,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
constexpr int kNumBankGroupBytes = 16;
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * kNumBankGroupBytes;
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode;
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
@ -410,7 +410,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + in_block_n_offset;
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);