mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-05 17:44:54 +00:00
tested
This commit is contained in:
parent
9d3222a93e
commit
6e53c6613d
@ -3,6 +3,7 @@ import torch
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
gemm_fp8_fp8_bf16_bw_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
ceil_div,
|
||||
|
442
deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh
Normal file
442
deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh
Normal file
@ -0,0 +1,442 @@
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
#include "fp8_gemm.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
// static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div<uint32_t>(BLOCK_N * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
|
||||
//static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
|
||||
//static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b[kNumStages];
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) + i * SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
}
|
||||
// Fill barriers
|
||||
// auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b[kNumStages-1]) + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
||||
// even with TMA multicast disabled, we want to make the behavior aligned
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
// Only support normal gemm now. @kavioyu
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(0, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
// Only support normal gemm now. @kavioyu
|
||||
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_b[s], n_block_idx * BLOCK_N, scheduler.get_global_idx(0, 1, k_idx / BLOCK_K));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
// full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const int laneid_div_4 = lane_idx / 4;
|
||||
const auto r_0 = warp_idx * 16 + laneid_div_4, r_1 = r_0 + 8;
|
||||
const unsigned int scale_b_idx[2] = {laneid_div_4 * 8 + lane_idx % 4 * 2, (8 + laneid_div_4) * 8 + lane_idx % 4 * 2};
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
// Read B scales
|
||||
float2 scale_b[2];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
scale_b[i].x = ld_shared(smem_scales_b[s] + scale_b_idx[i]);
|
||||
scale_b[i].y = ld_shared(smem_scales_b[s] + scale_b_idx[i] + 1);
|
||||
}
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // WGMMA::kNumAccum = 64, loop 16 steps
|
||||
int src_lane_id = threadIdx.x % 4 + (i % 8) * 4;
|
||||
float scale_b_0 = __shfl_sync(0xffffffff, scale_b[i/8].x, src_lane_id);
|
||||
float scale_b_1 = __shfl_sync(0xffffffff, scale_b[i/8].y, src_lane_id);
|
||||
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class GemmBW {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
GemmBW() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_scales_b_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_bw_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
min(BLOCK_M, shape_m), BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_b_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
constexpr uint32_t shape_n = ceil_div(SHAPE_N, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_n, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_N, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
@ -1,4 +1,5 @@
|
||||
from .gemm import gemm_fp8_fp8_bf16_nt
|
||||
from .gemm_bw import gemm_fp8_fp8_bf16_bw_nt
|
||||
from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
|
186
deep_gemm/jit_kernels/gemm_bw.py
Normal file
186
deep_gemm/jit_kernels/gemm_bw.py
Normal file
@ -0,0 +1,186 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import (
|
||||
ceil_div,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout,
|
||||
get_num_sms,
|
||||
)
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm_backward_w.cuh"',)
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = GemmBW<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_scales_b_desc = GemmType::make_2d_tma_scales_b_desc(rhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(
|
||||
n: int, block_n: int, num_tma_multicast: int, num_sms: int
|
||||
) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_smem_size(
|
||||
num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128
|
||||
) -> int:
|
||||
smem_d = block_m * block_n * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_scales_b_per_stage = block_n * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
# smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
# smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
return smem_size
|
||||
|
||||
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(32, 129, 32))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_bw_nt(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (n, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_bw_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
@ -1,9 +1,15 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
bench_kineto,
|
||||
calc_diff,
|
||||
ceil_div,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -33,8 +39,18 @@ def construct(m: int, k: int, n: int) -> \
|
||||
ref_out = x @ y.t()
|
||||
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_backward_w(m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = x @ y.t()
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_token_cast_to_fp8(y)
|
||||
#x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
#y_fp8 = (y_fp8[0], get_col_major_tma_aligned_tensor(y_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
@ -84,6 +100,30 @@ def test_gemm() -> None:
|
||||
print()
|
||||
|
||||
|
||||
def test_gemm_backward_w() -> None:
|
||||
print('Testing GEMM Backward W:')
|
||||
for m in (64, 128, 4096):
|
||||
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
||||
x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
torch.cuda.synchronize()
|
||||
print(diff)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm_bw', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing grouped contiguous GEMM:')
|
||||
|
||||
@ -153,6 +193,7 @@ if __name__ == '__main__':
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm_backward_w()
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
# test_m_grouped_gemm_contiguous()
|
||||
# test_m_grouped_gemm_masked()
|
||||
|
Loading…
Reference in New Issue
Block a user