mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-10 18:55:28 +00:00
Performance: Configuration algorithms tuned to minimize the impact of tail effects, now up to 1402 TFLOPS
This commit is contained in:
parent
9d3222a93e
commit
50cf26cc7c
@ -6,6 +6,7 @@ from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_
|
|||||||
|
|
||||||
# C++ code templates
|
# C++ code templates
|
||||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||||
|
# includes = ('"deep_gemm/fp8_gemm_inter_wg.cuh"', )
|
||||||
template = """
|
template = """
|
||||||
using namespace deep_gemm;
|
using namespace deep_gemm;
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
|||||||
GemmType::run(out, rhs_scales, nullptr,
|
GemmType::run(out, rhs_scales, nullptr,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||||
stream, num_sms, smem_size);
|
stream, n_block, smem_size);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -63,28 +64,30 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
else:
|
else:
|
||||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||||
block_ns = tuple(range(16, 129, 8))
|
block_ns = tuple(range(16, 129, 8))
|
||||||
|
n_blocks = tuple(range(124, num_sms+1, 2))
|
||||||
|
|
||||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
fix_wave_saturate = lambda x, nb: nb if x == 0 else x
|
||||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
get_num_waves = lambda bm, bn, nb: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, nb) if bm else None)
|
||||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
get_last_wave_util = lambda bm, bn, nb: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % nb, nb)
|
||||||
|
|
||||||
# Decide block sizes by waves
|
# Decide block sizes by waves
|
||||||
best_block_m, best_block_n = None, None
|
best_block_m, best_block_n, best_n_block = None, None, None
|
||||||
for block_m in block_ms:
|
for block_m in block_ms:
|
||||||
for block_n in block_ns:
|
for block_n in block_ns:
|
||||||
|
for n_block in n_blocks:
|
||||||
success = False
|
success = False
|
||||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
num_waves, best_num_waves = get_num_waves(block_m, block_n, n_block), get_num_waves(best_block_m, best_block_n, best_n_block)
|
||||||
if best_block_m is None or best_block_n is None:
|
if best_block_m is None or best_block_n is None or best_n_block is None:
|
||||||
success = True
|
success = True
|
||||||
elif num_waves < best_num_waves:
|
elif num_waves < best_num_waves:
|
||||||
success = True
|
success = True
|
||||||
elif num_waves == best_num_waves:
|
elif num_waves == best_num_waves:
|
||||||
# Check last wave utilization
|
# Check last wave utilization
|
||||||
util = get_last_wave_util(block_m, block_n)
|
util = get_last_wave_util(block_m, block_n, n_block) + num_waves * n_block / num_sms
|
||||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
best_util = get_last_wave_util(best_block_m, best_block_n, best_n_block) + best_num_waves * best_n_block / num_sms
|
||||||
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
|
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n > best_block_n) or (block_m == best_block_m and block_n == best_block_n and n_block < best_n_block)))
|
||||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
best_block_m, best_block_n, best_n_block= (block_m, block_n, n_block) if success else (best_block_m, best_block_n, best_n_block)
|
||||||
assert best_block_m is not None and best_block_n is not None
|
assert best_block_m is not None and best_block_n is not None and best_n_block is not None
|
||||||
|
|
||||||
# Always pick the longest one
|
# Always pick the longest one
|
||||||
# NOTES: for double B scales, the best number of stages may be reduced
|
# NOTES: for double B scales, the best number of stages may be reduced
|
||||||
@ -98,10 +101,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
|
|
||||||
# Decide the number of TMA multicast
|
# Decide the number of TMA multicast
|
||||||
best_num_tma_multicast = 1
|
best_num_tma_multicast = 1
|
||||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, best_n_block) and num_groups == 1:
|
||||||
best_num_tma_multicast = 2
|
best_num_tma_multicast = 2
|
||||||
|
|
||||||
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
return best_block_m, best_block_n, best_n_block, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||||
|
|
||||||
|
|
||||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@ -151,8 +154,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), n_block, smem_size)
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='gemm_fp8_fp8_bf16_nt',
|
name='gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
@ -162,7 +165,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||||
('out', torch.bfloat16), ('m', int),
|
('out', torch.bfloat16), ('m', int),
|
||||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
|
||||||
template=template,
|
template=template,
|
||||||
args=args
|
args=args
|
||||||
)
|
)
|
||||||
|
@ -28,7 +28,7 @@ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
|||||||
GemmType::run(out, rhs_scales, grouped_layout,
|
GemmType::run(out, rhs_scales, grouped_layout,
|
||||||
m,
|
m,
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||||
stream, num_sms, smem_size);
|
stream, n_block, smem_size);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -84,11 +84,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
||||||
is_grouped_contiguous=True)
|
is_grouped_contiguous=True)
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
m_indices, m, num_groups,
|
m_indices, m, num_groups,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), n_block, smem_size)
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||||
@ -99,7 +99,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||||
('out', torch.bfloat16),
|
('out', torch.bfloat16),
|
||||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
|
||||||
template=template,
|
template=template,
|
||||||
args=args
|
args=args
|
||||||
)
|
)
|
||||||
@ -158,7 +158,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
block_m, block_n, n_block, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||||
|
|
||||||
# Extra checks for TMA store
|
# Extra checks for TMA store
|
||||||
if num_groups > 1 and m > block_m:
|
if num_groups > 1 and m > block_m:
|
||||||
@ -166,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
|
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
masked_m, m,
|
masked_m, m,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), n_block, smem_size)
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||||
@ -177,7 +177,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||||
('out', torch.bfloat16),
|
('out', torch.bfloat16),
|
||||||
('grouped_layout', torch.int32), ('m', int),
|
('grouped_layout', torch.int32), ('m', int),
|
||||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
('stream', torch.cuda.Stream), ('n_block', int), ('smem_size', int)),
|
||||||
template=template,
|
template=template,
|
||||||
args=args
|
args=args
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user