mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Support block size 160
This commit is contained in:
@@ -10,14 +10,14 @@ template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto M = {M}, N = {N}, K = {K};
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
@@ -62,14 +62,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
# Current optimizations target large-scale Normal GEMMs for dense models and
|
||||
# Grouped GEMMs for MoE models (contiguous memory layout), with a potential
|
||||
# block_n tile size of 160 to enhance data reuse in block tiling.
|
||||
if m >= 4096 and num_groups == 1:
|
||||
block_ns = tuple((16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 160))
|
||||
else:
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
block_ns = tuple(range(16, 129, 8)) + (160, )
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
@@ -84,16 +77,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
# The value 0.02 is currently an empirically estimated threshold to
|
||||
# filter out cases unsuitable for large block tile configurations,
|
||||
# with optimizations planned for later stages to address excluded scenarios.
|
||||
if block_n == 160 and \
|
||||
(num_waves * block_m * block_n - best_num_waves * best_block_m * best_block_n) / (best_num_waves * best_block_m * best_block_n) < 0.02:
|
||||
success = True
|
||||
elif block_n == 160:
|
||||
success = False
|
||||
else:
|
||||
success = True
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
@@ -105,44 +89,25 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
if best_block_n != 160:
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
else:
|
||||
# NOTES: This is done to reduce the code footprint after unrolling.
|
||||
# Additionally, if k does not meet the following conditions, a slight performance penalty will occur.
|
||||
num_stages = 4
|
||||
assert k / 128 % num_stages == 0
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
assert best_smem_size <= sm90_capacity
|
||||
best_num_stages = num_stages
|
||||
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
# When using large block tiling, broadcasting B is required to achieve maximum performance gains.
|
||||
if best_block_n != 160:
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
else:
|
||||
if m >= 4096 and is_tma_multicast_legal(m, best_block_m, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
# Recompute the minimal number of SMs required
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast
|
||||
if best_block_n != 160:
|
||||
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
|
||||
else:
|
||||
assert num_min_sms <= num_sms and is_tma_multicast_legal(m, best_block_m, best_num_tma_multicast, num_min_sms)
|
||||
|
||||
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||
|
||||
@@ -198,7 +163,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
||||
space=(),
|
||||
includes=includes,
|
||||
@@ -212,5 +177,3 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
# For debug
|
||||
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size
|
||||
|
||||
@@ -11,14 +11,14 @@ template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto M = {M}, N = {N}, K = {K};
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
@@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
@@ -106,8 +106,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
# For debug
|
||||
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -171,7 +169,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
@@ -186,5 +184,3 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
# For debug
|
||||
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size
|
||||
|
||||
Reference in New Issue
Block a user