[wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 08:08:40 +00:00
parent 891f35adf5
commit 27cd276e19
13 changed files with 581 additions and 323 deletions

View File

@@ -1,41 +1,12 @@
import torch
from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
constexpr auto kNumGroups = {NUM_GROUPS};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated grouped GEMM
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
gemm_t::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
@@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
return
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, k, 1, num_sms, is_grouped_contiguous=True)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': m_indices,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
@@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedContiguous'},
'GEMM_TYPE': GemmType.GroupedContiguous},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
kwargs=kwargs,
)
# Run the kernel
runtime(*args)
runtime(**best_keys, **kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
torch.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': masked_m,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
@@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedMasked'},
'GEMM_TYPE': GemmType.GroupedMasked},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
kwargs=kwargs,
)
# Run the kernel
runtime(*args)
runtime(**best_keys, **kwargs)