mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-07 18:54:23 +00:00
Optimize TMA issues
This commit is contained in:
parent
6366f5ad1a
commit
f8797d3c12
@ -354,7 +354,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
|
|
||||||
// TMA checks
|
// TMA checks
|
||||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes;
|
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
||||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||||
if constexpr (kSwizzleDMode > 0) {
|
if constexpr (kSwizzleDMode > 0) {
|
||||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||||
@ -384,8 +384,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||||
|
|
||||||
// Reshape the atom in another view and swizzle
|
// Reshape the atom in another view and swizzle
|
||||||
// - original: `(WGMMA_M_PER_WARP, kSwizzleDMode / kNumBankGroupBytes)`
|
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
||||||
// - new: `(WGMMA_M_PER_WARP * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||||
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
||||||
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
||||||
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
||||||
@ -393,13 +393,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
|
|
||||||
// Add back into the base pointer
|
// Add back into the base pointer
|
||||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||||
warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset
|
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
||||||
m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset
|
m_offset * kSwizzleDMode + // Wave offset
|
||||||
atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants)
|
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||||
} else {
|
} else {
|
||||||
// No swizzling, just padding
|
// No swizzling, just padding
|
||||||
|
// NOTES: padding must be zero for BF16 output
|
||||||
|
DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output");
|
||||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -410,29 +412,25 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
smem_ptr
|
smem_ptr
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issue TMA store
|
|
||||||
cute::tma_store_fence();
|
|
||||||
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
|
|
||||||
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
|
|
||||||
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
|
|
||||||
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP;
|
|
||||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
|
||||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
|
||||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
|
|
||||||
} else if (kSwizzleDMode == 0 and lane_idx < WGMMA_M_PER_WARP) {
|
|
||||||
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
|
|
||||||
auto smem_ptr = smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
|
||||||
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
|
||||||
auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N;
|
|
||||||
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16));
|
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
}
|
}
|
||||||
|
cute::tma_store_fence();
|
||||||
|
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||||
|
|
||||||
// Wait TMA to be finished
|
// Use TMA store to write back to global memory
|
||||||
cute::tma_store_arrive();
|
// TODO: compatible with FP32 output
|
||||||
cute::tma_store_wait<0>();
|
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||||
|
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||||
|
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||||
|
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||||
|
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||||
|
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||||
|
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||||
|
|
||||||
|
// Wait TMA to be finished
|
||||||
|
cute::tma_store_arrive();
|
||||||
|
cute::tma_store_wait<0>();
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
@ -519,11 +517,9 @@ public:
|
|||||||
|
|
||||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||||
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||||
// TODO: try `FP8MMASelector<BLOCK_N>::type::M` (warpgroup-level TMA)
|
|
||||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||||
FP8MMASelector<BLOCK_N>::type::M / 4,
|
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||||
kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
|
||||||
swizzle_mode);
|
swizzle_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user