mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Fix bugs
This commit is contained in:
@@ -379,7 +379,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
constexpr int kNumBankGroupBytes = 16;
|
constexpr int kNumBankGroupBytes = 16;
|
||||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * kNumBankGroupBytes;
|
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode;
|
||||||
|
|
||||||
// Calculate the index of the bank group to be written in the atom
|
// Calculate the index of the bank group to be written in the atom
|
||||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||||
@@ -410,7 +410,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
|
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
|
||||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||||
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
|
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
|
||||||
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + in_block_n_offset;
|
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP;
|
||||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
|
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
|
||||||
|
|||||||
Reference in New Issue
Block a user