From cf640558afbf6c0d506f652388b923807350b3de Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 13 Mar 2025 21:02:52 +0800 Subject: [PATCH] Update fp8_gemm.cuh --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 141 ++++++++++++++++++----- 1 file changed, 112 insertions(+), 29 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ee6e4a4..533140a 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -146,6 +146,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; + uint32_t prev_global_idx, n_block_idx_prev; auto scheduler = Scheduler(shape_m, grouped_layout); if (threadIdx.x >= kNumMathThreads) { @@ -207,6 +208,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + // NOTE MODIFIED + float final_accum[WGMMA::kNumAccum] = {0}; + // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Decide the number of scales B to load @@ -230,7 +234,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; +// float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float accum[WGMMA::kNumAccum]; + +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 0 final_accum[0]=%f\n", final_accum[0]); +// } // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { @@ -249,6 +258,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #pragma unroll for (int s = 0; s < kNumInnerStages; ++ s) { +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 1 final_accum[0]=%f k_iter=%d s=%d\n", final_accum[0], k_iter, s); +// } + // Read B scales float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks @@ -274,6 +287,55 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, WGMMA::wgmma(desc_a, desc_b, accum, k); } warpgroup_commit_batch(); + + // ------------------------------------------------------------------------------------ + if ((scheduler.current_iter > 0) and (k_iter == 0) and (s == 0)) { +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 2 final_accum[0]=%f k_iter=%d s=%d\n", final_accum[0], k_iter, s); +// } + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx_prev * BLOCK_N, + prev_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) { + final_accum[i] = 0; + } + } + // ------------------------------------------------------------------------------------ + +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 3 final_accum[0]=%f k_iter=%d s=%d\n", final_accum[0], k_iter, s); +// } + #pragma unroll for (int i = 0; i < WGMMA::kNumAccum; ++ i) warpgroup_fence_operand(accum[i]); @@ -296,6 +358,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; } + +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 4 final_accum[0]=%f k_iter=%d s=%d\n", final_accum[0], k_iter, s); +// } } // Wait unaligned cases @@ -304,38 +370,55 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } + +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 5 final_accum[0]=%f k_iter=%d\n", final_accum[0], k_iter); +// } }); - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi 6 final_accum[0]=%f\n", final_accum[0]); +// } +// if (threadIdx.x == 0 and blockIdx.x == 0) { +// printf("hi current_iter=%d num_blocks=%d\n", scheduler.current_iter, scheduler.num_blocks); +// } + if ((scheduler.current_iter+1) * gridDim.x + blockIdx.x < scheduler.num_blocks) { +// if (threadIdx.x == 0 and blockIdx.x == 0) { printf("hi branch a\n"); } + prev_global_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + n_block_idx_prev = n_block_idx; + } else { +// if (threadIdx.x == 0 and blockIdx.x == 0) { printf("hi branch b\n"); } + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); } - __syncwarp(); } } #else