Weight gradient kernels for dense and MoE models (#95)

* Init weight gradient kernels.

* Support unaligned n,k and gmem stride

* Update docs

* Several cleanups

* Remove restrictions on N

* Add stride(0) assertions

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
Zhean Xu
2025-05-14 14:47:58 +08:00
committed by GitHub
parent d75b218b7b
commit 04278f6dee
12 changed files with 911 additions and 72 deletions

View File

@@ -8,6 +8,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
## News
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
@@ -22,9 +23,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] NVRTC as a faster compiler
- [ ] Stolen JIT cache
- [ ] Sanitizer for testing
- [ ] Weight gradient kernels for dense models
- [ ] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang))
- [ ] CUDA PDL support
- [ ] More scaling granularity support via templates
- [ ] Larger TMA multicast size for some shapes

View File

@@ -5,6 +5,8 @@ from .jit_kernels import (
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
ceil_div,
set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,

View File

@@ -16,21 +16,6 @@
namespace deep_gemm {
enum class Layout {
RowMajor,
ColMajor
};
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
return block_m == 64 ? 1 : 2;
}
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <int kNumFormerIters, int kGap, int kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
if (num_former_iters == kNumFormerIters) {

View File

@@ -0,0 +1,362 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "mma_utils.cuh"
#include "scheduler.cuh"
#include "tma_utils.cuh"
#include "utils.cuh"
namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_wgrad_gemm_kernel(uint32_t shape_k,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_scales_b,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U;
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K);
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b[kNumStages];
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages + 1];
Barrier* empty_barriers[kNumStages + 1];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)
+ i * SMEM_SCALES_A_SIZE_PER_STAGE);
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)
+ i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Fill barriers
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages
* (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE));
#pragma unroll
for (int i = 0; i < kNumStages + 1; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
full_barriers[kNumStages]->init(1);
empty_barriers[kNumStages]->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [&](const auto& func) {
if constexpr (kLastStages == 0) {
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(num_iterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<GemmType::Normal, SHAPE_N, BLOCK_M, BLOCK_N, 1, kNumTMAMulticast, kIsTMAMulticastOnA>(SHAPE_M);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
k_idx / BLOCK_K, num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b);
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
// Issue TMA D
empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1);
auto& full_barrier = *full_barriers[kNumStages];
tma_copy(&tensor_map_d, reinterpret_cast<uint64_t*>(&full_barrier),
smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1);
full_barrier.arrive_and_expect_tx(SMEM_D_SIZE);
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
// Read B scales at the first warpgroup wave
if (local_idx == 0) {
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_scales_b[s] + i * 8 + col_idx * 2));
__syncwarp();
}
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Promote with scales
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b[i].x;
const float &scale_b_1 = scales_b[i].y;
shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
}
// Wait last TMA store to be finished
if (k_iter == 0 and scheduler.current_iter > 0) {
if (threadIdx.x == 0) {
cute::tma_store_wait<0>();
empty_barriers[kNumStages]->arrive();
}
__syncwarp();
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Wait TMA D arrivals
full_barriers[kNumStages]->wait(scheduler.current_iter & 1);
// Accumulate to D shared memory
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2);
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
float2 d_0 = ld_shared(smem_d_0 + i * 4);
st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]});
float2 d_1 = ld_shared(smem_d_1 + i * 4);
st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]});
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M);
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -68,6 +68,12 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
@@ -76,6 +82,10 @@ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
}
template <int N>
__device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
@@ -170,6 +180,7 @@ struct FP8MMASelector {
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
@@ -183,4 +194,19 @@ struct FP8MMASelector {
using type = decltype(select_type());
};
enum class Layout {
RowMajor,
ColMajor
};
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
return block_m == 64 ? 1 : 2;
}
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
} // namespace deep_gemm

View File

@@ -3,6 +3,10 @@ from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
)
from .utils import (
ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,

View File

@@ -34,17 +34,22 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
assert block_k == 128
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
smem_d = block_m * (block_n + block_n_padding) * 2
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = ceil_div(k, block_k) * 4
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
smem_barrier = num_stages * 8 * 2
smem_size = 0
@@ -52,8 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k %
block_n == 0 else 2), 8) * 8
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
@@ -64,13 +69,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
if not is_grouped_contiguous:
block_ms = (64, 128, 256)
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@@ -110,7 +120,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break
@@ -145,11 +155,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
@@ -164,8 +177,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape
assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
@@ -174,7 +185,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
@@ -185,6 +203,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
if m == 0:
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
@@ -194,11 +215,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
@@ -223,7 +244,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
keys={'N': n, 'K': aligned_k,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,

View File

@@ -14,13 +14,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, m_indices: torch.Tensor) -> None:
"""
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -116,13 +118,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
"""
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,

View File

