mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-30 21:25:09 +00:00
Merge pull request #74 from deepseek-ai/larger-block
Performance: Larger BlockTile optimizations enable 1470+ TF FP8 on the "H800"-SXM
This commit is contained in:
commit
a5645d7afa
20
README.md
20
README.md
@ -28,21 +28,21 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed
|
||||
| 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
|
||||
| 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
|
||||
| 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
|
||||
| 4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x |
|
||||
| 4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x |
|
||||
| 4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x |
|
||||
| 4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x |
|
||||
| 4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x |
|
||||
| 4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x |
|
||||
| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x |
|
||||
| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x |
|
||||
| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x |
|
||||
| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x |
|
||||
| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x |
|
||||
| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x |
|
||||
|
||||
### Grouped GEMMs for MoE models (contiguous layout)
|
||||
|
||||
| #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup |
|
||||
|:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:|
|
||||
| 4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x |
|
||||
| 4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x |
|
||||
| 8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x |
|
||||
| 8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x |
|
||||
| 4 | 8192 | 4096 | 7168 | 1346 TFLOPS | 434 GB/s | 1.3x |
|
||||
| 4 | 8192 | 7168 | 2048 | 1214 TFLOPS | 752 GB/s | 1.3x |
|
||||
| 8 | 4096 | 4096 | 7168 | 1346 TFLOPS | 516 GB/s | 1.3x |
|
||||
| 8 | 4096 | 7168 | 2048 | 1214 TFLOPS | 826 GB/s | 1.2x |
|
||||
|
||||
### Grouped GEMMs for MoE models (masked layout)
|
||||
|
||||
|
@ -27,11 +27,22 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <int kNumFormerIters, int kGap, int kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
|
||||
if (num_former_iters == kNumFormerIters) {
|
||||
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (kNumFormerIters + kGap <= kEnd)
|
||||
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
@ -43,7 +54,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
@ -129,15 +140,23 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
|
||||
constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
|
||||
// NOTES: for too-many branches (> 5), we disable this optimization
|
||||
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
|
||||
}
|
||||
}, func, kShouldOptimize ? num_former_iters : 0);
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
@ -146,7 +165,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
@ -156,11 +175,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
launch_k_iterations([&](int k_iter, auto type, auto _) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
|
||||
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
#pragma unroll
|
||||
@ -168,18 +191,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
// Issue TMA B
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
@ -189,7 +212,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}, 0);
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
@ -242,7 +265,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
@ -288,8 +311,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
@ -304,7 +329,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
}, num_former_iters);
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
@ -347,7 +372,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
@ -369,7 +394,7 @@ public:
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
|
@ -665,6 +665,135 @@ struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
|
||||
struct SM90_64x144x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %74, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71}, "
|
||||
" %72,"
|
||||
" %73,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 144;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
|
||||
struct SM90_64x160x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %82, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79}, "
|
||||
" %80,"
|
||||
" %81,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 160;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
@ -876,6 +1005,8 @@ struct FP8MMASelector {
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
|
@ -12,9 +12,10 @@ enum class GemmType {
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
uint32_t kNum1DBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
@ -43,16 +44,27 @@ struct Scheduler {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
// TODO: unify these 2 branches
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
} else {
|
||||
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_m_blocks_in_group;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
@ -98,6 +110,7 @@ struct Scheduler {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
@ -46,3 +46,8 @@ template <typename T>
|
||||
__device__ __host__ constexpr T ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T gcd(T a, T b) {
|
||||
return b == 0 ? a : gcd(b, a % b);
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=177,174,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
||||
|
@ -67,7 +67,10 @@ def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
||||
value_str = str(value)
|
||||
if isinstance(value, bool):
|
||||
value_str = value_str.lower()
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
|
||||
return new_template
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
@ -15,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
@ -31,10 +33,10 @@ GemmType::run(out, rhs_scales, nullptr,
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
||||
@ -56,13 +58,13 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
|
||||
|
||||
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int, int]:
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
@ -89,26 +91,38 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
stage_candidates = (8, 7, 6, 5, 4)
|
||||
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = (4, )
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
# Decide the number of TMA multicast and whether broadcast on A
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms),
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
if m >= 512 and is_multicast_legal[i] and num_groups == 1:
|
||||
best_tma_multicast_config = (2, i == 'A')
|
||||
break
|
||||
|
||||
# Recompute the minimal number of SMs required
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast
|
||||
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
|
||||
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
assert num_min_sms <= num_sms
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -158,12 +172,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
|
@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
||||
is_grouped_contiguous=True)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedContiguous'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and m > block_m:
|
||||
@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedMasked'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
|
Loading…
Reference in New Issue
Block a user