mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 19:14:21 +00:00
Merge pull request #81 from deepseek-ai/blocktile-256x128
Performance: BlockTile 256x128 optimizations enable 1500+ TF FP8
This commit is contained in:
commit
fed3e4d701
16
README.md
16
README.md
@ -6,6 +6,10 @@ Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address
|
|||||||
|
|
||||||
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
|
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
|
||||||
|
|
||||||
|
## News
|
||||||
|
|
||||||
|
- 2025.04.09: DeepGEMM now achieves up to **1520 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), and [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81) for details.
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
We test all shapes potentially used in DeepSeek-V3/R1 inference (including both prefilling and decoding, but without tensor parallelism) on H800 SXM5 with NVCC 12.8. All speedup metrics are calculated in comparison to our internally and carefully optimized implementation based on CUTLASS 3.6.
|
We test all shapes potentially used in DeepSeek-V3/R1 inference (including both prefilling and decoding, but without tensor parallelism) on H800 SXM5 with NVCC 12.8. All speedup metrics are calculated in comparison to our internally and carefully optimized implementation based on CUTLASS 3.6.
|
||||||
@ -28,11 +32,11 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed
|
|||||||
| 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
|
| 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
|
||||||
| 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
|
| 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
|
||||||
| 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
|
| 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
|
||||||
| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x |
|
| 4096 | 2112 | 7168 | 1127 TFLOPS | 562 GB/s | 1.2x |
|
||||||
| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x |
|
| 4096 | 24576 | 1536 | 1212 TFLOPS | 962 GB/s | 1.2x |
|
||||||
| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x |
|
| 4096 | 32768 | 512 | 775 TFLOPS | 1620 GB/s | 1.2x |
|
||||||
| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x |
|
| 4096 | 7168 | 16384 | 1520 TFLOPS | 384 GB/s | 1.4x |
|
||||||
| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x |
|
| 4096 | 4096 | 7168 | 1410 TFLOPS | 541 GB/s | 1.3x |
|
||||||
| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x |
|
| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x |
|
||||||
|
|
||||||
### Grouped GEMMs for MoE models (contiguous layout)
|
### Grouped GEMMs for MoE models (contiguous layout)
|
||||||
@ -160,6 +164,8 @@ The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide
|
|||||||
|
|
||||||
- Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction
|
- Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction
|
||||||
- [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups
|
- [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups
|
||||||
|
- Larger block sizes
|
||||||
|
- Less bank conflicts via 3D TMA 🐳
|
||||||
- Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳
|
- Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳
|
||||||
|
|
||||||
#### A unified and optimized block scheduler
|
#### A unified and optimized block scheduler
|
||||||
|
@ -21,10 +21,14 @@ enum class Layout {
|
|||||||
ColMajor
|
ColMajor
|
||||||
};
|
};
|
||||||
|
|
||||||
|
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
|
||||||
|
return block_m == 64 ? 1 : 2;
|
||||||
|
}
|
||||||
|
|
||||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumFormerIters, int kGap, int kEnd>
|
template <int kNumFormerIters, int kGap, int kEnd>
|
||||||
@ -257,7 +261,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||||
|
|
||||||
// Accumulation for WGMMA or CUDA promotion
|
// Accumulation for WGMMA or CUDA promotion
|
||||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
|
||||||
|
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
||||||
|
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||||
|
|
||||||
// Empty barrier arrival
|
// Empty barrier arrival
|
||||||
auto empty_barrier_arrive = [&](int s) {
|
auto empty_barrier_arrive = [&](int s) {
|
||||||
@ -285,9 +291,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
// Wait TMA arrivals
|
// Wait TMA arrivals
|
||||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||||
|
|
||||||
|
// TODO: remove some useless computation for unaligned Ms
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||||
|
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||||
|
|
||||||
// Read A scales
|
// Read A scales
|
||||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
|
||||||
|
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
|
||||||
|
|
||||||
// Commit WGMMA instructions
|
// Commit WGMMA instructions
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -296,7 +308,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
warpgroup_arrive();
|
warpgroup_arrive();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||||
}
|
}
|
||||||
@ -306,7 +318,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
warpgroup_fence_operand(accum[i]);
|
warpgroup_fence_operand(accum[i]);
|
||||||
warpgroup_wait<0>();
|
warpgroup_wait<0>();
|
||||||
|
|
||||||
// Notify barrier arrival
|
// Notify barrier arrival at the last warpgroup wave
|
||||||
|
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||||
empty_barrier_arrive(s);
|
empty_barrier_arrive(s);
|
||||||
|
|
||||||
// Promote with scales
|
// Promote with scales
|
||||||
@ -316,14 +329,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
if constexpr (not kMustUseUniformedScaleB)
|
if constexpr (not kMustUseUniformedScaleB)
|
||||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||||
|
|
||||||
|
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,22 +353,27 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
// Write back to shared memory using STSM
|
// Write back to shared memory using STSM
|
||||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||||
|
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||||
|
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||||
|
#pragma unroll
|
||||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
__float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}),
|
||||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
__float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}),
|
||||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
__float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}),
|
||||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
__float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}),
|
||||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
|
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
|
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
cute::tma_store_fence();
|
cute::tma_store_fence();
|
||||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||||
|
|
||||||
|
@ -73,8 +73,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
||||||
Tuple[int, int, int, int, Tuple[int, bool], int]:
|
Tuple[int, int, int, int, Tuple[int, bool], int]:
|
||||||
if not is_grouped_contiguous:
|
if not is_grouped_contiguous:
|
||||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
block_ms = (64, 128, 256)
|
||||||
block_ms = (64 if m <= 64 else 128, )
|
|
||||||
else:
|
else:
|
||||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||||
@ -86,7 +85,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
# Decide block sizes by waves
|
# Decide block sizes by waves
|
||||||
best_block_m, best_block_n = None, None
|
best_block_m, best_block_n = None, None
|
||||||
for block_m in block_ms:
|
for block_m in block_ms:
|
||||||
for block_n in block_ns:
|
# NOTES: the block sizes can not be too large, so at least one dim less than 128
|
||||||
|
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
||||||
success = False
|
success = False
|
||||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||||
if best_block_m is None or best_block_n is None:
|
if best_block_m is None or best_block_n is None:
|
||||||
@ -97,7 +97,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
# Check last wave utilization
|
# Check last wave utilization
|
||||||
util = get_last_wave_util(block_m, block_n)
|
util = get_last_wave_util(block_m, block_n)
|
||||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||||
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
|
success = util > best_util
|
||||||
|
if util == best_util:
|
||||||
|
# Case 1: same `block_m`, smaller `block_n` (wasted)
|
||||||
|
success |= block_m == best_block_m and block_n < best_block_n
|
||||||
|
# Case 2: same `block_n`, smaller `block_m` (wasted)
|
||||||
|
success |= block_n == best_block_n and block_m < best_block_m
|
||||||
|
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
||||||
|
success |= block_m != best_block_m and block_n > best_block_n
|
||||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||||
assert best_block_m is not None and best_block_n is not None
|
assert best_block_m is not None and best_block_n is not None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user