mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
[wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -3,40 +3,11 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||
constexpr auto kNumGroups = 1;
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated GEMM
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
@@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
@@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k %
|
||||
block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
def fix_wave_saturate(x): return num_sms if x == 0 else x
|
||||
|
||||
def get_num_waves(bm, bn): return (ceil_div(
|
||||
ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
|
||||
def get_last_wave_util(bm, bn): return fix_wave_saturate(
|
||||
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
@@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# NOTES: the block sizes can not be too large, so at least one dim less than 128
|
||||
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
num_waves, best_num_waves = get_num_waves(
|
||||
block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
@@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
success |= block_n == best_block_n and block_m < best_block_m
|
||||
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
||||
success |= block_m != best_block_m and block_n > best_block_n
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (
|
||||
best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
@@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = (4, 3)
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||
best_smem_config = get_smem_config(
|
||||
num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_config[0] <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
@@ -158,8 +138,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Recompute the minimal number of SMs required
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) *
|
||||
ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(
|
||||
num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
assert num_min_sms <= num_sms
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
@@ -210,11 +192,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
|
||||
kwargs = {
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@@ -223,14 +236,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
@@ -1,41 +1,12 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||
constexpr auto kNumGroups = {NUM_GROUPS};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': m_indices,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedContiguous'},
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and m > block_m:
|
||||
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': masked_m,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedMasked'},
|
||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
@@ -3,15 +3,16 @@ import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from ..jit import build, generate, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
@@ -26,7 +27,7 @@ class JITTuner:
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert args is not None
|
||||
assert kwargs is not None
|
||||
space = (dict(), ) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
@@ -34,30 +35,31 @@ class JITTuner:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
code = generate(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code), full_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(*args)
|
||||
if return_code != 0:
|
||||
return_code = runtime(**tuned_keys, **kwargs)
|
||||
if return_code != cuda.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
print(
|
||||
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int,
|
||||
device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
|
||||
(8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
@@ -68,14 +70,16 @@ class JITTuner:
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
print(
|
||||
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
print(
|
||||
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = (best_runtime, best_keys)
|
||||
return best_runtime, best_keys
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
|
||||
Reference in New Issue
Block a user