mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add TMA D descriptor
This commit is contained in:
parent
6078b25424
commit
93c92c2c89
@ -45,6 +45,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
|
|||||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||||
uint32_t BLOCK_N_PADDING,
|
uint32_t BLOCK_N_PADDING,
|
||||||
|
uint32_t kSwizzleDMode,
|
||||||
uint32_t kNumGroups, uint32_t kNumStages,
|
uint32_t kNumGroups, uint32_t kNumStages,
|
||||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||||
@ -54,7 +55,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
uint32_t shape_m,
|
uint32_t shape_m,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_scales_a) {
|
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||||
|
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||||
// Scaling checks
|
// Scaling checks
|
||||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||||
@ -86,6 +88,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||||
|
|
||||||
|
// `tensor_map_d` is only used in swizzling mode
|
||||||
|
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||||
|
if constexpr (kSwizzleDMode > 0)
|
||||||
|
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
@ -411,6 +418,7 @@ public:
|
|||||||
const CUtensorMap& tma_a_desc,
|
const CUtensorMap& tma_a_desc,
|
||||||
const CUtensorMap& tma_b_desc,
|
const CUtensorMap& tma_b_desc,
|
||||||
const CUtensorMap& tma_scales_a_desc,
|
const CUtensorMap& tma_scales_a_desc,
|
||||||
|
const CUtensorMap& tma_d_desc,
|
||||||
cudaStream_t stream,
|
cudaStream_t stream,
|
||||||
int num_sms, uint32_t smem_size) {
|
int num_sms, uint32_t smem_size) {
|
||||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||||
@ -419,6 +427,7 @@ public:
|
|||||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
||||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||||
BLOCK_N_PADDING,
|
BLOCK_N_PADDING,
|
||||||
|
kSwizzleDMode,
|
||||||
kNumGroups, kNumStages,
|
kNumGroups, kNumStages,
|
||||||
kNumTMAThreads, kNumMathThreadsPerGroup,
|
kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||||
@ -443,7 +452,7 @@ public:
|
|||||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||||
gmem_d, scales_b, grouped_layout,
|
gmem_d, scales_b, grouped_layout,
|
||||||
shape_m,
|
shape_m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc);
|
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||||
DG_HOST_ASSERT(status == cudaSuccess);
|
DG_HOST_ASSERT(status == cudaSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -459,6 +468,21 @@ public:
|
|||||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||||
|
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||||
|
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||||
|
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||||
|
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||||
|
|
||||||
|
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||||
|
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||||
|
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||||
|
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||||
|
min(BLOCK_M, shape_m), kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||||
|
swizzle_mode);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||||
// Make TMA aligned to 16 bytes
|
// Make TMA aligned to 16 bytes
|
||||||
|
@ -30,9 +30,10 @@ using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDM
|
|||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||||
|
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||||
gemm_t::run(out, rhs_scales, nullptr,
|
gemm_t::run(out, rhs_scales, nullptr,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||||
stream, num_sms, smem_size);
|
stream, num_sms, smem_size);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -29,9 +29,10 @@ using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDM
|
|||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||||
|
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||||
gemm_t::run(out, rhs_scales, grouped_layout,
|
gemm_t::run(out, rhs_scales, grouped_layout,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||||
stream, num_sms, smem_size);
|
stream, num_sms, smem_size);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -11,18 +11,20 @@ int main() {
|
|||||||
constexpr int BLOCK_N = 128;
|
constexpr int BLOCK_N = 128;
|
||||||
constexpr int BLOCK_K = 128;
|
constexpr int BLOCK_K = 128;
|
||||||
constexpr int BLOCK_N_PADDING = 0;
|
constexpr int BLOCK_N_PADDING = 0;
|
||||||
|
constexpr int kSwizzleDMode = 0;
|
||||||
constexpr int kNumGroups = 1;
|
constexpr int kNumGroups = 1;
|
||||||
constexpr int kNumStages = 5;
|
constexpr int kNumStages = 5;
|
||||||
constexpr int kNumTMAMulticast = 1;
|
constexpr int kNumTMAMulticast = 1;
|
||||||
constexpr bool kIsTMAMulticastOnA = false;
|
constexpr bool kIsTMAMulticastOnA = false;
|
||||||
|
|
||||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m);
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0));
|
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0));
|
||||||
|
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast<nv_bfloat16*>(0), m);
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast<float*>(0), m);
|
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast<float*>(0), m);
|
||||||
gemm_t::run(nullptr, nullptr, nullptr,
|
gemm_t::run(nullptr, nullptr, nullptr,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||||
nullptr, 132, 0);
|
nullptr, 132, 0);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user