Refactor launch-related structures

This commit is contained in:
Chenggang Zhao 2025-05-15 16:14:21 +08:00
parent e2d6a107ef
commit 816b39053a
9 changed files with 199 additions and 396 deletions

View File

@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful:
- Post optimization
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
- Heuristic selection
- `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default
- `DG_PRINT_HEURISTIC`: `0` or `1`, print selected configs for each shape, `0` by default
- Testing
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default

View File

@ -18,7 +18,7 @@ namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kLastStages,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
@ -127,7 +127,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [&](const auto& func) {
if constexpr (kLastStages == 0) {
if constexpr (kNumLastStages == 0) {
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
@ -155,7 +155,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
@ -244,7 +244,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll

View File

@ -8,11 +8,10 @@ from torch.utils.cpp_extension import CUDA_HOME
class Runtime:
def __init__(self, path: str, args: List[str] = None) -> None:
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.kernel = None
self.args = args
assert self.is_path_valid(self.path)
@staticmethod
@ -48,8 +47,10 @@ class Runtime:
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
illegal_names = ['vprintf', '__instantiate_kernel', '__internal']
check_illegal = lambda line: any([name in line for name in illegal_names])
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
if line.startswith('STT_FUNC') and not check_illegal(line)]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
# Load kernel from the library
@ -62,7 +63,7 @@ class Runtime:
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
# noinspection PyArgumentList
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
return self.launch(self.kernel, kwargs)
def __del__(self) -> None:
if self.lib is not None:

View File

@ -3,11 +3,11 @@ import torch
from functools import lru_cache
from typing import Tuple
from ..jit import build
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@ -18,7 +18,6 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
def get_swizzle_mode(block_n: int) -> int:
# TODO: remove some candidates if slow
elem_size = 2
for mode_bytes in (128, 64, 32):
if (block_n * elem_size) % mode_bytes == 0:
@ -187,13 +186,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert out.dtype == torch.bfloat16
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
@ -208,29 +200,30 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, k, 1, num_sms)
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
kwargs = {
# Templated arguments
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'M': m, 'N': n, 'K': aligned_k,
'NUM_GROUPS': 1,
'BLOCK_K': block_k,
'GMEM_D': out,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
# Runtime arguments
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
'NUM_SMS': num_sms,
@ -240,21 +233,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': aligned_k,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(**kwargs)
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
runtime(**kwargs)

View File

@ -1,12 +1,12 @@
import torch
from typing import Tuple
from ..jit import build
from .gemm import get_best_configs
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
@ -69,21 +69,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
# Templated arguments
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'M': m, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedContiguous,
# Runtime arguments
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': m_indices,
'NUM_SMS': num_sms,
@ -93,25 +97,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedContiguous},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(**kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
runtime(**kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -176,21 +168,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
# Templated arguments
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'M': m, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedMasked,
# Runtime arguments
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': masked_m,
'NUM_SMS': num_sms,
@ -200,22 +196,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedMasked},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(**kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
runtime(**kwargs)

View File

@ -5,14 +5,10 @@ import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict, Tuple
from .utils import get_tma_aligned_size
from ..jit.runtime import Runtime
class Layout(enum.Enum):
RowMajor = 0
ColMajor = 1
class GemmType(enum.Enum):
Normal = 0
GroupedContiguous = 1
@ -61,19 +57,18 @@ def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor,
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t],
stride_in_bytes: cbd.cuuint64_t,
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
def make_2d_tma_copy_desc(t: torch.Tensor,
gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
tensor_dtype = tmap_type_map[global_address.dtype]
tensor_dtype = tmap_type_map[t.dtype]
res, tensor_map = cbd.cuTensorMapEncodeTiled(
tensor_dtype,
2,
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes, ),
smem_dim,
t.data_ptr(),
gmem_dims,
(gmem_outer_stride,),
smem_dims,
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
@ -86,90 +81,61 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
gmem_rows: int, gmem_cols: int, gmem_stride: int,
smem_rows: int, smem_cols: int,
def make_2d_tma_desc(t: torch.Tensor,
gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
smem_inner_dim: int, smem_outer_dim: int,
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_k: int,
def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_k: int, m_stride: int,
block_m: int, block_k: int,
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
a_stride = shape_k if a_stride == 0 else a_stride
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
block_m, block_k)
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
block_k, block_m)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_k: int, shape_n: int,
block_k: int, block_n: int,
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
b_stride = shape_k if b_stride == 0 else b_stride
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
shape_n: int, shape_k: int, n_stride: int,
block_n: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_n: int,
def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_n: int, m_stride: int,
block_m: int, block_n: int,
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
num_groups: int,
swizzle_mode: int) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
d_stride = shape_n if d_stride == 0 else d_stride
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
return make_2d_tma_desc(t,
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
swizzle_type_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
shape_mn: int, shape_k: int,
block_mn: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size()
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size()
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
return make_2d_tma_desc(t,
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
block_mn, 1,
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
super().__init__(path)
@staticmethod
def generate(**kwargs) -> str:
@ -213,21 +179,16 @@ static void __instantiate_kernel() {{
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
@ -237,23 +198,23 @@ static void __instantiate_kernel() {{
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = kwargs['STREAM']
arg_values = (
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
kwargs['SCALES_B'].data_ptr(),
kwargs['GROUPED_LAYOUT'].data_ptr(),
kwargs['M'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
)
arg_types = (
ctypes.c_void_p,
@ -269,20 +230,7 @@ static void __instantiate_kernel() {{
class FP8WGradGemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'NUM_TMA_MULTICAST',
'K',
'BLOCK_M',
'GMEM_D',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_SCALES_B',
'TENSOR_MAP_D',
'STREAM',
])
super().__init__(path)
@staticmethod
def generate(**kwargs) -> str:
@ -309,7 +257,7 @@ static void __instantiate_kernel() {{
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['NUM_STAGES']},
{kwargs['LAST_STAGES']},
{kwargs['NUM_LAST_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
@ -323,21 +271,16 @@ static void __instantiate_kernel() {{
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
@ -347,22 +290,22 @@ static void __instantiate_kernel() {{
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = kwargs['STREAM']
arg_values = (
shape_k,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_scales_b,
tensor_map_d,
kwargs['K'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_SCALES_B'],
kwargs['TENSOR_MAP_D'],
)
arg_types = (
ctypes.c_uint32,

View File

@ -1,82 +0,0 @@
import copy
import os
import torch
import cuda.bindings.driver as cbd
from typing import Any, Callable, Dict, Type, Tuple
from ..jit import build, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
assert kwargs is not None
space = (dict(), ) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = runtime_cls.generate(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(**tuned_keys, **kwargs)
if return_code != cbd.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g., insufficient shared memory capacity
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int,
device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = (best_runtime, best_keys)
return best_runtime, best_keys
jit_tuner = JITTuner()

View File

@ -1,18 +1,18 @@
import torch
from typing import List, Tuple
from ..jit import build
from .runtime import (
FP8WGradGemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .gemm import get_best_configs
from .tuner import jit_tuner
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: Tuple[torch.Tensor, torch.Tensor]):
out: torch.Tensor):
"""
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
@ -21,8 +21,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format.
If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
@ -47,13 +47,6 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m):
@ -81,30 +74,30 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
last_stages = (k + 127) // 128 % num_stages
num_last_stages = (k + 127) // 128 % num_stages
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
tensor_map_scales_b = make_2d_tma_scales_b_desc(
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1)
kwargs = {
# Templated arguments
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'K': aligned_k,
'M': m, 'N': n, 'K': aligned_k,
'NUM_GROUPS': 1,
'BLOCK_K': block_k,
'GMEM_D': out,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_STAGES': num_stages,
'NUM_LAST_STAGES': num_last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
# Runtime arguments
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
@ -113,23 +106,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='wgrad_gemm_fp8_fp8_fp32_nt',
keys={'M': m, 'N': n,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages,
'LAST_STAGES': last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8WGradGemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
# Generate, build and run the kernel
code = FP8WGradGemmRuntime.generate(**kwargs)
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime)
runtime(**kwargs)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -144,16 +127,16 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format.
Arguments:
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
the second element is an FP32 scaling tensor for LHS with shape `[k / 128 for k in batch_sizes), m]`,
The second element is an FP32 scaling tensor for LHS with shape `[k / 128 for k in batch_sizes), m]`,
representing the per-128-channel scaling factors.
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
the second element is an FP32 scaling tensor for RHS with shape `[k / 128 for k in batch_sizes), n]`,
The second element is an FP32 scaling tensor for RHS with shape `[k / 128 for k in batch_sizes), n]`,
representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
batch_sizes: A list of integers specifying the k-dimension for each batch.
@ -164,16 +147,14 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
for idx in range(num_batches):
k = batch_sizes[idx]
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
D = out[idx]
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
for i in range(num_batches):
k = batch_sizes[i]
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
lhs_offset += m * k
rhs_offset += n * k
scales_offset += (k + 127) // 128
scales_offset += (k + 127) // 128

View File

@ -2,6 +2,7 @@ import ctypes
import os
import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict
from deep_gemm import jit
@ -12,12 +13,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'A',
'B',
'C',
'STREAM',
])
super().__init__(path)
@staticmethod
def generate(**kwargs) -> str:
@ -46,27 +42,25 @@ static void __instantiate_kernel() {{
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cbd.CUstream) -> cbd.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
assert kwargs['A'].dim() == 1
config = cbd.CUlaunchConfig()
config.gridDimX = (a.numel() + 127) // 128
config.gridDimX = (kwargs['A'].numel() + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
config.hStream = kwargs['STREAM']
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.numel(),
kwargs['A'].data_ptr(),
kwargs['B'].data_ptr(),
kwargs['C'].data_ptr(),
kwargs['A'].numel(),
)
arg_types = (
ctypes.c_void_p,