From 50cf26cc7cfbf2d6b243bfce995cef976969a880 Mon Sep 17 00:00:00 2001 From: sazc Date: Mon, 10 Mar 2025 11:45:05 +0800 Subject: [PATCH] Performance: Configuration algorithms tuned to minimize the impact of tail effects, now up to 1402 TFLOPS --- deep_gemm/jit_kernels/gemm.py | 49 +++++++++++++------------ deep_gemm/jit_kernels/m_grouped_gemm.py | 14 +++---- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 1ba413a..302db88 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -6,6 +6,7 @@ from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_ # C++ code templates includes = ('"deep_gemm/fp8_gemm.cuh"', ) +# includes = ('"deep_gemm/fp8_gemm_inter_wg.cuh"', ) template = """ using namespace deep_gemm; @@ -27,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); GemmType::run(out, rhs_scales, nullptr, m, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); + stream, n_block, smem_size); """ @@ -63,28 +64,30 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, else: block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) + n_blocks = tuple(range(124, num_sms+1, 2)) - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) + fix_wave_saturate = lambda x, nb: nb if x == 0 else x + get_num_waves = lambda bm, bn, nb: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, nb) if bm else None) + get_last_wave_util = lambda bm, bn, nb: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % nb, nb) # Decide block sizes by waves - best_block_m, best_block_n = None, None + best_block_m, best_block_n, best_n_block = None, None, None for block_m in block_ms: for block_n in block_ns: - success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) - if best_block_m is None or best_block_n is None: - success = True - elif num_waves < best_num_waves: - success = True - elif num_waves == best_num_waves: - # Check last wave utilization - util = get_last_wave_util(block_m, block_n) - best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) - assert best_block_m is not None and best_block_n is not None + for n_block in n_blocks: + success = False + num_waves, best_num_waves = get_num_waves(block_m, block_n, n_block), get_num_waves(best_block_m, best_block_n, best_n_block) + if best_block_m is None or best_block_n is None or best_n_block is None: + success = True + elif num_waves < best_num_waves: + success = True + elif num_waves == best_num_waves: + # Check last wave utilization + util = get_last_wave_util(block_m, block_n, n_block) + num_waves * n_block / num_sms + best_util = get_last_wave_util(best_block_m, best_block_n, best_n_block) + best_num_waves * best_n_block / num_sms + success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n > best_block_n) or (block_m == best_block_m and block_n == best_block_n and n_block < best_n_block))) + best_block_m, best_block_n, best_n_block= (block_m, block_n, n_block) if success else (best_block_m, best_block_n, best_n_block) + assert best_block_m is not None and best_block_n is not None and best_n_block is not None # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced @@ -98,10 +101,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Decide the number of TMA multicast best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: + if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, best_n_block) and num_groups == 1: best_num_tma_multicast = 2 - return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + return best_block_m, best_block_n, best_n_block, best_num_stages, best_num_tma_multicast, best_smem_size def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -151,8 +154,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) + block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), n_block, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, @@ -162,7 +165,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), ('out', torch.bfloat16), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + ('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)), template=template, args=args ) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 97fb636..b45f50a 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -28,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); GemmType::run(out, rhs_scales, grouped_layout, m, tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); + stream, n_block, smem_size); """ @@ -84,11 +84,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, + block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) args = (lhs, lhs_scales, rhs, rhs_scales, out, m_indices, m, num_groups, - torch.cuda.current_stream(), num_sms, smem_size) + torch.cuda.current_stream(), n_block, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, @@ -99,7 +99,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), ('out', torch.bfloat16), ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + ('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)), template=template, args=args ) @@ -158,7 +158,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) # Extra checks for TMA store if num_groups > 1 and m > block_m: @@ -166,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] args = (lhs, lhs_scales, rhs, rhs_scales, out, masked_m, m, - torch.cuda.current_stream(), num_sms, smem_size) + torch.cuda.current_stream(), n_block, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, @@ -177,7 +177,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), ('out', torch.bfloat16), ('grouped_layout', torch.int32), ('m', int), - ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), + ('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)), template=template, args=args )