Performance: Configuration algorithms tuned to minimize the impact of tail effects, now up to 1402 TFLOPS

This commit is contained in:
sazc
2025-03-10 11:45:05 +08:00
parent 9d3222a93e
commit 50cf26cc7c
2 changed files with 33 additions and 30 deletions

View File

@@ -28,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
stream, n_block, smem_size);
"""
@@ -84,11 +84,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_size)
torch.cuda.current_stream(), n_block, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
@@ -99,7 +99,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
template=template,
args=args
)
@@ -158,7 +158,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
@@ -166,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
torch.cuda.current_stream(), num_sms, smem_size)
torch.cuda.current_stream(), n_block, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
@@ -177,7 +177,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
template=template,
args=args
)