From 93c92c2c89a22c012129efc85cb5e6b5c3314bc7 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 11 Apr 2025 14:10:11 +0800 Subject: [PATCH] Add TMA D descriptor --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 28 ++++++++++++++++++++++-- deep_gemm/jit_kernels/gemm.py | 3 ++- deep_gemm/jit_kernels/m_grouped_gemm.py | 3 ++- indexing/main.cu | 6 +++-- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index f879b95..ad181d7 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -45,6 +45,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it template = 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); @@ -86,6 +88,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + + // `tensor_map_d` is only used in swizzling mode + // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode + if constexpr (kSwizzleDMode > 0) + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); } __syncwarp(); @@ -411,6 +418,7 @@ public: const CUtensorMap& tma_a_desc, const CUtensorMap& tma_b_desc, const CUtensorMap& tma_scales_a_desc, + const CUtensorMap& tma_d_desc, cudaStream_t stream, int num_sms, uint32_t smem_size) { // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps @@ -419,6 +427,7 @@ public: auto kernel = fp8_gemm_kernel; @@ -443,7 +452,7 @@ public: auto status = cudaLaunchKernelEx(&config, kernel, gmem_d, scales_b, grouped_layout, shape_m, - tma_a_desc, tma_b_desc, tma_scales_a_desc); + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); DG_HOST_ASSERT(status == cudaSuccess); } @@ -459,6 +468,21 @@ public: SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); } + template + static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { + auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; + if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; + if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; + + // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes + // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, + min(BLOCK_M, shape_m), kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), + swizzle_mode); + } + template static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { // Make TMA aligned to 16 bytes diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index ab64abd..eab5442 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -30,9 +30,10 @@ using gemm_t = Gemm; + using gemm_t = Gemm; auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0)); + auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast(0), m); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast(0), m); gemm_t::run(nullptr, nullptr, nullptr, m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, nullptr, 132, 0); return 0; }