Init weight gradient kernels.

This commit is contained in:
Zhean Xu 2025-05-06 17:16:27 +08:00
parent d374456787
commit d5470d3b4e
9 changed files with 841 additions and 35 deletions

View File

@ -5,6 +5,8 @@ from .jit_kernels import (
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
ceil_div,
set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,

View File

@ -16,21 +16,6 @@
namespace deep_gemm {
enum class Layout {
RowMajor,
ColMajor
};
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
return block_m == 64 ? 1 : 2;
}
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <int kNumFormerIters, int kGap, int kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
if (num_former_iters == kNumFormerIters) {

View File

@ -0,0 +1,468 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "mma_utils.cuh"
#include "scheduler.cuh"
#include "tma_utils.cuh"
#include "utils.cuh"
namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_wgrad_gemm_kernel(uint32_t shape_k,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_scales_b,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U;
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K);
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
cute::prefetch_tma_descriptor(&tensor_map_scales_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b[kNumStages];
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages + 1];
Barrier* empty_barriers[kNumStages + 1];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)
+ i * SMEM_SCALES_A_SIZE_PER_STAGE);
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)
+ i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Fill barriers
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages
* (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE));
#pragma unroll
for (int i = 0; i < kNumStages + 1; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
full_barriers[kNumStages]->init(1);
empty_barriers[kNumStages]->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [&](const auto& func) {
if constexpr (kLastStages == 0) {
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(num_iterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<GemmType::Normal, SHAPE_N, BLOCK_M, BLOCK_N, 1, kNumTMAMulticast, kIsTMAMulticastOnA>(SHAPE_M);
cudaGridDependencySynchronize();
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
k_idx / BLOCK_K, num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b);
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
// Issue TMA D
empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1);
auto& full_barrier = *full_barriers[kNumStages];
tma_copy(&tensor_map_d, reinterpret_cast<uint64_t*>(&full_barrier),
smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1);
full_barrier.arrive_and_expect_tx(SMEM_D_SIZE);
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
// Read B scales at the first warpgroup wave
if (local_idx == 0) {
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_scales_b[s] + i * 8 + col_idx * 2));
__syncwarp();
}
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Promote with scales
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b[i].x;
const float &scale_b_1 = scales_b[i].y;
shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
}
// Wait last TMA store to be finished
if (k_iter == 0 and scheduler.current_iter > 0) {
if (threadIdx.x == 0) {
cute::tma_store_wait<0>();
empty_barriers[kNumStages]->arrive();
}
__syncwarp();
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Wait TMA D arrivals
full_barriers[kNumStages]->wait(scheduler.current_iter & 1);
// Accumulate to D shared memory
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2);
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
float2 d_0 = ld_shared(smem_d_0 + i * 4);
st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]});
float2 d_1 = ld_shared(smem_d_1 + i * 4);
st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]});
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M);
cute::tma_store_arrive();
}
__syncwarp();
}
cudaTriggerProgrammaticLaunchCompletion();
}
#else
if (blockIdx.x == 0 && threadIdx.x == 0)
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
#endif
}
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kLastStages,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
class WgradGemm {
public:
WgradGemm() = default;
static void run(uint32_t shape_k,
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_scales_b_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_wgrad_gemm_kernel<SHAPE_M, SHAPE_N, BLOCK_M, BLOCK_N, BLOCK_K,
kNumStages, kLastStages, kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kIsTMAMulticastOnA>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr[2];
attr[0].id = cudaLaunchAttributeClusterDimension;
attr[0].val.clusterDim = {kNumTMAMulticast, 1, 1};
attr[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr[1].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attr;
config.numAttrs = 2;
// Launch
auto status = cudaLaunchKernelEx(&config, kernel,
shape_k,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k, uint32_t a_stride) {
return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m, shape_k, BLOCK_M, BLOCK_K, a_stride);
}
template <typename T>
static CUtensorMap make_2d_tma_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k, uint32_t b_stride) {
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k, shape_n, BLOCK_K, BLOCK_N, b_stride);
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m, uint32_t shape_n, uint32_t d_stride) {
return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m, shape_n, BLOCK_M, BLOCK_N, d_stride,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_m, ceil_div(shape_k, BLOCK_K), BLOCK_M, 1, shape_m,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_n = ceil_div(shape_n, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_n, ceil_div(shape_k, BLOCK_K), BLOCK_N, 1, shape_n,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_desc(
T* global_address, Layout layout,
uint32_t gmem_rows, uint32_t gmem_cols,
uint32_t smem_rows, uint32_t smem_cols,
uint32_t gmem_stride,
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_stride * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_stride * sizeof(T), smem_dim, swizzle_type);
}
}
};
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@ -66,6 +66,12 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
@ -74,6 +80,10 @@ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
}
template <int N>
__device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
@ -181,4 +191,19 @@ struct FP8MMASelector {
using type = decltype(select_type());
};
enum class Layout {
RowMajor,
ColMajor
};
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
return block_m == 64 ? 1 : 2;
}
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
} // namespace deep_gemm

View File

