mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
[wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -1,41 +1,12 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||
constexpr auto kNumGroups = {NUM_GROUPS};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': m_indices,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedContiguous'},
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and m > block_m:
|
||||
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': masked_m,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedMasked'},
|
||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user