mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor launch-related structures
This commit is contained in:
parent
e2d6a107ef
commit
816b39053a
@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful:
|
||||
- Post optimization
|
||||
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
|
||||
- Heuristic selection
|
||||
- `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||
- `DG_PRINT_HEURISTIC`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||
- Testing
|
||||
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumStages, uint32_t kLastStages,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
@ -127,7 +127,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
if constexpr (kLastStages == 0) {
|
||||
if constexpr (kNumLastStages == 0) {
|
||||
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
@ -155,7 +155,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
@ -244,7 +244,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@ -8,11 +8,10 @@ from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, args: List[str] = None) -> None:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
@ -48,8 +47,10 @@ class Runtime:
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
illegal_names = ['vprintf', '__instantiate_kernel', '__internal']
|
||||
check_illegal = lambda line: any([name in line for name in illegal_names])
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
|
||||
if line.startswith('STT_FUNC') and not check_illegal(line)]
|
||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||
|
||||
# Load kernel from the library
|
||||
@ -62,7 +63,7 @@ class Runtime:
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
||||
|
||||
# noinspection PyArgumentList
|
||||
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
|
||||
return self.launch(self.kernel, kwargs)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
|
||||
@ -3,11 +3,11 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit import build
|
||||
from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
||||
from .tuner import jit_tuner
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
|
||||
|
||||
|
||||
def get_swizzle_mode(block_n: int) -> int:
|
||||
# TODO: remove some candidates if slow
|
||||
elem_size = 2
|
||||
for mode_bytes in (128, 64, 32):
|
||||
if (block_n * elem_size) % mode_bytes == 0:
|
||||
@ -187,13 +186,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
lhs_stride = lhs.stride(0)
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
@ -208,29 +200,30 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'M': m, 'N': n, 'K': aligned_k,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
# Runtime arguments
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
||||
'NUM_SMS': num_sms,
|
||||
@ -240,21 +233,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': aligned_k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(**best_keys, **kwargs)
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(**kwargs)
|
||||
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
|
||||
runtime(**kwargs)
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit import build
|
||||
from .gemm import get_best_configs
|
||||
from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
||||
from .tuner import jit_tuner
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
|
||||
@ -69,21 +69,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'M': m, 'N': n, 'K': k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous,
|
||||
# Runtime arguments
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': m_indices,
|
||||
'NUM_SMS': num_sms,
|
||||
@ -93,25 +97,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(**best_keys, **kwargs)
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(**kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -176,21 +168,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'M': m, 'N': n, 'K': k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': GemmType.GroupedMasked,
|
||||
# Runtime arguments
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': masked_m,
|
||||
'NUM_SMS': num_sms,
|
||||
@ -200,22 +196,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(**best_keys, **kwargs)
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(**kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
|
||||
runtime(**kwargs)
|
||||
|
||||
@ -5,14 +5,10 @@ import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from .utils import get_tma_aligned_size
|
||||
from ..jit.runtime import Runtime
|
||||
|
||||
|
||||
class Layout(enum.Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
|
||||
|
||||
class GemmType(enum.Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
@ -61,19 +57,18 @@ def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
||||
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t],
|
||||
stride_in_bytes: cbd.cuuint64_t,
|
||||
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||
def make_2d_tma_copy_desc(t: torch.Tensor,
|
||||
gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
|
||||
smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
|
||||
tensor_dtype = tmap_type_map[global_address.dtype]
|
||||
tensor_dtype = tmap_type_map[t.dtype]
|
||||
res, tensor_map = cbd.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2,
|
||||
global_address.data_ptr(),
|
||||
gmem_dim,
|
||||
(stride_in_bytes, ),
|
||||
smem_dim,
|
||||
t.data_ptr(),
|
||||
gmem_dims,
|
||||
(gmem_outer_stride,),
|
||||
smem_dims,
|
||||
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
|
||||
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
@ -86,90 +81,61 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
|
||||
gmem_rows: int, gmem_cols: int, gmem_stride: int,
|
||||
smem_rows: int, smem_cols: int,
|
||||
def make_2d_tma_desc(t: torch.Tensor,
|
||||
gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
|
||||
smem_inner_dim: int, smem_outer_dim: int,
|
||||
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
|
||||
smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
|
||||
return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_k: int,
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_m: int, shape_k: int, m_stride: int,
|
||||
block_m: int, block_k: int,
|
||||
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
|
||||
a_stride = shape_k if a_stride == 0 else a_stride
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
|
||||
block_m, block_k)
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(t,
|
||||
shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
block_k, block_m)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_k: int, shape_n: int,
|
||||
block_k: int, block_n: int,
|
||||
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
|
||||
b_stride = shape_k if b_stride == 0 else b_stride
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_n: int, shape_k: int, n_stride: int,
|
||||
block_n: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(t,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
|
||||
block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_n: int,
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_m: int, shape_n: int, m_stride: int,
|
||||
block_m: int, block_n: int,
|
||||
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
|
||||
num_groups: int,
|
||||
swizzle_mode: int) -> cbd.CUtensorMap:
|
||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
d_stride = shape_n if d_stride == 0 else d_stride
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
|
||||
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
|
||||
return make_2d_tma_desc(t,
|
||||
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
|
||||
swizzle_type_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
||||
def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_mn: int, shape_k: int,
|
||||
block_mn: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
tma_alignment = 16 / global_address.element_size()
|
||||
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
|
||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
tma_alignment = 16 / global_address.element_size()
|
||||
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
|
||||
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
|
||||
return make_2d_tma_desc(t,
|
||||
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
|
||||
block_mn, 1,
|
||||
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
@ -213,21 +179,16 @@ static void __instantiate_kernel() {{
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
||||
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
||||
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
@ -237,23 +198,23 @@ static void __instantiate_kernel() {{
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimX = kwargs['NUM_SMS']
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
scales_b.data_ptr(),
|
||||
grouped_layout.data_ptr(),
|
||||
shape_m,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_d,
|
||||
kwargs['SCALES_B'].data_ptr(),
|
||||
kwargs['GROUPED_LAYOUT'].data_ptr(),
|
||||
kwargs['M'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
@ -269,20 +230,7 @@ static void __instantiate_kernel() {{
|
||||
|
||||
class FP8WGradGemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'K',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_SCALES_B',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
@ -309,7 +257,7 @@ static void __instantiate_kernel() {{
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['LAST_STAGES']},
|
||||
{kwargs['NUM_LAST_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
@ -323,21 +271,16 @@ static void __instantiate_kernel() {{
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
|
||||
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
|
||||
tensor_map_d: cbd.CUtensorMap,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
@ -347,22 +290,22 @@ static void __instantiate_kernel() {{
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimX = kwargs['NUM_SMS']
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
shape_k,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_scales_b,
|
||||
tensor_map_d,
|
||||
kwargs['K'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_SCALES_B'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_uint32,
|
||||
|
||||
@ -1,82 +0,0 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Callable, Dict, Type, Tuple
|
||||
|
||||
from ..jit import build, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
||||
# NOTES: we always assume the space, template and GPU devices will not change
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f'{keys}')
|
||||
if signature in self.tuned:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
||||
return self.tuned[signature]
|
||||
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert kwargs is not None
|
||||
space = (dict(), ) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
for tuned_keys in space:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = runtime_cls.generate(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code, runtime_cls), full_keys))
|
||||
|
||||
# TODO: fix tuning with space > 1
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(**tuned_keys, **kwargs)
|
||||
if return_code != cbd.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g., insufficient shared memory capacity
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int,
|
||||
device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
|
||||
(8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
else:
|
||||
elapsed_time = 0
|
||||
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
|
||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = (best_runtime, best_keys)
|
||||
return best_runtime, best_keys
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
@ -1,18 +1,18 @@
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from ..jit import build
|
||||
from .runtime import (
|
||||
FP8WGradGemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||
|
||||
|
||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: Tuple[torch.Tensor, torch.Tensor]):
|
||||
out: torch.Tensor):
|
||||
"""
|
||||
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
Results will be accumulated into the output tensor.
|
||||
@ -21,8 +21,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format.
|
||||
If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@ -47,13 +47,6 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
assert out.dtype == torch.float
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
lhs_stride = lhs.stride(0)
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
|
||||
|
||||
# LHS and RHS scales must be transposed for TMA load
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
||||
@ -81,30 +74,30 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
last_stages = (k + 127) // 128 % num_stages
|
||||
num_last_stages = (k + 127) // 128 % num_stages
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
tensor_map_scales_b = make_2d_tma_scales_b_desc(
|
||||
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
|
||||
tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'K': aligned_k,
|
||||
'M': m, 'N': n, 'K': aligned_k,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_LAST_STAGES': num_last_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
# Runtime arguments
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
@ -113,23 +106,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='wgrad_gemm_fp8_fp8_fp32_nt',
|
||||
keys={'M': m, 'N': n,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages,
|
||||
'LAST_STAGES': last_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8WGradGemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(**best_keys, **kwargs)
|
||||
# Generate, build and run the kernel
|
||||
code = FP8WGradGemmRuntime.generate(**kwargs)
|
||||
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -144,16 +127,16 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
||||
Each batch's LHS, RHS, and output tensors must be contiguous.
|
||||
The RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
|
||||
The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
||||
lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
||||
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
|
||||
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
||||
The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
||||
rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
||||
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
||||
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||
The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
|
||||
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
||||
@ -164,16 +147,14 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
|
||||
|
||||
for idx in range(num_batches):
|
||||
k = batch_sizes[idx]
|
||||
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
D = out[idx]
|
||||
|
||||
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
|
||||
for i in range(num_batches):
|
||||
k = batch_sizes[i]
|
||||
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
|
||||
|
||||
lhs_offset += m * k
|
||||
rhs_offset += n * k
|
||||
scales_offset += (k + 127) // 128
|
||||
scales_offset += (k + 127) // 128
|
||||
|
||||
@ -2,6 +2,7 @@ import ctypes
|
||||
import os
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Dict
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
@ -12,12 +13,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
@ -46,27 +42,25 @@ static void __instantiate_kernel() {{
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
|
||||
assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
|
||||
assert kwargs['A'].dim() == 1
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.gridDimX = (a.numel() + 127) // 128
|
||||
config.gridDimX = (kwargs['A'].numel() + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
a.numel(),
|
||||
kwargs['A'].data_ptr(),
|
||||
kwargs['B'].data_ptr(),
|
||||
kwargs['C'].data_ptr(),
|
||||
kwargs['A'].numel(),
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user