Merge pull request #74 from deepseek-ai/larger-block

Performance: Larger BlockTile optimizations enable 1470+ TF FP8 on the "H800"-SXM
This commit is contained in:
Chenggang Zhao 2025-03-25 18:07:33 +08:00 committed by GitHub
commit a5645d7afa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 269 additions and 70 deletions

View File

@ -28,21 +28,21 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed
| 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
| 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
| 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
| 4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x |
| 4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x |
| 4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x |
| 4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x |
| 4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x |
| 4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x |
| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x |
| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x |
| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x |
| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x |
| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x |
| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x |
### Grouped GEMMs for MoE models (contiguous layout)
| #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup |
|:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:|
| 4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x |
| 4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x |
| 8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x |
| 8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x |
| 4 | 8192 | 4096 | 7168 | 1346 TFLOPS | 434 GB/s | 1.3x |
| 4 | 8192 | 7168 | 2048 | 1214 TFLOPS | 752 GB/s | 1.3x |
| 8 | 4096 | 4096 | 7168 | 1346 TFLOPS | 516 GB/s | 1.3x |
| 8 | 4096 | 7168 | 2048 | 1214 TFLOPS | 826 GB/s | 1.2x |
### Grouped GEMMs for MoE models (masked layout)

View File

@ -27,11 +27,22 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <int kNumFormerIters, int kGap, int kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
if (num_former_iters == kNumFormerIters) {
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
return;
}
if constexpr (kNumFormerIters + kGap <= kEnd)
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
@ -43,7 +54,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
@ -129,15 +140,23 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(kNumIterations - 1, NotDivisibleK{});
}
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type);
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}, num_former_iters_type);
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
}
}, func, kShouldOptimize ? num_former_iters : 0);
};
// Register reconfigurations
@ -146,7 +165,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
@ -156,11 +175,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
launch_k_iterations([&](int k_iter, auto type, auto _) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
@ -168,18 +191,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Issue TMA A with broadcasting
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
// Issue TMA B
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
@ -189,7 +212,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}, 0);
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
@ -242,7 +265,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
@ -288,8 +311,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
@ -304,7 +329,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
}, num_former_iters);
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
@ -347,7 +372,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
class Gemm {
private:
@ -369,7 +394,7 @@ public:
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kGemmType>;
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch

View File

@ -665,6 +665,135 @@ struct SM90_64x128x32_F32E4M3E4M3_SS {
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x144x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %74, 0;\n"
"wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71}, "
" %72,"
" %73,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 144;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x160x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %82, 0;\n"
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71, "
" %72, %73, %74, %75, %76, %77, %78, %79}, "
" %80,"
" %81,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 160;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x192x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
@ -876,6 +1005,8 @@ struct FP8MMASelector {
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS();
if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS();
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
}

View File

@ -12,9 +12,10 @@ enum class GemmType {
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumGroups,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
uint32_t kNum1DBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
uint32_t num_aligned_m_blocks;
@ -43,16 +44,27 @@ struct Scheduler {
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
// TODO: unify these 2 branches
if constexpr (kIsTMAMulticastOnA) {
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
} else {
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
n_block_idx = in_group_idx / num_m_blocks_in_group;
}
}
template <bool kIgnoreGroupedForGroupedContiguous=true>
@ -98,6 +110,7 @@ struct Scheduler {
return true;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@ -46,3 +46,8 @@ template <typename T>
__device__ __host__ constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ constexpr T gcd(T a, T b) {
return b == 0 ? a : gcd(b, a % b);
}

View File

@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=177,174,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]

View File

@ -67,7 +67,10 @@ def cpp_format(template: str, keys: Dict[str, Any]) -> str:
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
new_template = new_template.replace(f'{{{key}}}', f'{value}')
value_str = str(value)
if isinstance(value, bool):
value_str = value_str.lower()
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
return new_template

View File

@ -1,3 +1,4 @@
import math
import torch
from typing import Tuple
@ -15,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@ -31,10 +33,10 @@ GemmType::run(out, rhs_scales, nullptr,
"""
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1:
return True
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
@ -56,13 +58,13 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int, int]:
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]:
if not is_grouped_contiguous:
# TODO: for some cases, smaller M block is better, add them into tuning space
block_ms = (64 if m <= 64 else 128, )
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8))
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@ -89,26 +91,38 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
stage_candidates = (8, 7, 6, 5, 4)
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, )
for num_stages in stage_candidates:
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
assert best_num_stages is not None
# Decide the number of TMA multicast
best_num_tma_multicast = 1
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
# Decide the number of TMA multicast and whether broadcast on A
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms),
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i] and num_groups == 1:
best_tma_multicast_config = (2, i == 'A')
break
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -158,12 +172,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms)
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),

View File

@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated grouped GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
is_grouped_contiguous=True)
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedContiguous'},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedMasked'},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),