From a77009cb149b132eebe3d8f0a11032c2ad2db8fb Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 10 Apr 2025 18:07:25 +0800 Subject: [PATCH 1/4] Make partition pipelined --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 56 ++++++++++------------- deep_gemm/include/deep_gemm/mma_utils.cuh | 10 ++++ 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 9ba930f..de6ba38 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -350,48 +350,38 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, } }, num_former_iters); - // Write back to shared memory using STSM + // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto m_offset = local_idx * WAVE_BLOCK_M; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); - } - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + for (uint32_t partition_idx = 0; partition_idx < 2; ++ partition_idx) { + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + auto casted = __float22bfloat162_rn({shifted_accum[i * 4 + partition_idx * 2 + 0], + shifted_accum[i * 4 + partition_idx * 2 + 1]}); + auto smem_ptr = smem_d + i * 8; + smem_ptr += (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); + SM90_U32x1_STSM_N::copy(casted, smem_ptr); + } - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - if (n_block_idx < SHAPE_N / BLOCK_N) { - // Except the last unaligned block - cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - } else { - // The last unaligned block - cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + // Issue TMA store + cute::tma_store_fence(); + if (lane_idx < 8) { + auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); + auto gmem_ptr = gmem_d + (m_block_idx * BLOCK_M + m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; + cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16)); + } + __syncwarp(); } - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); } - __syncwarp(); + + // Wait TMA to be finished + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); } } #else diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index 0cc554a..bb6284e 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -867,6 +867,16 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int kNumAccum = M * N / 128; }; +template +struct SM90_U32x1_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + const uint32_t src[1] = {*reinterpret_cast(&src_0)}; + asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "l"(smem_dst), "r"(src[0])); + } +}; + template struct SM90_U32x2_STSM_N { __device__ __forceinline__ static void From 8041ed71642e58e8a36d63d3a522e6e7f218eeeb Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 11 Apr 2025 10:42:01 +0800 Subject: [PATCH 2/4] Use 1D TMA store --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 79 ++++++++--------------- deep_gemm/include/deep_gemm/tma_utils.cuh | 21 ------ deep_gemm/jit_kernels/gemm.py | 3 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 3 +- indexing/main.cu | 3 +- 5 files changed, 31 insertions(+), 78 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index de6ba38..89e9bef 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -54,8 +54,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, uint32_t shape_m, const __grid_constant__ CUtensorMap tensor_map_a, const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ std::pair tensor_map_d) { + const __grid_constant__ CUtensorMap tensor_map_scales_a) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); @@ -87,10 +86,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - if constexpr (SHAPE_N >= BLOCK_N) - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d.first)); - if constexpr (SHAPE_N % BLOCK_N != 0) - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d.second)); } __syncwarp(); @@ -354,29 +349,35 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; + auto m_offset = local_idx * WAVE_BLOCK_M; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; #pragma unroll - for (uint32_t partition_idx = 0; partition_idx < 2; ++ partition_idx) { - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - auto casted = __float22bfloat162_rn({shifted_accum[i * 4 + partition_idx * 2 + 0], - shifted_accum[i * 4 + partition_idx * 2 + 1]}); - auto smem_ptr = smem_d + i * 8; - smem_ptr += (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); - SM90_U32x1_STSM_N::copy(casted, smem_ptr); - } - - // Issue TMA store - cute::tma_store_fence(); - if (lane_idx < 8) { - auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); - auto gmem_ptr = gmem_d + (m_block_idx * BLOCK_M + m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; - cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16)); - } - __syncwarp(); + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), + __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + ); } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + ); + } + + // Issue TMA store + cute::tma_store_fence(); + if (lane_idx < 16) { + uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); + auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; + cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16)); + } + __syncwarp(); } // Wait TMA to be finished @@ -408,7 +409,6 @@ public: const CUtensorMap& tma_a_desc, const CUtensorMap& tma_b_desc, const CUtensorMap& tma_scales_a_desc, - const std::pair& tma_d_desc, cudaStream_t stream, int num_sms, uint32_t smem_size) { // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps @@ -441,7 +441,7 @@ public: auto status = cudaLaunchKernelEx(&config, kernel, gmem_d, scales_b, grouped_layout, shape_m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); + tma_a_desc, tma_b_desc, tma_scales_a_desc); DG_HOST_ASSERT(status == cudaSuccess); } @@ -457,29 +457,6 @@ public: SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); } - template - static std::pair make_3d_tma_d_desc(T* global_address, uint32_t shape_m) { - // NOTES: must be row-major - auto m = shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1); - uint64_t gmem_strides[2] = {BLOCK_N * sizeof(T), SHAPE_N * sizeof(T)}; - uint32_t smem_dim[3] = {BLOCK_N + BLOCK_N_PADDING, 1, BLOCK_M}; - - // `SHAPE_N % BLOCK_N` maybe not zero, dividing them into two parts - CUtensorMap aligned; - if constexpr (SHAPE_N >= BLOCK_N) { - uint64_t gmem_dim[3] = {BLOCK_N, SHAPE_N / BLOCK_N, m}; - aligned = make_3d_tma_copy_desc(global_address, gmem_dim, gmem_strides, smem_dim); - } - - CUtensorMap unaligned; - if constexpr (SHAPE_N % BLOCK_N != 0) { - uint64_t gmem_dim[3] = {SHAPE_N % BLOCK_N, 1, m}; - unaligned = make_3d_tma_copy_desc(global_address + (SHAPE_N / BLOCK_N) * BLOCK_N, - gmem_dim, gmem_strides, smem_dim); - } - return {aligned, unaligned}; - } - template static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { // Make TMA aligned to 16 bytes diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index e7bc241..22731a6 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -80,27 +80,6 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], return tensor_map; } -template -CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3], - uint64_t gmem_strides[2], uint32_t smem_dim[3], - PFN_cuTensorMapEncodeTiled encode_func = nullptr) { - CUtensorMap tensor_map = {}; - uint32_t elem_strides[3] = {1, 1, 1}; - - if (encode_func == nullptr) - encode_func = get_cuTensorMapEncodeTiled(); - - auto result = encode_func( - &tensor_map, get_CUtensorMapDataType>(), 3, - global_address, gmem_dim, gmem_strides, smem_dim, elem_strides, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - DG_HOST_ASSERT(result == CUDA_SUCCESS); - return tensor_map; -} - template __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index a4407e7..4b5db70 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -28,10 +28,9 @@ using gemm_t = Gemm(0), m); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0)); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast(0), m); - auto tma_d_desc = gemm_t::make_3d_tma_d_desc(reinterpret_cast(0), m); gemm_t::run(nullptr, nullptr, nullptr, m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + tma_a_desc, tma_b_desc, tma_scales_a_desc, nullptr, 132, 0); return 0; } From 99eb6ec563807cb8879633281e01017897a1c372 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 11 Apr 2025 10:45:36 +0800 Subject: [PATCH 3/4] Remove useless STSM --- deep_gemm/include/deep_gemm/mma_utils.cuh | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index bb6284e..0cc554a 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -867,16 +867,6 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int kNumAccum = M * N / 128; }; -template -struct SM90_U32x1_STSM_N { - __device__ __forceinline__ static void - copy(dtype_t src_0, void* smem_dst) { - const uint32_t src[1] = {*reinterpret_cast(&src_0)}; - asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" - :: "l"(smem_dst), "r"(src[0])); - } -}; - template struct SM90_U32x2_STSM_N { __device__ __forceinline__ static void From b0d64817a77d93fe07ec22aee9cd877b3ce538cc Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 11 Apr 2025 11:00:47 +0800 Subject: [PATCH 4/4] OOB bugs fixed --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 89e9bef..a43ee2c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -375,7 +375,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; - cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16)); + auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N; + cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16)); } __syncwarp(); }