mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Optimize swizzle performance
This commit is contained in:
parent
5ff0eb24b5
commit
9406e2a3a1
@ -353,7 +353,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
}, num_former_iters);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16);
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes;
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||
@ -373,13 +374,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// Swizzle or padding into the correct address
|
||||
uint8_t* smem_ptr;
|
||||
uint8_t* smem_ptr = nullptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the base pointer of the swizzling atom
|
||||
// Calculate the swizzling atom offset and in-atom offset
|
||||
constexpr int kNumBankGroupBytes = 16;
|
||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode;
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
@ -391,7 +390,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
col ^= row % (kSwizzleDMode / 16);
|
||||
|
||||
// Add back into the base pointer
|
||||
smem_ptr += row * kNumBankGroupBytes * 8 + col * kNumBankGroupBytes;
|
||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset
|
||||
m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset
|
||||
atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
||||
|
Loading…
Reference in New Issue
Block a user