mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Init weight gradient kernels.
This commit is contained in:
parent
d374456787
commit
d5470d3b4e
@ -5,6 +5,8 @@ from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
ceil_div,
|
||||
set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
|
||||
@ -16,21 +16,6 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
|
||||
return block_m == 64 ? 1 : 2;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <int kNumFormerIters, int kGap, int kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
|
||||
if (num_former_iters == kNumFormerIters) {
|
||||
|
||||
468
deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh
Normal file
468
deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh
Normal file
@ -0,0 +1,468 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumStages, uint32_t kLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U;
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
|
||||
const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K);
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_scales_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<float*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b[kNumStages];
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages + 1];
|
||||
Barrier* empty_barriers[kNumStages + 1];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)
|
||||
+ i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)
|
||||
+ i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages
|
||||
* (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages + 1; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
||||
// even with TMA multicast disabled, we want to make the behavior aligned
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
full_barriers[kNumStages]->init(1);
|
||||
empty_barriers[kNumStages]->init(1);
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
if constexpr (kLastStages == 0) {
|
||||
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(num_iterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<GemmType::Normal, SHAPE_N, BLOCK_M, BLOCK_N, 1, kNumTMAMulticast, kIsTMAMulticastOnA>(SHAPE_M);
|
||||
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
k_idx / BLOCK_K, num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b);
|
||||
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b);
|
||||
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
|
||||
// Issue TMA D
|
||||
empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1);
|
||||
auto& full_barrier = *full_barriers[kNumStages];
|
||||
tma_copy(&tensor_map_d, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_D_SIZE);
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
|
||||
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
float2 scales_b[WGMMA::kNumAccum / 4];
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
|
||||
// Read A scales
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Read B scales at the first warpgroup wave
|
||||
if (local_idx == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
|
||||
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_scales_b[s] + i * 8 + col_idx * 2));
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
const float &scale_b_0 = scales_b[i].x;
|
||||
const float &scale_b_1 = scales_b[i].y;
|
||||
shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
|
||||
shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
|
||||
shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
|
||||
shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait last TMA store to be finished
|
||||
if (k_iter == 0 and scheduler.current_iter > 0) {
|
||||
if (threadIdx.x == 0) {
|
||||
cute::tma_store_wait<0>();
|
||||
empty_barriers[kNumStages]->arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Wait TMA D arrivals
|
||||
full_barriers[kNumStages]->wait(scheduler.current_iter & 1);
|
||||
|
||||
// Accumulate to D shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2);
|
||||
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2);
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
float2 d_0 = ld_shared(smem_d_0 + i * 4);
|
||||
st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]});
|
||||
float2 d_1 = ld_shared(smem_d_1 + i * 4);
|
||||
st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]});
|
||||
}
|
||||
}
|
||||
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumStages, uint32_t kLastStages,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
||||
class WgradGemm {
|
||||
public:
|
||||
WgradGemm() = default;
|
||||
|
||||
static void run(uint32_t shape_k,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_scales_b_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_wgrad_gemm_kernel<SHAPE_M, SHAPE_N, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumStages, kLastStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kIsTMAMulticastOnA>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr[2];
|
||||
attr[0].id = cudaLaunchAttributeClusterDimension;
|
||||
attr[0].val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
attr[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attr[1].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.attrs = attr;
|
||||
config.numAttrs = 2;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
shape_k,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k, uint32_t a_stride) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m, shape_k, BLOCK_M, BLOCK_K, a_stride);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k, uint32_t b_stride) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k, shape_n, BLOCK_K, BLOCK_N, b_stride);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m, uint32_t shape_n, uint32_t d_stride) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m, shape_n, BLOCK_M, BLOCK_N, d_stride,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_m, ceil_div(shape_k, BLOCK_K), BLOCK_M, 1, shape_m,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_n = ceil_div(shape_n, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_n, ceil_div(shape_k, BLOCK_K), BLOCK_N, 1, shape_n,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
uint32_t gmem_stride,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_stride * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_stride * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@ -66,6 +66,12 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
|
||||
float2 ret;
|
||||
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
@ -74,6 +80,10 @@ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
|
||||
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
@ -181,4 +191,19 @@ struct FP8MMASelector {
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
|
||||
return block_m == 64 ? 1 : 2;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@ -3,6 +3,10 @@ from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .wgrad_gemm import (
|
||||
wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
|
||||
)
|
||||
from .utils import (
|
||||
ceil_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
|
||||
@ -61,16 +61,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
|
||||
|
||||
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
if is_wgrad:
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128
|
||||
else:
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
@ -78,7 +82,10 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
if is_wgrad:
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
else:
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@ -89,13 +96,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
|
||||
if not is_grouped_contiguous:
|
||||
block_ms = (64, 128, 256)
|
||||
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
|
||||
# Avoid bank conflicts for fp32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
@ -135,7 +147,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = (4, 3)
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
|
||||
if best_smem_config[0] <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
|
||||
171
deep_gemm/jit_kernels/wgrad_gemm.py
Normal file
171
deep_gemm/jit_kernels/wgrad_gemm.py
Normal file
@ -0,0 +1,171 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_wgrad_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto M = {M}, N = {N};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kLastStages = {LAST_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated GEMM
|
||||
using gemm_t = WgradGemm<M, N, BLOCK_M, BLOCK_N, BLOCK_K, kNumStages, kLastStages, kNumTMAMulticast, kIsTMAMulticastOnA>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m, k, a_stride);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs, n, k, b_stride);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m, k);
|
||||
auto tma_scales_b_desc = gemm_t::make_2d_tma_scales_b_desc(rhs_scales, n, k);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m, n, d_stride);
|
||||
gemm_t::run(k,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: Tuple[torch.Tensor, torch.Tensor]):
|
||||
"""
|
||||
Do a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
|
||||
out: the FP32 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and m > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
|
||||
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.float
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
lhs_stride = lhs.stride(0)
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# LHS and RHS scales must be transposed for TMA load
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
||||
lhs_scales = lhs_scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(m, 4) == m
|
||||
else:
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert lhs_scales.stride(0) == 1
|
||||
|
||||
if rhs_scales.shape == ((k + 127) // 128, n):
|
||||
rhs_scales = rhs_scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(n, 4) == n
|
||||
else:
|
||||
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
|
||||
assert rhs_scales.stride(0) == 1
|
||||
|
||||
# Do nothing if `k` is zero
|
||||
if k == 0:
|
||||
return
|
||||
|
||||
aligned_n = (n + 63) // 64 * 64
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
last_stages = (k + 127) // 128 % num_stages
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, n, k,
|
||||
lhs_stride, rhs_stride, out_stride,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_fp32_nt_dptp128c_dyn',
|
||||
keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages,
|
||||
'LAST_STAGES': last_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.float), ('m', int), ('n', int), ('k', int),
|
||||
('a_stride', int), ('b_stride', int), ('d_stride', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
|
||||
|
||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
batch_sizes: List[int]):
|
||||
"""
|
||||
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
||||
Each batch's LHS, RHS, and output tensors must be contiguous.
|
||||
The RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
||||
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
|
||||
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
||||
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
||||
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
out: The FP32 output tensor of shape [num_batches, m, n], representing the result.
|
||||
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
||||
"""
|
||||
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
|
||||
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
|
||||
num_batches, m, n = out.shape
|
||||
|
||||
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
|
||||
|
||||
for idx in range(num_batches):
|
||||
k = batch_sizes[idx]
|
||||
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
D = out[idx]
|
||||
|
||||
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
|
||||
|
||||
lhs_offset += m * k
|
||||
rhs_offset += n * k
|
||||
scales_offset += (k + 127) // 128
|
||||
@ -78,7 +78,7 @@ class suppress_stdout_stderr:
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, is_multiple: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
|
||||
|
||||
@ -136,8 +136,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
if not is_multiple:
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
@ -147,14 +148,29 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
if not is_multiple:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
else:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
@ -89,6 +89,70 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_wgrad(m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10
|
||||
out = residual.clone()
|
||||
ref_out = residual + (x.float() @ y.float().t())
|
||||
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = per_token_cast_to_fp8(y)
|
||||
|
||||
# NOTES: please do inplace add on the `out` later
|
||||
return x_fp8, y_fp8, residual, out, ref_out
|
||||
|
||||
|
||||
def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]:
|
||||
num_groups, total_k = len(k_sizes), sum(k_sizes)
|
||||
|
||||
x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16)
|
||||
y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
|
||||
ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
|
||||
|
||||
# Fill tensors with data and compute reference output
|
||||
x_offset, y_offset = 0, 0
|
||||
for idx, k in enumerate(k_sizes):
|
||||
x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten())
|
||||
y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten())
|
||||
ref_out[idx] = x_chunk.float() @ y_chunk.float().t()
|
||||
|
||||
x_offset += m * k
|
||||
y_offset += n * k
|
||||
|
||||
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
||||
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
||||
|
||||
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
|
||||
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
||||
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
||||
|
||||
# Cast to FP8 and prepare scale factors
|
||||
x_offset, y_offset, scale_offset = 0, 0, 0
|
||||
for k in k_sizes:
|
||||
x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k))
|
||||
y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k))
|
||||
|
||||
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
||||
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
||||
|
||||
num_scales = (k + 127) // 128
|
||||
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
||||
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
||||
|
||||
x_offset += m * k
|
||||
y_offset += n * k
|
||||
scale_offset += num_scales
|
||||
|
||||
return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes
|
||||
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for m in (64, 128, 4096):
|
||||
@ -170,6 +234,62 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print()
|
||||
|
||||
|
||||
def test_wgrad_gemm():
|
||||
print('Testing weight gradient GEMM:')
|
||||
|
||||
for k in (4096, 8192):
|
||||
for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)):
|
||||
# Test correctness
|
||||
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
|
||||
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_k_grouped_wgrad_gemm():
|
||||
print('Testing grouped weight gradient GEMM:')
|
||||
|
||||
for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)):
|
||||
for m, n in ((7168, 4096), (2048, 7168)):
|
||||
# Vary k sizes around base_k
|
||||
k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)]
|
||||
k_sizes.append(base_k * num_groups - sum(k_sizes))
|
||||
|
||||
# Test correctness
|
||||
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
|
||||
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
|
||||
|
||||
for idx in range(num_groups):
|
||||
diff = calc_diff(out[idx], ref_out[idx])
|
||||
assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
|
||||
total_k = sum(k_sizes)
|
||||
|
||||
def test_func():
|
||||
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, is_multiple=True) * num_groups
|
||||
print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
@ -182,3 +302,6 @@ if __name__ == '__main__':
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
|
||||
test_wgrad_gemm()
|
||||
test_k_grouped_wgrad_gemm()
|
||||
Loading…
Reference in New Issue
Block a user