diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 9ba930f..de6ba38 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -350,48 +350,38 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, } }, num_former_iters); - // Write back to shared memory using STSM + // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto m_offset = local_idx * WAVE_BLOCK_M; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 - ); - } - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + for (uint32_t partition_idx = 0; partition_idx < 2; ++ partition_idx) { + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + auto casted = __float22bfloat162_rn({shifted_accum[i * 4 + partition_idx * 2 + 0], + shifted_accum[i * 4 + partition_idx * 2 + 1]}); + auto smem_ptr = smem_d + i * 8; + smem_ptr += (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); + SM90_U32x1_STSM_N::copy(casted, smem_ptr); + } - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - if (n_block_idx < SHAPE_N / BLOCK_N) { - // Except the last unaligned block - cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - } else { - // The last unaligned block - cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + // Issue TMA store + cute::tma_store_fence(); + if (lane_idx < 8) { + auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); + auto gmem_ptr = gmem_d + (m_block_idx * BLOCK_M + m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; + cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16)); + } + __syncwarp(); } - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); } - __syncwarp(); + + // Wait TMA to be finished + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); } } #else diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index 0cc554a..bb6284e 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -867,6 +867,16 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int kNumAccum = M * N / 128; }; +template +struct SM90_U32x1_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + const uint32_t src[1] = {*reinterpret_cast(&src_0)}; + asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "l"(smem_dst), "r"(src[0])); + } +}; + template struct SM90_U32x2_STSM_N { __device__ __forceinline__ static void