diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 15b22ca..8e6b299 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -5,6 +5,8 @@ from .jit_kernels import ( gemm_fp8_fp8_bf16_nt, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, + wgrad_gemm_fp8_fp8_fp32_nt, + k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, ceil_div, set_num_sms, get_num_sms, get_col_major_tma_aligned_tensor, diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index e8370af..a7d5480 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -16,21 +16,6 @@ namespace deep_gemm { -enum class Layout { - RowMajor, - ColMajor -}; - -__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { - return block_m == 64 ? 1 : 2; -} - -template -__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { - DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; -} - template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { if (num_former_iters == kNumFormerIters) { diff --git a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh new file mode 100644 index 0000000..915552b --- /dev/null +++ b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh @@ -0,0 +1,468 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include "mma_utils.cuh" +#include "scheduler.cuh" +#include "tma_utils.cuh" +#include "utils.cuh" + +namespace deep_gemm { + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) +fp8_wgrad_gemm_kernel(uint32_t shape_k, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_scales_a, + const __grid_constant__ CUtensorMap tensor_map_scales_b, + const __grid_constant__ CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U; + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + + const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_scales_a); + cute::prefetch_tma_descriptor(&tensor_map_scales_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b[kNumStages]; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages + 1]; + Barrier* empty_barriers[kNumStages + 1]; + + // Fill shared memory pointers + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + i * SMEM_SCALES_A_SIZE_PER_STAGE); + smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) + + i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE); + } + + // Fill barriers + DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages + * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE)); + #pragma unroll + for (int i = 0; i < kNumStages + 1; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + full_barriers[kNumStages]->init(1); + empty_barriers[kNumStages]->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + auto launch_k_iterations = [&](const auto& func) { + if constexpr (kLastStages == 0) { + for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(num_iterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(SHAPE_M); + + cudaGridDependencySynchronize(); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Issue TMA A + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + k_idx / BLOCK_K, num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b); + tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), + smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + + // Issue TMA D + empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1); + auto& full_barrier = *full_barriers[kNumStages]; + tma_copy(&tensor_map_d, reinterpret_cast(&full_barrier), + smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1); + full_barrier.arrive_and_expect_tx(SMEM_D_SIZE); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + + // Read B scales at the first warpgroup wave + if (local_idx == 0) { + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + col_idx * 2)); + __syncwarp(); + } + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + } + + // Wait last TMA store to be finished + if (k_iter == 0 and scheduler.current_iter > 0) { + if (threadIdx.x == 0) { + cute::tma_store_wait<0>(); + empty_barriers[kNumStages]->arrive(); + } + __syncwarp(); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Wait TMA D arrivals + full_barriers[kNumStages]->wait(scheduler.current_iter & 1); + + // Accumulate to D shared memory + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + float2 d_0 = ld_shared(smem_d_0 + i * 4); + st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]}); + float2 d_1 = ld_shared(smem_d_1 + i * 4); + st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]}); + } + } + + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + cudaTriggerProgrammaticLaunchCompletion(); + } +#else + if (blockIdx.x == 0 && threadIdx.x == 0) + DG_DEVICE_ASSERT(false && "This kernel only support sm_90a"); +#endif +} + +template +class WgradGemm { +public: + WgradGemm() = default; + + static void run(uint32_t shape_k, + const CUtensorMap& tma_a_desc, + const CUtensorMap& tma_b_desc, + const CUtensorMap& tma_scales_a_desc, + const CUtensorMap& tma_scales_b_desc, + const CUtensorMap& tma_d_desc, + cudaStream_t stream, + int num_sms, uint32_t smem_size) { + // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + auto kernel = fp8_wgrad_gemm_kernel; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm(BLOCK_M); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + // NOTES: `>= 4` cluster size will cause performance degradation + cudaLaunchAttribute attr[2]; + attr[0].id = cudaLaunchAttributeClusterDimension; + attr[0].val.clusterDim = {kNumTMAMulticast, 1, 1}; + attr[1].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr[1].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attr; + config.numAttrs = 2; + + // Launch + auto status = cudaLaunchKernelEx(&config, kernel, + shape_k, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); + } + + template + static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k, uint32_t a_stride) { + return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m, shape_k, BLOCK_M, BLOCK_K, a_stride); + } + + template + static CUtensorMap make_2d_tma_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k, uint32_t b_stride) { + return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k, shape_n, BLOCK_K, BLOCK_N, b_stride); + } + + template + static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m, uint32_t shape_n, uint32_t d_stride) { + return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m, shape_n, BLOCK_M, BLOCK_N, d_stride, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_m = ceil_div(shape_m, kAlignment) * kAlignment; + + return make_2d_tma_desc(global_address, Layout::ColMajor, shape_m, ceil_div(shape_k, BLOCK_K), BLOCK_M, 1, shape_m, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_scales_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_n = ceil_div(shape_n, kAlignment) * kAlignment; + + return make_2d_tma_desc(global_address, Layout::ColMajor, shape_n, ceil_div(shape_k, BLOCK_K), BLOCK_N, 1, shape_n, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_desc( + T* global_address, Layout layout, + uint32_t gmem_rows, uint32_t gmem_cols, + uint32_t smem_rows, uint32_t smem_cols, + uint32_t gmem_stride, + CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { + if (layout == Layout::RowMajor) { + uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; + uint32_t smem_dim[2] = {smem_cols, smem_rows}; + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_stride * sizeof(T), smem_dim, swizzle_type); + } else { + uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; + uint32_t smem_dim[2] = {smem_rows, smem_cols}; + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_stride * sizeof(T), smem_dim, swizzle_type); + } + } +}; + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index a442af7..343ca38 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -66,6 +66,12 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { return ret; } +__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); + return ret; +} + __device__ __forceinline__ void st_shared(const float* ptr, float val) { asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); } @@ -74,6 +80,10 @@ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); } +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); +} + template __device__ void warpgroup_wait() { DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); @@ -181,4 +191,19 @@ struct FP8MMASelector { using type = decltype(select_type()); }; +enum class Layout { + RowMajor, + ColMajor +}; + +__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { + return block_m == 64 ? 1 : 2; +} + +template +__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { + DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); + return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; +} + } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py index 2a8624b..f1fa7bb 100644 --- a/deep_gemm/jit_kernels/__init__.py +++ b/deep_gemm/jit_kernels/__init__.py @@ -3,6 +3,10 @@ from .m_grouped_gemm import ( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked ) +from .wgrad_gemm import ( + wgrad_gemm_fp8_fp8_fp32_nt, + k_grouped_wgrad_gemm_fp8_fp8_fp32_nt +) from .utils import ( ceil_div, set_num_sms, get_num_sms, get_col_major_tma_aligned_tensor, diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index c6fd29d..bd11155 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -61,16 +61,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size -def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: +def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, + is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: # Try swizzle first, as it does not waste shared memory swizzle_mode = get_swizzle_mode(block_n) block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 - smem_d = block_m * (block_n + block_n_padding) * 2 + smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 smem_b_per_stage = block_n * block_k - smem_scales_b = ceil_div(k, block_k) * 4 + if is_wgrad: + smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128 + else: + smem_scales_b = ceil_div(k, block_k) * 4 smem_barrier = num_stages * 8 * 2 smem_size = 0 @@ -78,7 +82,10 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + if is_wgrad: + smem_size += num_stages * smem_scales_b_per_stage + else: + smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += smem_barrier # Swizzle and padding are not compatible @@ -89,13 +96,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ + is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, + is_fp32_out: bool = False, is_wgrad: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: if not is_grouped_contiguous: - block_ms = (64, 128, 256) + block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + (144, 160, ) + block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) + + # Avoid bank conflicts for fp32 output + if is_fp32_out: + block_ns = [x for x in block_ns if x % 16 == 8] fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -135,7 +147,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = (4, 3) for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py new file mode 100644 index 0000000..dbefa57 --- /dev/null +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -0,0 +1,171 @@ +import math +import torch +from typing import List, Tuple + +from .gemm import get_best_configs +from .tuner import jit_tuner +from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size + +# C++ code templates +includes = ('"deep_gemm/fp8_wgrad_gemm.cuh"', ) +template = """ +using namespace deep_gemm; + +// Templated args from Python JIT call +constexpr auto M = {M}, N = {N}; +constexpr auto BLOCK_M = {BLOCK_M}; +constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto BLOCK_K = 128; +constexpr auto kNumStages = {NUM_STAGES}; +constexpr auto kLastStages = {LAST_STAGES}; +constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; + +// Make a templated GEMM +using gemm_t = WgradGemm; + +// Launch kernel +auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m, k, a_stride); +auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs, n, k, b_stride); +auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m, k); +auto tma_scales_b_desc = gemm_t::make_2d_tma_scales_b_desc(rhs_scales, n, k); +auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m, n, d_stride); +gemm_t::run(k, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc, + stream, num_sms, smem_size); +""" + + +def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: Tuple[torch.Tensor, torch.Tensor]): + """ + Do a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. + LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + + Arguments: + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. + the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`. + out: the FP32 output tensor of shape `[m, n]`, representing the result. + """ + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + n, k_ = rhs.shape + m_, n_ = out.shape + + # Type and shape checks + assert m == m_ and n == n_ and k == k_ + assert n > 0 and m > 0 + assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m) + assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.float + assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 + + lhs_stride = lhs.stride(0) + rhs_stride = rhs.stride(0) + out_stride = out.stride(0) + + # LHS and RHS scales must be transposed for TMA load + # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + if lhs_scales.shape == ((k + 127) // 128, m): + lhs_scales = lhs_scales.permute(1, 0) + assert get_tma_aligned_size(m, 4) == m + else: + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert lhs_scales.stride(0) == 1 + + if rhs_scales.shape == ((k + 127) // 128, n): + rhs_scales = rhs_scales.permute(1, 0) + assert get_tma_aligned_size(n, 4) == n + else: + rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales) + assert rhs_scales.stride(0) == 1 + + # Do nothing if `k` is zero + if k == 0: + return + + aligned_n = (n + 63) // 64 * 64 + aligned_k = (k + 127) // 128 * 128 + + # Auto-tuning with compilation + global includes, template + num_sms = get_num_sms() + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) + last_stages = (k + 127) // 128 % num_stages + + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, n, k, + lhs_stride, rhs_stride, out_stride, + torch.cuda.current_stream(), num_sms, smem_config[0]) + runtime = jit_tuner.compile_and_tune( + name='gemm_fp8_fp8_fp32_nt_dptp128c_dyn', + keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'NUM_STAGES': num_stages, + 'LAST_STAGES': last_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, + space=(), + includes=includes, + arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), + ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), + ('out', torch.float), ('m', int), ('n', int), ('k', int), + ('a_stride', int), ('b_stride', int), ('d_stride', int), + ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) + + # Run the kernel + runtime(*args) + + +def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + batch_sizes: List[int]): + """ + Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. + This function handles multiple batches with varying k-dimensions, processing each batch sequentially. + Each batch's LHS, RHS, and output tensors must be contiguous. + The RHS and RHS scaling factors are required to be transposed. + The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. + + Arguments: + lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, + and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. + the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, + representing the per-128-channel scaling factors. + rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, + and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. + the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, + representing the per-128-channel scaling factors. + out: The FP32 output tensor of shape [num_batches, m, n], representing the result. + batch_sizes: A list of integers specifying the k-dimension for each batch. + """ + lhs, lhs_scales = lhs[0].view(-1), lhs[1] + rhs, rhs_scales = rhs[0].view(-1), rhs[1] + num_batches, m, n = out.shape + + lhs_offset, rhs_offset, scales_offset = 0, 0, 0 + + for idx in range(num_batches): + k = batch_sizes[idx] + A = lhs[lhs_offset:lhs_offset + m * k].view(m, k) + B = rhs[rhs_offset:rhs_offset + n * k].view(n, k) + A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] + B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] + D = out[idx] + + wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D) + + lhs_offset += m * k + rhs_offset += n * k + scales_offset += (k + 127) // 128 \ No newline at end of file diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py index d5cdd01..6b5f898 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -78,7 +78,7 @@ class suppress_stdout_stderr: def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): + trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, is_multiple: bool = False): # Conflict with Nsight Systems using_nsys = os.environ.get('DG_NSYS_PROFILING', False) @@ -136,8 +136,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) - for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + if not is_multiple: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' # Save chrome traces if trace_path is not None: @@ -147,14 +148,29 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: units = {'ms': 1e3, 'us': 1e6} kernel_times = [] for name in kernel_names: - for line in prof_lines: - if name in line: - time_str = line.split()[-2] - for unit, scale in units.items(): - if unit in time_str: - kernel_times.append(float(time_str.replace(unit, '')) / scale) - break - break + if not is_multiple: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_times.append(float(time_str.replace(unit, '')) / scale) + break + break + else: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num) + return tuple(kernel_times) if is_tupled else kernel_times[0] diff --git a/tests/test_core.py b/tests/test_core.py index bdc1841..effabd4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,6 @@ import random import torch -from typing import Tuple +from typing import List, Tuple import deep_gemm from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor @@ -89,6 +89,70 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ return x_fp8, y_fp8, out, ref_out +def construct_wgrad(m: int, k: int, n: int) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 + out = residual.clone() + ref_out = residual + (x.float() @ y.float().t()) + + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = per_token_cast_to_fp8(y) + + # NOTES: please do inplace add on the `out` later + return x_fp8, y_fp8, residual, out, ref_out + + +def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: + num_groups, total_k = len(k_sizes), sum(k_sizes) + + x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) + y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) + out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) + ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) + + # Fill tensors with data and compute reference output + x_offset, y_offset = 0, 0 + for idx, k in enumerate(k_sizes): + x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + + x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) + y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) + ref_out[idx] = x_chunk.float() @ y_chunk.float().t() + + x_offset += m * k + y_offset += n * k + + x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) + y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) + + total_scale_factors = sum((k + 127) // 128 for k in k_sizes) + x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) + y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) + + # Cast to FP8 and prepare scale factors + x_offset, y_offset, scale_offset = 0, 0, 0 + for k in k_sizes: + x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) + y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) + + x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) + y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) + + num_scales = (k + 127) // 128 + x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) + y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) + + x_offset += m * k + y_offset += n * k + scale_offset += num_scales + + return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes + + def test_gemm() -> None: print('Testing GEMM:') for m in (64, 128, 4096): @@ -170,6 +234,62 @@ def test_m_grouped_gemm_masked() -> None: print() +def test_wgrad_gemm(): + print('Testing weight gradient GEMM:') + + for k in (4096, 8192): + for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): + # Test correctness + x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) + x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) + print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_wgrad_gemm(): + print('Testing grouped weight gradient GEMM:') + + for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): + for m, n in ((7168, 4096), (2048, 7168)): + # Vary k sizes around base_k + k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] + k_sizes.append(base_k * num_groups - sum(k_sizes)) + + # Test correctness + x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) + deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) + + for idx in range(num_groups): + diff = calc_diff(out[idx], ref_out[idx]) + assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' + + # Construct new tensors to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) + total_k = sum(k_sizes) + + def test_func(): + deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) + + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, is_multiple=True) * num_groups + print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' + f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -182,3 +302,6 @@ if __name__ == '__main__': test_gemm() test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() + + test_wgrad_gemm() + test_k_grouped_wgrad_gemm() \ No newline at end of file