mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-03 12:40:55 +00:00
Fix grouped GEMM cases
This commit is contained in:
parent
7768319ffe
commit
046fab64b7
@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
|
|||||||
constexpr auto BLOCK_N = {BLOCK_N};
|
constexpr auto BLOCK_N = {BLOCK_N};
|
||||||
constexpr auto kNumStages = {NUM_STAGES};
|
constexpr auto kNumStages = {NUM_STAGES};
|
||||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||||
|
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||||
|
|
||||||
// Make a templated grouped GEMM
|
// Make a templated grouped GEMM
|
||||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||||
|
|
||||||
// Launch kernel
|
// Launch kernel
|
||||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||||
@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||||
is_grouped_contiguous=True)
|
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
m_indices, m, num_groups,
|
m_indices, m, num_groups,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), num_sms, smem_size)
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||||
|
'GEMM_TYPE': 'GroupedContiguous'},
|
||||||
space=(),
|
space=(),
|
||||||
includes=includes,
|
includes=includes,
|
||||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||||
@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||||
|
|
||||||
# Extra checks for TMA store
|
# Extra checks for TMA store
|
||||||
if num_groups > 1 and m > block_m:
|
if num_groups > 1 and m > block_m:
|
||||||
@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||||
|
'GEMM_TYPE': 'GroupedMasked'},
|
||||||
space=(),
|
space=(),
|
||||||
includes=includes,
|
includes=includes,
|
||||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||||
|
Loading…
Reference in New Issue
Block a user