@ -3,6 +3,10 @@ from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
)
from .utils import (
ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,

View File

@ -61,16 +61,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
smem_d = block_m * (block_n + block_n_padding) * 2
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = ceil_div(k, block_k) * 4
if is_wgrad:
smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128
else:
smem_scales_b = ceil_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2
smem_size = 0
@ -78,7 +82,10 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
if is_wgrad:
smem_size += num_stages * smem_scales_b_per_stage
else:
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
@ -89,13 +96,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
if not is_grouped_contiguous:
block_ms = (64, 128, 256)
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
# Avoid bank conflicts for fp32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@ -135,7 +147,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, 3)
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break

View File

@ -0,0 +1,171 @@
import math
import torch
from typing import List, Tuple
from .gemm import get_best_configs
from .tuner import jit_tuner
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
# C++ code templates
includes = ('"deep_gemm/fp8_wgrad_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto M = {M}, N = {N};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kLastStages = {LAST_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated GEMM
using gemm_t = WgradGemm<M, N, BLOCK_M, BLOCK_N, BLOCK_K, kNumStages, kLastStages, kNumTMAMulticast, kIsTMAMulticastOnA>;
// Launch kernel
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m, k, a_stride);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs, n, k, b_stride);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m, k);
auto tma_scales_b_desc = gemm_t::make_2d_tma_scales_b_desc(rhs_scales, n, k);
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m, n, d_stride);
gemm_t::run(k,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: Tuple[torch.Tensor, torch.Tensor]):
"""
Do a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, k / 128]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, k / 128]`.
out: the FP32 output tensor of shape `[m, n]`, representing the result.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
m_, n_ = out.shape
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and m > 0
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m):
lhs_scales = lhs_scales.permute(1, 0)
assert get_tma_aligned_size(m, 4) == m
else:
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert lhs_scales.stride(0) == 1
if rhs_scales.shape == ((k + 127) // 128, n):
rhs_scales = rhs_scales.permute(1, 0)
assert get_tma_aligned_size(n, 4) == n
else:
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
assert rhs_scales.stride(0) == 1
# Do nothing if `k` is zero
if k == 0:
return
aligned_n = (n + 63) // 64 * 64
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
last_stages = (k + 127) // 128 % num_stages
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, n, k,
lhs_stride, rhs_stride, out_stride,
torch.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_fp32_nt_dptp128c_dyn',
keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages,
'LAST_STAGES': last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.float), ('m', int), ('n', int), ('k', int),
('a_stride', int), ('b_stride', int), ('d_stride', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
# Run the kernel
runtime(*args)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
batch_sizes: List[int]):
"""
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
Arguments:
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
the second element is an FP32 scaling tensor for LHS with shape `[k / 128 for k in batch_sizes), m]`,
representing the per-128-channel scaling factors.
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
the second element is an FP32 scaling tensor for RHS with shape `[k / 128 for k in batch_sizes), n]`,
representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], representing the result.
batch_sizes: A list of integers specifying the k-dimension for each batch.
"""
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
num_batches, m, n = out.shape
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
for idx in range(num_batches):
k = batch_sizes[idx]
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
D = out[idx]
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
lhs_offset += m * k
rhs_offset += n * k
scales_offset += (k + 127) // 128

View File

@ -78,7 +78,7 @@ class suppress_stdout_stderr:
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, is_multiple: bool = False):
# Conflict with Nsight Systems
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
@ -136,8 +136,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
if not is_multiple:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
@ -147,14 +148,29 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
if not is_multiple:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
else:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tupled else kernel_times[0]

View File

@ -1,6 +1,6 @@
import random
import torch
from typing import Tuple
from typing import List, Tuple
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
@ -89,6 +89,70 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
return x_fp8, y_fp8, out, ref_out
def construct_wgrad(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10
out = residual.clone()
ref_out = residual + (x.float() @ y.float().t())
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = per_token_cast_to_fp8(y)
# NOTES: please do inplace add on the `out` later
return x_fp8, y_fp8, residual, out, ref_out
def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]:
num_groups, total_k = len(k_sizes), sum(k_sizes)
x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16)
y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16)
out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
# Fill tensors with data and compute reference output
x_offset, y_offset = 0, 0
for idx, k in enumerate(k_sizes):
x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten())
y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten())
ref_out[idx] = x_chunk.float() @ y_chunk.float().t()
x_offset += m * k
y_offset += n * k
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
# Cast to FP8 and prepare scale factors
x_offset, y_offset, scale_offset = 0, 0, 0
for k in k_sizes:
x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k))
y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k))
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
num_scales = (k + 127) // 128
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
x_offset += m * k
y_offset += n * k
scale_offset += num_scales
return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes
def test_gemm() -> None:
print('Testing GEMM:')
for m in (64, 128, 4096):
@ -170,6 +234,62 @@ def test_m_grouped_gemm_masked() -> None:
print()
def test_wgrad_gemm():
print('Testing weight gradient GEMM:')
for k in (4096, 8192):
for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)):
# Test correctness
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
# noinspection PyShadowingNames
def test_func():
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_k_grouped_wgrad_gemm():
print('Testing grouped weight gradient GEMM:')
for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)):
for m, n in ((7168, 4096), (2048, 7168)):
# Vary k sizes around base_k
k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)]
k_sizes.append(base_k * num_groups - sum(k_sizes))
# Test correctness
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
for idx in range(num_groups):
diff = calc_diff(out[idx], ref_out[idx])
assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}'
# Construct new tensors to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
total_k = sum(k_sizes)
def test_func():
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, is_multiple=True) * num_groups
print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, '
f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@ -182,3 +302,6 @@ if __name__ == '__main__':
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_wgrad_gemm()
test_k_grouped_wgrad_gemm()