Performance: Configuration algorithms tuned to minimize the impact of tail effects, now up to 1402 TFLOPS

This commit is contained in:
sazc 2025-03-10 11:45:05 +08:00
parent 9d3222a93e
commit 50cf26cc7c
2 changed files with 33 additions and 30 deletions

View File

@ -6,6 +6,7 @@ from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
# includes = ('"deep_gemm/fp8_gemm_inter_wg.cuh"', )
template = """
using namespace deep_gemm;
@ -27,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
stream, n_block, smem_size);
"""
@ -63,28 +64,30 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8))
n_blocks = tuple(range(124, num_sms+1, 2))
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
fix_wave_saturate = lambda x, nb: nb if x == 0 else x
get_num_waves = lambda bm, bn, nb: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, nb) if bm else None)
get_last_wave_util = lambda bm, bn, nb: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % nb, nb)
# Decide block sizes by waves
best_block_m, best_block_n = None, None
best_block_m, best_block_n, best_n_block = None, None, None
for block_m in block_ms:
for block_n in block_ns:
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None
for n_block in n_blocks:
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n, n_block), get_num_waves(best_block_m, best_block_n, best_n_block)
if best_block_m is None or best_block_n is None or best_n_block is None:
success = True
elif num_waves < best_num_waves:
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n, n_block) + num_waves * n_block / num_sms
best_util = get_last_wave_util(best_block_m, best_block_n, best_n_block) + best_num_waves * best_n_block / num_sms
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n > best_block_n) or (block_m == best_block_m and block_n == best_block_n and n_block < best_n_block)))
best_block_m, best_block_n, best_n_block= (block_m, block_n, n_block) if success else (best_block_m, best_block_n, best_n_block)
assert best_block_m is not None and best_block_n is not None and best_n_block is not None
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
@ -98,10 +101,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Decide the number of TMA multicast
best_num_tma_multicast = 1
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, best_n_block) and num_groups == 1:
best_num_tma_multicast = 2
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
return best_block_m, best_block_n, best_n_block, best_num_stages, best_num_tma_multicast, best_smem_size
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -151,8 +154,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), n_block, smem_size)
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
@ -162,7 +165,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
template=template,
args=args
)

View File

@ -28,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
stream, n_block, smem_size);
"""
@ -84,11 +84,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_size)
torch.cuda.current_stream(), n_block, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
@ -99,7 +99,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
template=template,
args=args
)
@ -158,7 +158,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
@ -166,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
torch.cuda.current_stream(), num_sms, smem_size)
torch.cuda.current_stream(), n_block, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
@ -177,7 +177,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
template=template,
args=args
)