mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 17:54:21 +00:00
Fix bugs
This commit is contained in:
parent
23d4289365
commit
5ff0eb24b5
@ -379,7 +379,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
constexpr int kNumBankGroupBytes = 16;
|
||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * kNumBankGroupBytes;
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode;
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
@ -410,7 +410,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + in_block_n_offset;
|
||||
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
|
||||
|
Loading…
Reference in New Issue
Block a user