diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 15b22ca..978dd4c 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -3,6 +3,7 @@ import torch from . import jit from .jit_kernels import ( gemm_fp8_fp8_bf16_nt, + gemm_fp8_fp8_bf16_bw_nt, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, ceil_div, diff --git a/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh b/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh new file mode 100644 index 0000000..e6fa3d0 --- /dev/null +++ b/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh @@ -0,0 +1,442 @@ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" +#pragma once + +#include +#include + +#include +#include +#include + +#include "mma_utils.cuh" +#include "scheduler.cuh" +#include "tma_utils.cuh" +#include "utils.cuh" +#include "fp8_gemm.cuh" + +namespace deep_gemm { + + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) +fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, + uint32_t shape_m, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_scales_a, + const __grid_constant__ CUtensorMap tensor_map_scales_b, + const __grid_constant__ CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + // static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(BLOCK_N * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); + //static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + //static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b[kNumStages]; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); + smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) + i * SMEM_SCALES_B_SIZE_PER_STAGE); + } + // Fill barriers + // auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE_PER_STAGE); + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b[kNumStages-1]) + SMEM_SCALES_B_SIZE_PER_STAGE); + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + auto launch_k_iterations = [](const auto& func) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A with broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + // Only support normal gemm now. @kavioyu + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(0, 1, k_idx / BLOCK_K)); + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + // Only support normal gemm now. @kavioyu + tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), + smem_scales_b[s], n_block_idx * BLOCK_N, scheduler.get_global_idx(0, 1, k_idx / BLOCK_K)); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); + // full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const int laneid_div_4 = lane_idx / 4; + const auto r_0 = warp_idx * 16 + laneid_div_4, r_1 = r_0 + 8; + const unsigned int scale_b_idx[2] = {laneid_div_4 * 8 + lane_idx % 4 * 2, (8 + laneid_div_4) * 8 + lane_idx % 4 * 2}; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + // Read B scales + float2 scale_b[2]; + #pragma unroll + for (int i = 0; i < 2; ++i) { + scale_b[i].x = ld_shared(smem_scales_b[s] + scale_b_idx[i]); + scale_b[i].y = ld_shared(smem_scales_b[s] + scale_b_idx[i] + 1); + } + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // WGMMA::kNumAccum = 64, loop 16 steps + int src_lane_id = threadIdx.x % 4 + (i % 8) * 4; + float scale_b_0 = __shfl_sync(0xffffffff, scale_b[i/8].x, src_lane_id); + float scale_b_1 = __shfl_sync(0xffffffff, scale_b[i/8].y, src_lane_id); + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +template +class GemmBW { +private: + using Barrier = cuda::barrier; + +public: + GemmBW() = default; + + static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, + uint32_t shape_m, + const CUtensorMap& tma_a_desc, + const CUtensorMap& tma_b_desc, + const CUtensorMap& tma_scales_a_desc, + const CUtensorMap& tma_scales_b_desc, + const CUtensorMap& tma_d_desc, + cudaStream_t stream, + int num_sms, uint32_t smem_size) { + // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + auto kernel = fp8_gemm_bw_kernel; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm(BLOCK_M); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + // NOTES: `>= 4` cluster size will cause performance degradation + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {kNumTMAMulticast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + // Launch + auto status = cudaLaunchKernelEx(&config, kernel, + gmem_d, scales_b, grouped_layout, + shape_m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); + } + + template + static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); + } + + template + static CUtensorMap make_2d_tma_b_desc(T* global_address) { + return make_2d_tma_desc(global_address, Layout::ColMajor, + SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); + } + + template + static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, + min(BLOCK_M, shape_m), BLOCK_N, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_m = ceil_div(shape_m, kAlignment) * kAlignment; + + return make_2d_tma_desc(global_address, Layout::ColMajor, + shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_scales_b_desc(T* global_address, uint32_t shape_m) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + constexpr uint32_t shape_n = ceil_div(SHAPE_N, kAlignment) * kAlignment; + + return make_2d_tma_desc(global_address, Layout::ColMajor, + shape_n, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_N, 1, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_desc( + T* global_address, Layout layout, + uint32_t gmem_rows, uint32_t gmem_cols, + uint32_t smem_rows, uint32_t smem_cols, + CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { + if (layout == Layout::RowMajor) { + uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; + uint32_t smem_dim[2] = {smem_cols, smem_rows}; + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); + } else { + uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; + uint32_t smem_dim[2] = {smem_rows, smem_cols}; + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); + } + } +}; + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 329fbb0..b8bd021 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -1,3 +1,4 @@ +#pragma once #include "utils.cuh" namespace deep_gemm { diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py index 2a8624b..e37f3a4 100644 --- a/deep_gemm/jit_kernels/__init__.py +++ b/deep_gemm/jit_kernels/__init__.py @@ -1,4 +1,5 @@ from .gemm import gemm_fp8_fp8_bf16_nt +from .gemm_bw import gemm_fp8_fp8_bf16_bw_nt from .m_grouped_gemm import ( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked diff --git a/deep_gemm/jit_kernels/gemm_bw.py b/deep_gemm/jit_kernels/gemm_bw.py new file mode 100644 index 0000000..ee65f59 --- /dev/null +++ b/deep_gemm/jit_kernels/gemm_bw.py @@ -0,0 +1,186 @@ +from typing import Tuple + +import torch + +from .tuner import jit_tuner +from .utils import ( + ceil_div, + get_col_major_tma_aligned_tensor, + get_m_alignment_for_contiguous_layout, + get_num_sms, +) + +# C++ code templates +includes = ('"deep_gemm/fp8_gemm_backward_w.cuh"',) +template = """ +using namespace deep_gemm; + +// Templated args from Python JIT call +constexpr auto N = {N}, K = {K}; +constexpr auto BLOCK_M = {BLOCK_M}; +constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto kNumStages = {NUM_STAGES}; +constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; + +// Make a templated GEMM +using GemmType = GemmBW; + +// Launch kernel +auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_scales_b_desc = GemmType::make_2d_tma_scales_b_desc(rhs_scales, m); +auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); +GemmType::run(out, rhs_scales, nullptr, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc, + stream, num_sms, smem_size); +""" + + +def is_tma_multicast_legal( + n: int, block_n: int, num_tma_multicast: int, num_sms: int +) -> bool: + if num_tma_multicast == 1: + return True + return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + + +def get_smem_size( + num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128 +) -> int: + smem_d = block_m * block_n * 2 + smem_a_per_stage = block_m * block_k + smem_scales_a_per_stage = block_m * 4 + smem_scales_b_per_stage = block_n * 4 + smem_b_per_stage = block_n * block_k + # smem_scales_b = ceil_div(k, block_k) * 4 + smem_barrier = num_stages * 8 * 2 + + smem_size = 0 + smem_size += smem_d + smem_size += num_stages * smem_a_per_stage + smem_size += num_stages * smem_scales_a_per_stage + smem_size += num_stages * smem_scales_b_per_stage + smem_size += num_stages * smem_b_per_stage + # smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += smem_barrier + return smem_size + + +def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, + is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]: + if not is_grouped_contiguous: + # TODO: for some cases, smaller M block is better, add them into tuning space + block_ms = (64 if m <= 64 else 128, ) + else: + block_ms = (get_m_alignment_for_contiguous_layout(), ) + block_ns = tuple(range(32, 129, 32)) + + fix_wave_saturate = lambda x: num_sms if x == 0 else x + get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) + get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) + + # Decide block sizes by waves + best_block_m, best_block_n = None, None + for block_m in block_ms: + for block_n in block_ns: + success = False + num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) + if best_block_m is None or best_block_n is None: + success = True + elif num_waves < best_num_waves: + success = True + elif num_waves == best_num_waves: + # Check last wave utilization + util = get_last_wave_util(block_m, block_n) + best_util = get_last_wave_util(best_block_m, best_block_n) + success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) + best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) + assert best_block_m is not None and best_block_n is not None + + # Always pick the longest one + # NOTES: for double B scales, the best number of stages may be reduced + best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 + for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) + if best_smem_size <= sm90_capacity: + best_num_stages = num_stages + break + assert best_num_stages is not None + + # Decide the number of TMA multicast + best_num_tma_multicast = 1 + if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 + + return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + + +def gemm_fp8_fp8_bf16_bw_nt( + lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, +) -> None: + """ + Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 1x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + + Arguments: + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[m, n]`, representing the result. + """ + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + n, k_ = rhs.shape + m_, n_ = out.shape + + assert n % 64 == 0 and k % 128 == 0 + + # Type and shape checks + assert m == m_ and n == n_ and k == k_ + assert n > 0 and k > 0 + assert lhs_scales.shape == (m, (k + 127) // 128) + assert rhs_scales.shape == (n, (k + 127) // 128) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.bfloat16 + assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales) + + # Do nothing if `m` is zero + if m == 0: + return + + # Auto-tuning with compilation + global includes, template + num_sms = get_num_sms() + block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) + runtime = jit_tuner.compile_and_tune( + name='gemm_fp8_fp8_bf16_bw_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, + space=(), + includes=includes, + arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), + ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), + ('out', torch.bfloat16), ('m', int), + ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) + + # Run the kernel + runtime(*args) diff --git a/tests/test_core.py b/tests/test_core.py index 68d9b79..52df4a7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,9 +1,15 @@ import random -import torch from typing import Tuple +import torch + import deep_gemm -from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor +from deep_gemm import ( + bench_kineto, + calc_diff, + ceil_div, + get_col_major_tma_aligned_tensor, +) def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -33,8 +39,18 @@ def construct(m: int, k: int, n: int) -> \ ref_out = x @ y.t() x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +def construct_backward_w(m: int, k: int, n: int) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = x @ y.t() + x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_token_cast_to_fp8(y) + #x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + #y_fp8 = (y_fp8[0], get_col_major_tma_aligned_tensor(y_fp8[1])) return x_fp8, y_fp8, out, ref_out @@ -84,6 +100,30 @@ def test_gemm() -> None: print() +def test_gemm_backward_w() -> None: + print('Testing GEMM Backward W:') + for m in (64, 128, 4096): + for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: + x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + torch.cuda.synchronize() + print(diff) + + # noinspection PyShadowingNames + def test_func(): + # Construct new tensors every time to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out) + + t = bench_kineto(test_func, 'fp8_gemm_bw', suppress_kineto_output=True) + print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + def test_m_grouped_gemm_contiguous() -> None: print('Testing grouped contiguous GEMM:') @@ -153,6 +193,7 @@ if __name__ == '__main__': print('Library path:') print(f' > {deep_gemm.__path__}\n') + test_gemm_backward_w() test_gemm() - test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() + # test_m_grouped_gemm_contiguous() + # test_m_grouped_gemm_masked()