mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-08 17:29:25 +00:00
Use 1D TMA store
This commit is contained in:
parent
a77009cb14
commit
8041ed7164
@ -54,8 +54,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
uint32_t shape_m,
|
uint32_t shape_m,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
const __grid_constant__ CUtensorMap tensor_map_scales_a) {
|
||||||
const __grid_constant__ std::pair<CUtensorMap, CUtensorMap> tensor_map_d) {
|
|
||||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||||
// Scaling checks
|
// Scaling checks
|
||||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||||
@ -87,10 +86,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||||
if constexpr (SHAPE_N >= BLOCK_N)
|
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.first));
|
|
||||||
if constexpr (SHAPE_N % BLOCK_N != 0)
|
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.second));
|
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
@ -357,27 +352,33 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t partition_idx = 0; partition_idx < 2; ++ partition_idx) {
|
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||||
// Store into shared memory
|
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||||
#pragma unroll
|
__float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}),
|
||||||
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
__float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}),
|
||||||
auto casted = __float22bfloat162_rn({shifted_accum[i * 4 + partition_idx * 2 + 0],
|
__float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}),
|
||||||
shifted_accum[i * 4 + partition_idx * 2 + 1]});
|
__float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}),
|
||||||
auto smem_ptr = smem_d + i * 8;
|
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
|
||||||
smem_ptr += (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
);
|
||||||
SM90_U32x1_STSM_N<nv_bfloat162>::copy(casted, smem_ptr);
|
}
|
||||||
|
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||||
|
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||||
|
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||||
|
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||||
|
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issue TMA store
|
// Issue TMA store
|
||||||
cute::tma_store_fence();
|
cute::tma_store_fence();
|
||||||
if (lane_idx < 8) {
|
if (lane_idx < 16) {
|
||||||
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
|
||||||
auto gmem_ptr = gmem_d + (m_block_idx * BLOCK_M + m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
|
||||||
|
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
|
||||||
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16));
|
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16));
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Wait TMA to be finished
|
// Wait TMA to be finished
|
||||||
cute::tma_store_arrive();
|
cute::tma_store_arrive();
|
||||||
@ -408,7 +409,6 @@ public:
|
|||||||
const CUtensorMap& tma_a_desc,
|
const CUtensorMap& tma_a_desc,
|
||||||
const CUtensorMap& tma_b_desc,
|
const CUtensorMap& tma_b_desc,
|
||||||
const CUtensorMap& tma_scales_a_desc,
|
const CUtensorMap& tma_scales_a_desc,
|
||||||
const std::pair<CUtensorMap, CUtensorMap>& tma_d_desc,
|
|
||||||
cudaStream_t stream,
|
cudaStream_t stream,
|
||||||
int num_sms, uint32_t smem_size) {
|
int num_sms, uint32_t smem_size) {
|
||||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||||
@ -441,7 +441,7 @@ public:
|
|||||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||||
gmem_d, scales_b, grouped_layout,
|
gmem_d, scales_b, grouped_layout,
|
||||||
shape_m,
|
shape_m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
tma_a_desc, tma_b_desc, tma_scales_a_desc);
|
||||||
DG_HOST_ASSERT(status == cudaSuccess);
|
DG_HOST_ASSERT(status == cudaSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,29 +457,6 @@ public:
|
|||||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static std::pair<CUtensorMap, CUtensorMap> make_3d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
|
||||||
// NOTES: must be row-major
|
|
||||||
auto m = shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1);
|
|
||||||
uint64_t gmem_strides[2] = {BLOCK_N * sizeof(T), SHAPE_N * sizeof(T)};
|
|
||||||
uint32_t smem_dim[3] = {BLOCK_N + BLOCK_N_PADDING, 1, BLOCK_M};
|
|
||||||
|
|
||||||
// `SHAPE_N % BLOCK_N` maybe not zero, dividing them into two parts
|
|
||||||
CUtensorMap aligned;
|
|
||||||
if constexpr (SHAPE_N >= BLOCK_N) {
|
|
||||||
uint64_t gmem_dim[3] = {BLOCK_N, SHAPE_N / BLOCK_N, m};
|
|
||||||
aligned = make_3d_tma_copy_desc(global_address, gmem_dim, gmem_strides, smem_dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
CUtensorMap unaligned;
|
|
||||||
if constexpr (SHAPE_N % BLOCK_N != 0) {
|
|
||||||
uint64_t gmem_dim[3] = {SHAPE_N % BLOCK_N, 1, m};
|
|
||||||
unaligned = make_3d_tma_copy_desc(global_address + (SHAPE_N / BLOCK_N) * BLOCK_N,
|
|
||||||
gmem_dim, gmem_strides, smem_dim);
|
|
||||||
}
|
|
||||||
return {aligned, unaligned};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||||
// Make TMA aligned to 16 bytes
|
// Make TMA aligned to 16 bytes
|
||||||
|
@ -80,27 +80,6 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
|||||||
return tensor_map;
|
return tensor_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3],
|
|
||||||
uint64_t gmem_strides[2], uint32_t smem_dim[3],
|
|
||||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
|
||||||
CUtensorMap tensor_map = {};
|
|
||||||
uint32_t elem_strides[3] = {1, 1, 1};
|
|
||||||
|
|
||||||
if (encode_func == nullptr)
|
|
||||||
encode_func = get_cuTensorMapEncodeTiled();
|
|
||||||
|
|
||||||
auto result = encode_func(
|
|
||||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 3,
|
|
||||||
global_address, gmem_dim, gmem_strides, smem_dim, elem_strides,
|
|
||||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
|
||||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
|
|
||||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
|
||||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
|
||||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
|
||||||
return tensor_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <uint32_t kNumTMAMulticast = 1>
|
template <uint32_t kNumTMAMulticast = 1>
|
||||||
__device__ __forceinline__ void
|
__device__ __forceinline__ void
|
||||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||||
|
@ -28,10 +28,9 @@ using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, 1, kNumSta
|
|||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||||
auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m);
|
|
||||||
gemm_t::run(out, rhs_scales, nullptr,
|
gemm_t::run(out, rhs_scales, nullptr,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc,
|
||||||
stream, num_sms, smem_size);
|
stream, num_sms, smem_size);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -28,10 +28,9 @@ using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kNumGroups
|
|||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||||
auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m);
|
|
||||||
gemm_t::run(out, rhs_scales, grouped_layout,
|
gemm_t::run(out, rhs_scales, grouped_layout,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc,
|
||||||
stream, num_sms, smem_size);
|
stream, num_sms, smem_size);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -20,10 +20,9 @@ int main() {
|
|||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m);
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0));
|
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0));
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast<float*>(0), m);
|
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast<float*>(0), m);
|
||||||
auto tma_d_desc = gemm_t::make_3d_tma_d_desc(reinterpret_cast<nv_bfloat16*>(0), m);
|
|
||||||
gemm_t::run(nullptr, nullptr, nullptr,
|
gemm_t::run(nullptr, nullptr, nullptr,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc,
|
||||||
nullptr, 132, 0);
|
nullptr, 132, 0);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user