mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Optimize swizzle performance
This commit is contained in:
parent
5ff0eb24b5
commit
9406e2a3a1
@ -353,7 +353,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
}, num_former_iters);
|
}, num_former_iters);
|
||||||
|
|
||||||
// TMA checks
|
// TMA checks
|
||||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16);
|
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||||
|
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes;
|
||||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||||
if constexpr (kSwizzleDMode > 0) {
|
if constexpr (kSwizzleDMode > 0) {
|
||||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||||
@ -373,13 +374,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||||
// Swizzle or padding into the correct address
|
// Swizzle or padding into the correct address
|
||||||
uint8_t* smem_ptr;
|
uint8_t* smem_ptr = nullptr;
|
||||||
if constexpr (kSwizzleDMode > 0) {
|
if constexpr (kSwizzleDMode > 0) {
|
||||||
// Calculate the base pointer of the swizzling atom
|
// Calculate the swizzling atom offset and in-atom offset
|
||||||
constexpr int kNumBankGroupBytes = 16;
|
constexpr int kNumBankGroupBytes = 16;
|
||||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
|
||||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode;
|
|
||||||
|
|
||||||
// Calculate the index of the bank group to be written in the atom
|
// Calculate the index of the bank group to be written in the atom
|
||||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||||
@ -391,7 +390,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
col ^= row % (kSwizzleDMode / 16);
|
col ^= row % (kSwizzleDMode / 16);
|
||||||
|
|
||||||
// Add back into the base pointer
|
// Add back into the base pointer
|
||||||
smem_ptr += row * kNumBankGroupBytes * 8 + col * kNumBankGroupBytes;
|
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||||
|
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||||
|
warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset
|
||||||
|
m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset
|
||||||
|
atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||||
|
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||||
} else {
|
} else {
|
||||||
// No swizzling, just padding
|
// No swizzling, just padding
|
||||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
||||||
|
Loading…
Reference in New Issue
Block a user