mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
[wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -3,40 +3,11 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||
constexpr auto kNumGroups = 1;
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated GEMM
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
@@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
@@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k %
|
||||
block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
def fix_wave_saturate(x): return num_sms if x == 0 else x
|
||||
|
||||
def get_num_waves(bm, bn): return (ceil_div(
|
||||
ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
|
||||
def get_last_wave_util(bm, bn): return fix_wave_saturate(
|
||||
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
@@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# NOTES: the block sizes can not be too large, so at least one dim less than 128
|
||||
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
num_waves, best_num_waves = get_num_waves(
|
||||
block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
@@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
success |= block_n == best_block_n and block_m < best_block_m
|
||||
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
||||
success |= block_m != best_block_m and block_n > best_block_n
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (
|
||||
best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
@@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = (4, 3)
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||
best_smem_config = get_smem_config(
|
||||
num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_config[0] <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
@@ -158,8 +138,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Recompute the minimal number of SMs required
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) *
|
||||
ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(
|
||||
num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
assert num_min_sms <= num_sms
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
@@ -210,11 +192,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
|
||||
kwargs = {
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@@ -223,14 +236,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user