diff --git a/README.md b/README.md index f14601c..2aa53ce 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## News +- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. @@ -22,9 +23,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] NVRTC as a faster compiler - [ ] Stolen JIT cache - [ ] Sanitizer for testing -- [ ] Weight gradient kernels for dense models -- [ ] Weight gradient kernels for MoE models -- [ ] Utility kernels for MoE models (as a pre-built CUDA library) +- [x] Weight gradient kernels for dense models +- [x] Weight gradient kernels for MoE models +- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang)) - [ ] CUDA PDL support - [ ] More scaling granularity support via templates - [ ] Larger TMA multicast size for some shapes diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 15b22ca..8e6b299 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -5,6 +5,8 @@ from .jit_kernels import ( gemm_fp8_fp8_bf16_nt, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, + wgrad_gemm_fp8_fp8_fp32_nt, + k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, ceil_div, set_num_sms, get_num_sms, get_col_major_tma_aligned_tensor, diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index c57691b..b9dfe9f 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -16,21 +16,6 @@ namespace deep_gemm { -enum class Layout { - RowMajor, - ColMajor -}; - -__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { - return block_m == 64 ? 1 : 2; -} - -template -__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { - DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; -} - template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { if (num_former_iters == kNumFormerIters) { diff --git a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh new file mode 100644 index 0000000..4bf179e --- /dev/null +++ b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh @@ -0,0 +1,362 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include "mma_utils.cuh" +#include "scheduler.cuh" +#include "tma_utils.cuh" +#include "utils.cuh" + +namespace deep_gemm { + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) +fp8_wgrad_gemm_kernel(uint32_t shape_k, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_scales_a, + const __grid_constant__ CUtensorMap tensor_map_scales_b, + const __grid_constant__ CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U; + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + + const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b[kNumStages]; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages + 1]; + Barrier* empty_barriers[kNumStages + 1]; + + // Fill shared memory pointers + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + i * SMEM_SCALES_A_SIZE_PER_STAGE); + smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) + + i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE); + } + + // Fill barriers + DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages + * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE)); + #pragma unroll + for (int i = 0; i < kNumStages + 1; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (int i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + full_barriers[kNumStages]->init(1); + empty_barriers[kNumStages]->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + auto launch_k_iterations = [&](const auto& func) { + if constexpr (kLastStages == 0) { + for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(num_iterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(SHAPE_M); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Issue TMA A + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + k_idx / BLOCK_K, num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b); + tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), + smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + + // Issue TMA D + empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1); + auto& full_barrier = *full_barriers[kNumStages]; + tma_copy(&tensor_map_d, reinterpret_cast(&full_barrier), + smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1); + full_barrier.arrive_and_expect_tx(SMEM_D_SIZE); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + + // Read B scales at the first warpgroup wave + if (local_idx == 0) { + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + col_idx * 2)); + __syncwarp(); + } + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + } + + // Wait last TMA store to be finished + if (k_iter == 0 and scheduler.current_iter > 0) { + if (threadIdx.x == 0) { + cute::tma_store_wait<0>(); + empty_barriers[kNumStages]->arrive(); + } + __syncwarp(); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Wait TMA D arrivals + full_barriers[kNumStages]->wait(scheduler.current_iter & 1); + + // Accumulate to D shared memory + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + float2 d_0 = ld_shared(smem_d_0 + i * 4); + st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]}); + float2 d_1 = ld_shared(smem_d_1 + i * 4); + st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]}); + } + } + + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false && "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index c6c7e28..85b2ccc 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -68,6 +68,12 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { return ret; } +__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); + return ret; +} + __device__ __forceinline__ void st_shared(const float* ptr, float val) { asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); } @@ -76,6 +82,10 @@ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); } +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); +} + template __device__ void warpgroup_wait() { DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); @@ -170,6 +180,7 @@ struct FP8MMASelector { if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); @@ -183,4 +194,19 @@ struct FP8MMASelector { using type = decltype(select_type()); }; +enum class Layout { + RowMajor, + ColMajor +}; + +__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { + return block_m == 64 ? 1 : 2; +} + +template +__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { + DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); + return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; +} + } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py index 2a8624b..f1fa7bb 100644 --- a/deep_gemm/jit_kernels/__init__.py +++ b/deep_gemm/jit_kernels/__init__.py @@ -3,6 +3,10 @@ from .m_grouped_gemm import ( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked ) +from .wgrad_gemm import ( + wgrad_gemm_fp8_fp8_fp32_nt, + k_grouped_wgrad_gemm_fp8_fp8_fp32_nt +) from .utils import ( ceil_div, set_num_sms, get_num_sms, get_col_major_tma_aligned_tensor, diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 2122683..c782f28 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -34,17 +34,22 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size -def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: +def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, + is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: + assert block_k == 128 + # Try swizzle first, as it does not waste shared memory swizzle_mode = get_swizzle_mode(block_n) block_n_padding = get_block_n_padding_for_smem_d( block_n) if swizzle_mode == 0 else 0 - smem_d = block_m * (block_n + block_n_padding) * 2 + # NOTES: `scales_b` in a total manner or per-stage manner + smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 smem_b_per_stage = block_n * block_k - smem_scales_b = ceil_div(k, block_k) * 4 + smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 + smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 smem_barrier = num_stages * 8 * 2 smem_size = 0 @@ -52,8 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % - block_n == 0 else 2), 8) * 8 + smem_size += num_stages * smem_scales_b_per_stage + smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += smem_barrier # Swizzle and padding are not compatible @@ -64,13 +69,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ + is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, + is_fp32_out: bool = False, is_wgrad: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: if not is_grouped_contiguous: - block_ms = (64, 128, 256) + block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + (144, 160, ) + block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) + + # Avoid bank conflicts for FP32 output + if is_fp32_out: + block_ns = [x for x in block_ns if x % 16 == 8] fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -110,7 +120,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break @@ -145,11 +155,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor) -> None: """ - Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. + Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + + Requirements: + LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. + The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, @@ -164,8 +177,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], n, k_ = rhs.shape m_, n_ = out.shape - assert n % 64 == 0 and k % 128 == 0 - # Type and shape checks assert m == m_ and n == n_ and k == k_ assert n > 0 and k > 0 @@ -174,7 +185,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 - assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() + assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 + + lhs_stride = lhs.stride(0) + rhs_stride = rhs.stride(0) + out_stride = out.stride(0) + + # The stride(0) of LHS, RHS, and output must be aligned to 16 bytes + assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0 # LHS scales must be transposed for TMA loads, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels @@ -185,6 +203,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], if m == 0: return + # K must be aligned to 128 + aligned_k = (k + 127) // 128 * 128 + # Auto-tuning with compilation num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( @@ -194,11 +215,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], num_math_threads_per_group = 128 tensor_map_a = make_2d_tma_a_desc( - GemmType.Normal, lhs, m, k, block_m, block_k, 1) + GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) tensor_map_b = make_2d_tma_b_desc( - GemmType.Normal, rhs, k, n, block_k, block_n, 1) + GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) tensor_map_d = make_2d_tma_d_desc( - GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1]) + GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride) tensor_map_scales_a = make_2d_tma_scales_a_desc( GemmType.Normal, lhs_scales, m, k, block_m, block_k) @@ -223,7 +244,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], runtime, best_keys = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'N': n, 'K': aligned_k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], 'BLOCK_N_PADDING': smem_config[2], 'NUM_STAGES': num_stages, diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 24a2183..e8c1922 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -14,13 +14,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, m_indices: torch.Tensor) -> None: """ - Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - On the M axis, inputs are grouped into several batches, of which batch sizes aligned to - `get_m_alignment_for_contiguous_layout()` (128). + Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + + Requirements: + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + On the M axis, inputs are grouped into several batches, of which batch sizes aligned to + `get_m_alignment_for_contiguous_layout()` (128). Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, @@ -116,13 +118,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: """ - Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch - should be separately transposed. + Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + + Requirements: + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch + should be separately transposed. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index fa0a61d..8fb1a28 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -87,45 +87,48 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor, def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, - gmem_rows: int, gmem_cols: int, + gmem_rows: int, gmem_cols: int, gmem_stride: int, smem_rows: int, smem_cols: int, swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: if layout == Layout.RowMajor: gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows)) smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows)) - return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type) + return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type) else: gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols)) smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols)) - return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type) + return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type) def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: + num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap: + a_stride = shape_k if a_stride == 0 else a_stride return make_2d_tma_desc(global_address, Layout.RowMajor, - shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, + shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride, block_m, block_k) def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, - num_groups: int) -> cbd.CUtensorMap: + num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap: + b_stride = shape_k if b_stride == 0 else b_stride return make_2d_tma_desc(global_address, Layout.ColMajor, - shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), + shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride, block_k, block_n) def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, - num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap: + num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap: # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + d_stride = shape_n if d_stride == 0 else d_stride return make_2d_tma_desc(global_address, Layout.RowMajor, - shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, + shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_type_map[swizzle_mode]) @@ -136,10 +139,20 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment return make_2d_tma_desc(global_address, Layout.ColMajor, - shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), + shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m, block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) +def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: + # Make TMA aligned to 16 bytes + tma_alignment = 16 / global_address.element_size() + shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment + + return make_2d_tma_desc(global_address, Layout.ColMajor, + shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, + block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + class FP8GemmRuntime(Runtime): def __init__(self, path: str) -> None: super().__init__(path, [ @@ -254,3 +267,111 @@ static void __instantiate_kernel() {{ None, ) return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) + + +class FP8WGradGemmRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, [ + 'NUM_TMA_MULTICAST', + 'K', + 'BLOCK_M', + 'GMEM_D', + 'NUM_SMS', + 'SMEM_SIZE', + 'TENSOR_MAP_A', + 'TENSOR_MAP_B', + 'TENSOR_MAP_SCALES_A', + 'TENSOR_MAP_SCALES_B', + 'TENSOR_MAP_D', + 'STREAM', + ]) + + @staticmethod + def generate(**kwargs) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include +#include + +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&fp8_wgrad_gemm_kernel< + {kwargs['M']}, + {kwargs['N']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['NUM_STAGES']}, + {kwargs['LAST_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'} + >); +}}; +''' + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Generated FP8 WGrad GEMM code:\n{code}') + return code + + # noinspection PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int, + block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int, + tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, + tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap, + tensor_map_d: cbd.CUtensorMap, + stream: cbd.CUstream) -> cbd.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to set max dynamic shared memory size: {res}') + + attr_val = cbd.CUlaunchAttributeValue() + attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cbd.CUlaunchAttribute() + attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cbd.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = num_sms + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = smem_size + config.hStream = stream + + arg_values = ( + shape_k, + tensor_map_a, + tensor_map_b, + tensor_map_scales_a, + tensor_map_scales_b, + tensor_map_d, + ) + arg_types = ( + ctypes.c_uint32, + None, + None, + None, + None, + None, + ) + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py new file mode 100644 index 0000000..7dd5fc5 --- /dev/null +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -0,0 +1,179 @@ +import torch +from typing import List, Tuple + +from .runtime import ( + FP8WGradGemmRuntime, GemmType, + make_2d_tma_a_desc, make_2d_tma_b_desc, + make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc) +from .gemm import get_best_configs +from .tuner import jit_tuner +from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size + + +def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: Tuple[torch.Tensor, torch.Tensor]): + """ + Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. + Results will be accumulated into the output tensor. + + Requirements: + LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. + The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + + Arguments: + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, + the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`. + out: the FP32 output tensor of shape `[m, n]`, which will be accumulated. + """ + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + n, k_ = rhs.shape + m_, n_ = out.shape + + # Type and shape checks + assert m == m_ and n == n_ and k == k_ + assert n > 0 and m > 0 + assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m) + assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.float + assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 + + lhs_stride = lhs.stride(0) + rhs_stride = rhs.stride(0) + out_stride = out.stride(0) + + # The stride(0) of LHS, RHS, and output must be aligned to 16 bytes + assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0 + + # LHS and RHS scales must be transposed for TMA load + # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + if lhs_scales.shape == ((k + 127) // 128, m): + lhs_scales = lhs_scales.permute(1, 0) + assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m + else: + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert lhs_scales.stride(0) == 1 + + if rhs_scales.shape == ((k + 127) // 128, n): + rhs_scales = rhs_scales.permute(1, 0) + assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n + else: + rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales) + assert rhs_scales.stride(0) == 1 + + # Do nothing if `k` is zero + if k == 0: + return + + # K must be aligned to 128 + aligned_k = (k + 127) // 128 * 128 + + # Auto-tuning with compilation + num_sms = get_num_sms() + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) + last_stages = (k + 127) // 128 % num_stages + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc( + GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) + tensor_map_b = make_2d_tma_b_desc( + GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) + tensor_map_d = make_2d_tma_d_desc( + GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride) + tensor_map_scales_a = make_2d_tma_scales_a_desc( + GemmType.Normal, lhs_scales, m, k, block_m, block_k) + tensor_map_scales_b = make_2d_tma_scales_b_desc( + GemmType.Normal, rhs_scales, n, k, block_n, block_k) + + kwargs = { + 'GEMM_TYPE': GemmType.Normal, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'K': aligned_k, + 'NUM_GROUPS': 1, + 'BLOCK_K': block_k, + 'GMEM_D': out, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, + 'TENSOR_MAP_SCALES_B': tensor_map_scales_b, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + } + + runtime, best_keys = jit_tuner.compile_and_tune( + name='wgrad_gemm_fp8_fp8_fp32_nt', + keys={'M': m, 'N': n, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'NUM_STAGES': num_stages, + 'LAST_STAGES': last_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, + space=(), + kwargs=kwargs, + runtime_cls=FP8WGradGemmRuntime, + ) + + # Run the kernel + runtime(**best_keys, **kwargs) + + +def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + batch_sizes: List[int]): + """ + Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. + Results will be accumulated into the output tensor. + + Requirements: + This function handles multiple batches with varying k-dimensions, processing each batch sequentially. + Each batch's LHS, RHS, and output tensors must be contiguous. + The RHS and RHS scaling factors are required to be transposed. + The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. + + Arguments: + lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, + and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. + the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, + representing the per-128-channel scaling factors. + rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, + and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. + the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, + representing the per-128-channel scaling factors. + out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. + batch_sizes: A list of integers specifying the k-dimension for each batch. + """ + lhs, lhs_scales = lhs[0].view(-1), lhs[1] + rhs, rhs_scales = rhs[0].view(-1), rhs[1] + num_batches, m, n = out.shape + + lhs_offset, rhs_offset, scales_offset = 0, 0, 0 + + for idx in range(num_batches): + k = batch_sizes[idx] + A = lhs[lhs_offset:lhs_offset + m * k].view(m, k) + B = rhs[rhs_offset:rhs_offset + n * k].view(n, k) + A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] + B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] + D = out[idx] + + wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D) + + lhs_offset += m * k + rhs_offset += n * k + scales_offset += (k + 127) // 128 \ No newline at end of file diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py index f99ecd4..55a9aff 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -78,7 +78,8 @@ class suppress_stdout_stderr: def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): + trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, + with_multiple_kernels: bool = False): # Conflict with Nsight Systems using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) @@ -119,8 +120,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) - for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' # Save chrome traces if trace_path is not None: @@ -130,14 +132,19 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: units = {'ms': 1e3, 'us': 1e6} kernel_times = [] for name in kernel_names: + total_time = 0 + total_num = 0 for line in prof_lines: if name in line: time_str = line.split()[-2] + num_str = line.split()[-1] for unit, scale in units.items(): if unit in time_str: - kernel_times.append(float(time_str.replace(unit, '')) / scale) + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) break - break + kernel_times.append(total_time / total_num) + return tuple(kernel_times) if is_tupled else kernel_times[0] diff --git a/tests/test_core.py b/tests/test_core.py index de544c4..36c1c34 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,7 @@ print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') import random import torch -from typing import Tuple +from typing import List, Tuple import deep_gemm from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor @@ -13,11 +13,14 @@ from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 + assert x.dim() == 2 m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -94,10 +97,74 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ return x_fp8, y_fp8, out, ref_out +def construct_wgrad(m: int, k: int, n: int) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 + out = residual.clone() + ref_out = residual + (x.float() @ y.float().t()) + + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = per_token_cast_to_fp8(y) + + # NOTES: please do inplace add on the `out` later + return x_fp8, y_fp8, residual, out, ref_out + + +def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: + num_groups, total_k = len(k_sizes), sum(k_sizes) + + x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) + y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) + out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) + ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) + + # Fill tensors with data and compute reference output + x_offset, y_offset = 0, 0 + for idx, k in enumerate(k_sizes): + x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + + x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) + y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) + ref_out[idx] = x_chunk.float() @ y_chunk.float().t() + + x_offset += m * k + y_offset += n * k + + x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) + y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) + + total_scale_factors = sum((k + 127) // 128 for k in k_sizes) + x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) + y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) + + # Cast to FP8 and prepare scale factors + x_offset, y_offset, scale_offset = 0, 0, 0 + for k in k_sizes: + x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) + y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) + + x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) + y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) + + num_scales = (k + 127) // 128 + x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) + y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) + + x_offset += m * k + y_offset += n * k + scale_offset += num_scales + + return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes + + def test_gemm() -> None: print('Testing GEMM:') for m in (64, 128, 4096): - for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: + for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: x_fp8, y_fp8, out, ref_out = construct(m, k, n) deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) diff = calc_diff(out, ref_out) @@ -175,6 +242,62 @@ def test_m_grouped_gemm_masked() -> None: print() +def test_wgrad_gemm(): + print('Testing weight gradient GEMM:') + + for k in (4096, 8192): + for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): + # Test correctness + x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) + x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) + print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_wgrad_gemm(): + print('Testing grouped weight gradient GEMM:') + + for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): + for m, n in ((7168, 4096), (2048, 7168)): + # Vary k sizes around base_k + k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] + k_sizes.append(base_k * num_groups - sum(k_sizes)) + + # Test correctness + x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) + deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) + + for idx in range(num_groups): + diff = calc_diff(out[idx], ref_out[idx]) + assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' + + # Construct new tensors to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) + total_k = sum(k_sizes) + + def test_func(): + deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) + + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups + print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' + f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -187,3 +310,6 @@ if __name__ == '__main__': test_gemm() test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() + + test_wgrad_gemm() + test_k_grouped_wgrad_gemm() \ No newline at end of file