mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 15:54:21 +00:00
Swizzling draft
This commit is contained in:
parent
76804c096d
commit
4c111418a2
@ -352,6 +352,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
}
|
||||
}, num_former_iters);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16);
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
|
||||
"Swizzling and padding are not compatible");
|
||||
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
@ -360,19 +370,52 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// Swizzle or padding into the correct address
|
||||
uint8_t* smem_ptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the base pointer of the swizzling atom
|
||||
constexpr int kNumBankGroupBytes = 16;
|
||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * kNumBankGroupBytes;
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(WGMMA_M_PER_WARP, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(WGMMA_M_PER_WARP * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
auto row = bank_group_index / 8, col = bank_group_index % 8;
|
||||
col ^= row % (kSwizzleDMode / 16);
|
||||
|
||||
// Add back into the base pointer
|
||||
smem_ptr += row * kNumBankGroupBytes * 8 + col * kNumBankGroupBytes;
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
||||
}
|
||||
|
||||
// NOTES: only 16 lanes' addresses are used
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
||||
smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8
|
||||
smem_ptr
|
||||
);
|
||||
}
|
||||
|
||||
// Issue TMA store
|
||||
cute::tma_store_fence();
|
||||
if (lane_idx < 16) {
|
||||
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + in_block_n_offset;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
|
||||
} else if (kSwizzleDMode == 0 and lane_idx < WGMMA_M_PER_WARP) {
|
||||
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
|
||||
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
||||
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
||||
auto smem_ptr = smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
||||
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
||||
auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N;
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16));
|
||||
}
|
||||
@ -468,9 +511,11 @@ public:
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
// TODO: try `FP8MMASelector<BLOCK_N>::type::M` (warpgroup-level TMA)
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
min(BLOCK_M, shape_m), kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
FP8MMASelector<BLOCK_N>::type::M / 4,
|
||||
kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user