Support block size 160

This commit is contained in:
Chenggang Zhao
2025-03-25 13:37:59 +08:00
parent 46eb0d08fb
commit b922e64cb2
4 changed files with 188 additions and 879 deletions

View File

@@ -11,14 +11,14 @@ template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto M = {M}, N = {N}, K = {K};
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated grouped GEMM
using GemmType = Gemm<M, N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
space=(),
includes=includes,
@@ -106,8 +106,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -171,7 +169,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
space=(),
includes=includes,
@@ -186,5 +184,3 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Run the kernel
runtime(*args)
# For debug
return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size