@@ -87,45 +87,48 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
gmem_rows: int, gmem_cols: int,
gmem_rows: int, gmem_cols: int, gmem_stride: int,
smem_rows: int, smem_cols: int,
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_k: int,
block_m: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
a_stride = shape_k if a_stride == 0 else a_stride
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_k: int, shape_n: int,
block_k: int, block_n: int,
num_groups: int) -> cbd.CUtensorMap:
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
b_stride = shape_k if b_stride == 0 else b_stride
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1),
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_n: int,
block_m: int, block_n: int,
num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap:
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
d_stride = shape_n if d_stride == 0 else d_stride
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
swizzle_type_map[swizzle_mode])
@@ -136,10 +139,20 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1),
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size()
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
@@ -254,3 +267,111 @@ static void __instantiate_kernel() {{
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
class FP8WGradGemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'NUM_TMA_MULTICAST',
'K',
'BLOCK_M',
'GMEM_D',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_SCALES_B',
'TENSOR_MAP_D',
'STREAM',
])
@staticmethod
def generate(**kwargs) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_wgrad_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&fp8_wgrad_gemm_kernel<
{kwargs['M']},
{kwargs['N']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['NUM_STAGES']},
{kwargs['LAST_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}
>);
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 WGrad GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
arg_values = (
shape_k,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_scales_b,
tensor_map_d,
)
arg_types = (
ctypes.c_uint32,
None,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)

View File

@@ -0,0 +1,179 @@
import torch
from typing import List, Tuple
from .runtime import (
FP8WGradGemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
from .gemm import get_best_configs
from .tuner import jit_tuner
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: Tuple[torch.Tensor, torch.Tensor]):
"""
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
Requirements:
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
m_, n_ = out.shape
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and m > 0
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m):
lhs_scales = lhs_scales.permute(1, 0)
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
else:
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert lhs_scales.stride(0) == 1
if rhs_scales.shape == ((k + 127) // 128, n):
rhs_scales = rhs_scales.permute(1, 0)
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
else:
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
assert rhs_scales.stride(0) == 1
# Do nothing if `k` is zero
if k == 0:
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
last_stages = (k + 127) // 128 % num_stages
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
tensor_map_scales_b = make_2d_tma_scales_b_desc(
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
kwargs = {
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'K': aligned_k,
'NUM_GROUPS': 1,
'BLOCK_K': block_k,
'GMEM_D': out,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='wgrad_gemm_fp8_fp8_fp32_nt',
keys={'M': m, 'N': n,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages,
'LAST_STAGES': last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8WGradGemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
batch_sizes: List[int]):
"""
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
Requirements:
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
Arguments:
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
representing the per-128-channel scaling factors.
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
batch_sizes: A list of integers specifying the k-dimension for each batch.
"""
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
num_batches, m, n = out.shape
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
for idx in range(num_batches):
k = batch_sizes[idx]
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
D = out[idx]
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
lhs_offset += m * k
rhs_offset += n * k
scales_offset += (k + 127) // 128

View File

@@ -78,7 +78,8 @@ class suppress_stdout_stderr:
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
with_multiple_kernels: bool = False):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
@@ -119,8 +120,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
@@ -130,14 +132,19 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
total_num += int(num_str)
break
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tupled else kernel_times[0]

View File

@@ -5,7 +5,7 @@ print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
import random
import torch
from typing import Tuple
from typing import List, Tuple
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
@@ -13,11 +13,14 @@ from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -94,10 +97,74 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
return x_fp8, y_fp8, out, ref_out
def construct_wgrad(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10
out = residual.clone()
ref_out = residual + (x.float() @ y.float().t())
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = per_token_cast_to_fp8(y)
# NOTES: please do inplace add on the `out` later
return x_fp8, y_fp8, residual, out, ref_out
def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]:
num_groups, total_k = len(k_sizes), sum(k_sizes)
x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16)
y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16)
out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
# Fill tensors with data and compute reference output
x_offset, y_offset = 0, 0
for idx, k in enumerate(k_sizes):
x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten())
y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten())
ref_out[idx] = x_chunk.float() @ y_chunk.float().t()
x_offset += m * k
y_offset += n * k
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
# Cast to FP8 and prepare scale factors
x_offset, y_offset, scale_offset = 0, 0, 0
for k in k_sizes:
x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k))
y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k))
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
num_scales = (k + 127) // 128
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
x_offset += m * k
y_offset += n * k
scale_offset += num_scales
return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes
def test_gemm() -> None:
print('Testing GEMM:')
for m in (64, 128, 4096):
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
@@ -175,6 +242,62 @@ def test_m_grouped_gemm_masked() -> None:
print()
def test_wgrad_gemm():
print('Testing weight gradient GEMM:')
for k in (4096, 8192):
for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)):
# Test correctness
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
# noinspection PyShadowingNames
def test_func():
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_k_grouped_wgrad_gemm():
print('Testing grouped weight gradient GEMM:')
for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)):
for m, n in ((7168, 4096), (2048, 7168)):
# Vary k sizes around base_k
k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)]
k_sizes.append(base_k * num_groups - sum(k_sizes))
# Test correctness
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
for idx in range(num_groups):
diff = calc_diff(out[idx], ref_out[idx])
assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}'
# Construct new tensors to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
total_k = sum(k_sizes)
def test_func():
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups
print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, '
f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@@ -187,3 +310,6 @@ if __name__ == '__main__':
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_wgrad_gemm()
test_k_grouped_wgrad_gemm()