mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 21:14:21 +00:00
Optimize TMA issues
This commit is contained in:
parent
6366f5ad1a
commit
f8797d3c12
@ -354,7 +354,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes;
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||
@ -384,8 +384,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(WGMMA_M_PER_WARP, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(WGMMA_M_PER_WARP * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
||||
@ -393,13 +393,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Add back into the base pointer
|
||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset
|
||||
m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset
|
||||
atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
||||
m_offset * kSwizzleDMode + // Wave offset
|
||||
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
// NOTES: padding must be zero for BF16 output
|
||||
DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output");
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
||||
}
|
||||
|
||||
@ -410,29 +412,25 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
smem_ptr
|
||||
);
|
||||
}
|
||||
|
||||
// Issue TMA store
|
||||
cute::tma_store_fence();
|
||||
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
||||
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
|
||||
} else if (kSwizzleDMode == 0 and lane_idx < WGMMA_M_PER_WARP) {
|
||||
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
|
||||
auto smem_ptr = smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
||||
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
||||
auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N;
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Wait TMA to be finished
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
// Use TMA store to write back to global memory
|
||||
// TODO: compatible with FP32 output
|
||||
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
|
||||
// Wait TMA to be finished
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
@ -519,11 +517,9 @@ public:
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
// TODO: try `FP8MMASelector<BLOCK_N>::type::M` (warpgroup-level TMA)
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
FP8MMASelector<BLOCK_N>::type::M / 4,
|
||||
kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user