Support block size 160

This commit is contained in:
Chenggang Zhao
2025-03-25 13:37:59 +08:00
parent 46eb0d08fb
commit b922e64cb2
4 changed files with 188 additions and 879 deletions

View File

@@ -10,14 +10,14 @@ template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto M = {M}, N = {N}, K = {K};
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated GEMM
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@@ -62,14 +62,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (64 if m <= 64 else 128, )
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
# Current optimizations target large-scale Normal GEMMs for dense models and
# Grouped GEMMs for MoE models (contiguous memory layout), with a potential
# block_n tile size of 160 to enhance data reuse in block tiling.
if m >= 4096 and num_groups == 1:
block_ns = tuple((16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 160))
else:
block_ns = tuple(range(16, 129, 8))
block_ns = tuple(range(16, 129, 8)) + (160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@@ -84,16 +77,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
# The value 0.02 is currently an empirically estimated threshold to
# filter out cases unsuitable for large block tile configurations,
# with optimizations planned for later stages to address excluded scenarios.
if block_n == 160 and \
(num_waves * block_m * block_n - best_num_waves * best_block_m * best_block_n) / (best_num_waves * best_block_m * best_block_n) < 0.02:
success = True
elif block_n == 160:
success = False
else:
success = True
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
@@ -105,44 +89,25 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
if best_block_n != 160:
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
else:
# NOTES: This is done to reduce the code footprint after unrolling.
# Additionally, if k does not meet the following conditions, a slight performance penalty will occur.
num_stages = 4
assert k / 128 % num_stages == 0
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
assert best_smem_size <= sm90_capacity
best_num_stages = num_stages
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
assert best_num_stages is not None
# Decide the number of TMA multicast
best_num_tma_multicast = 1
# When using large block tiling, broadcasting B is required to achieve maximum performance gains.
if best_block_n != 160:
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
else:
if m >= 4096 and is_tma_multicast_legal(m, best_block_m, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast
if best_block_n != 160:
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
else:
assert num_min_sms <= num_sms and is_tma_multicast_legal(m, best_block_m, best_num_tma_multicast, num_min_sms)
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
@@ -198,7 +163,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
space=(),
includes=includes,
@@ -212,5 +177,3 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size

View File

@@ -11,14 +11,14 @@ template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto M = {M}, N = {N}, K = {K};
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated grouped GEMM
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
space=(),
includes=includes,
@@ -106,8 +106,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -171,7 +169,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
space=(),
includes=includes,
@@ -186,5 +184,3 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size