support groupwise scaling for b

This commit is contained in:
zhengxuegui.0 2025-03-01 23:04:15 +08:00
parent 6c5da03ba9
commit 6d0fd7de41
4 changed files with 445 additions and 26 deletions

View File

@ -4,6 +4,8 @@
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include "cutlass/pipeline/sm90_pipeline.hpp"
#include "cutlass/cutlass.h"
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
@ -339,6 +341,322 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#endif
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast,
GemmType kGemmType>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_groupwise_gemm_kernel(__nv_bfloat16* gmem_d, int* grouped_layout,
uint32_t shape_m,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_scales_b,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
// DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_ALIGNMENT = 128; // 128 Byte
// BLOCK_M and BLOCK_K are both multiples of 64,
// So the shared memory address of a/scale_a/d is 128 byte aligned.
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, SMEM_ALIGNMENT) * SMEM_ALIGNMENT;
// Shared memory Alignment static check
DG_STATIC_ASSERT(ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE % SMEM_ALIGNMENT == 0, "Shared memory address must be 128 byte aligned.");
DG_STATIC_ASSERT(SMEM_D_SIZE % SMEM_ALIGNMENT == 0, "Shared memory address must be 128 byte aligned.");
DG_STATIC_ASSERT(SMEM_A_SIZE_PER_STAGE % SMEM_ALIGNMENT == 0, "Shared memory address must be 128 byte aligned.");
DG_STATIC_ASSERT(SMEM_B_SIZE_PER_STAGE % SMEM_ALIGNMENT == 0, "Shared memory address must be 128 byte aligned.");
DG_STATIC_ASSERT(SMEM_SCALES_A_SIZE_PER_STAGE % SMEM_ALIGNMENT == 0, "Shared memory address must be 128 byte aligned.");
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b[kNumStages];
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) + i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE);
}
constexpr uint32_t BARRIER_SMEM_OFFSET = (SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE) + 7) / 8 * 8;
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + BARRIER_SMEM_OFFSET);
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(kNumIterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Issue TMA A with broadcasting
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_b[s], n_block_idx * BLOCK_N,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
constexpr uint32_t ktransactionBytesPerStage = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE;
full_barrier.arrive_and_expect_tx(ktransactionBytesPerStage);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
const auto col_offset = lane_idx % 4 * 2;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
float2 scales_b_reg[BLOCK_N / 8];
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Read B scales
#pragma unroll
for (int i = 0; i < BLOCK_N / 8; ++i) {
scales_b_reg[i] = ld_shared(reinterpret_cast<float2 *>(smem_scales_b[s] + i * 8 + col_offset));
}
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// m64nNk32
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b_reg[i].x;
const float &scale_b_1 = scales_b_reg[i].y;
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
@ -351,6 +669,7 @@ private:
public:
Gemm() = default;
// groupwise scale gemm with 1x128 LHS scaling and 128x128 RHS scaling
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const CUtensorMap& tma_a_desc,
@ -390,6 +709,48 @@ public:
DG_HOST_ASSERT(status == cudaSuccess);
}
// groupwise scale gemm with 1x128 LHS scaling and 128x1 RHS scaling
static void run(__nv_bfloat16* gmem_d, int* grouped_layout,
uint32_t shape_m,
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_scales_b_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
DG_HOST_ASSERT(grouped_layout == nullptr);
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_groupwise_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
// Launch
auto status = cudaLaunchKernelEx(&config, kernel,
gmem_d, grouped_layout,
shape_m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
@ -421,6 +782,16 @@ public:
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_b_desc(T* global_address) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
const uint32_t shape_n = ceil_div(SHAPE_N, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor,
shape_n, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_N, 1,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_desc(
T* global_address, Layout layout,

View File

@ -795,6 +795,18 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
return ret;
}
__device__ __forceinline__ float4 ld_shared(const float4* __restrict__ ptr) {
float4 ret;
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}

View File

@ -15,6 +15,7 @@ constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kGroupWiseScalingB = {IS_GROUP_WISE_SCALING_B};
// Make a templated GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
@ -22,12 +23,20 @@ using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMultica
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
if constexpr (not kGroupWiseScalingB) {
GemmType::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
} else {
auto tma_scales_b_desc = GemmType::make_2d_tma_scales_b_desc(rhs_scales);
GemmType::run(out, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_scales_b_desc, tma_d_desc,
stream, num_sms, smem_size);
}
"""
@ -37,12 +46,16 @@ def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, is_groupwise_scaling_b: bool = False) -> int:
# Alignment requirements for multi-dimensional bulk tensor asynchronous copy operations
# Shared memory address : Must be 128 byte aligned.
TMA_SMEM_ALIGNMENT = 128 # Byte
smem_d = block_m * block_n * 2
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = ceil_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2
smem_size = 0
@ -50,19 +63,26 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
if is_groupwise_scaling_b:
smem_scales_b_per_stage = ceil_div(block_n * 4, TMA_SMEM_ALIGNMENT) * TMA_SMEM_ALIGNMENT
smem_size += num_stages * smem_scales_b_per_stage
else:
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
return smem_size
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_groupwise_scaling_b: bool = False,
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
if not is_grouped_contiguous:
# TODO: for some cases, smaller M block is better, add them into tuning space
block_ms = (64 if m <= 64 else 128, )
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8))
# smaller N performs better for groupwise scaling
max_block_n = 128 if not is_groupwise_scaling_b else 112
block_ns = tuple(range(16, max_block_n + 1, 8))
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@ -90,7 +110,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n, is_groupwise_scaling_b=is_groupwise_scaling_b)
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
@ -106,9 +126,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor) -> None:
out: torch.Tensor,
is_groupwise_scaling_b: bool=False) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 or 128x1 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
@ -120,6 +141,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[n / 128, k / 128]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
is_groupwise_scaling_b: If True, applies groupwise scaling(128x1) for RHS.
Default: False (128x128 block scaling).
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
@ -133,7 +156,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
if not is_groupwise_scaling_b:
assert rhs_scales.is_contiguous()
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
else:
assert rhs_scales.shape == (n, (k + 127) // 128)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
@ -142,7 +169,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
if is_groupwise_scaling_b:
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
# Do nothing if `m` is zero
if m == 0:
@ -151,12 +179,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, is_groupwise_scaling_b=is_groupwise_scaling_b)
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast,
'IS_GROUP_WISE_SCALING_B': "true" if is_groupwise_scaling_b else "false"},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),

View File

@ -25,14 +25,17 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def construct(m: int, k: int, n: int) -> \
def construct(m: int, k: int, n: int, is_groupwise_scaling_b: bool=False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
if not is_groupwise_scaling_b:
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
else:
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_token_cast_to_fp8(y)
y_fp8 = (y_fp8[0], get_col_major_tma_aligned_tensor(y_fp8[1]))
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
@ -62,22 +65,25 @@ def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool)
return x_fp8, y_fp8, out, ref_out
def test_gemm() -> None:
print('Testing GEMM:')
def test_gemm(is_groupwise_scaling_b: bool=False) -> None:
print(f'Testing GEMM with 1x128 LHS Scaling {"128x1" if is_groupwise_scaling_b else "128x128"} RHS Scaling')
for m in (64, 128, 4096):
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
x_fp8, y_fp8, out, ref_out = construct(m, k, n, is_groupwise_scaling_b=is_groupwise_scaling_b)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out, is_groupwise_scaling_b=is_groupwise_scaling_b)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
x_fp8, y_fp8, out, ref_out = construct(m, k, n, is_groupwise_scaling_b=is_groupwise_scaling_b)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out, is_groupwise_scaling_b=is_groupwise_scaling_b)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
kernel_name = 'fp8_groupwise_gemm_kernel' if is_groupwise_scaling_b else 'fp8_gemm'
t = bench_kineto(test_func, kernel_name, suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
@ -153,6 +159,7 @@ if __name__ == '__main__':
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_gemm(is_groupwise_scaling_b=True)
test_gemm(is_groupwise_scaling_b=False)
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()