mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Weight gradient kernels for dense and MoE models (#95)
* Init weight gradient kernels. * Support unaligned n,k and gmem stride * Update docs * Several cleanups * Remove restrictions on N * Add stride(0) assertions --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -3,6 +3,10 @@ from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .wgrad_gemm import (
|
||||
wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
|
||||
)
|
||||
from .utils import (
|
||||
ceil_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
|
||||
@@ -34,17 +34,22 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
|
||||
|
||||
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
|
||||
assert block_k == 128
|
||||
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||
# NOTES: `scales_b` in a total manner or per-stage manner
|
||||
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
|
||||
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
@@ -52,8 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k %
|
||||
block_n == 0 else 2), 8) * 8
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@@ -64,13 +69,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
|
||||
if not is_grouped_contiguous:
|
||||
block_ms = (64, 128, 256)
|
||||
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
|
||||
# Avoid bank conflicts for FP32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
@@ -110,7 +120,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
|
||||
if best_smem_config[0] <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
@@ -145,11 +155,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -164,8 +177,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
@@ -174,7 +185,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
lhs_stride = lhs.stride(0)
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
@@ -185,6 +203,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
@@ -194,11 +215,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
|
||||
@@ -223,7 +244,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
keys={'N': n, 'K': aligned_k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_STAGES': num_stages,
|
||||
|
||||
@@ -14,13 +14,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -116,13 +118,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
|
||||
@@ -87,45 +87,48 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
|
||||
gmem_rows: int, gmem_cols: int,
|
||||
gmem_rows: int, gmem_cols: int, gmem_stride: int,
|
||||
smem_rows: int, smem_cols: int,
|
||||
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_k: int,
|
||||
block_m: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
|
||||
a_stride = shape_k if a_stride == 0 else a_stride
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
|
||||
block_m, block_k)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_k: int, shape_n: int,
|
||||
block_k: int, block_n: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
|
||||
b_stride = shape_k if b_stride == 0 else b_stride
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1),
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
|
||||
block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_n: int,
|
||||
block_m: int, block_n: int,
|
||||
num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap:
|
||||
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
|
||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
d_stride = shape_n if d_stride == 0 else d_stride
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
|
||||
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
|
||||
swizzle_type_map[swizzle_mode])
|
||||
|
||||
@@ -136,10 +139,20 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1),
|
||||
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
|
||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
tma_alignment = 16 / global_address.element_size()
|
||||
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
|
||||
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
@@ -254,3 +267,111 @@ static void __instantiate_kernel() {{
|
||||
None,
|
||||
)
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
|
||||
|
||||
class FP8WGradGemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'K',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_SCALES_B',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <deep_gemm/fp8_wgrad_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_wgrad_gemm_kernel<
|
||||
{kwargs['M']},
|
||||
{kwargs['N']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['LAST_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}
|
||||
>);
|
||||
}};
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 WGrad GEMM code:\n{code}')
|
||||
return code
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
|
||||
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
|
||||
tensor_map_d: cbd.CUtensorMap,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
|
||||
arg_values = (
|
||||
shape_k,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_scales_b,
|
||||
tensor_map_d,
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
|
||||
179
deep_gemm/jit_kernels/wgrad_gemm.py
Normal file
179
deep_gemm/jit_kernels/wgrad_gemm.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from .runtime import (
|
||||
FP8WGradGemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||
|
||||
|
||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: Tuple[torch.Tensor, torch.Tensor]):
|
||||
"""
|
||||
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
Results will be accumulated into the output tensor.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
|
||||
out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and m > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
|
||||
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.float
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
lhs_stride = lhs.stride(0)
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
|
||||
|
||||
# LHS and RHS scales must be transposed for TMA load
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
||||
lhs_scales = lhs_scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
|
||||
else:
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert lhs_scales.stride(0) == 1
|
||||
|
||||
if rhs_scales.shape == ((k + 127) // 128, n):
|
||||
rhs_scales = rhs_scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
|
||||
else:
|
||||
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
|
||||
assert rhs_scales.stride(0) == 1
|
||||
|
||||
# Do nothing if `k` is zero
|
||||
if k == 0:
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
last_stages = (k + 127) // 128 % num_stages
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
tensor_map_scales_b = make_2d_tma_scales_b_desc(
|
||||
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
|
||||
|
||||
kwargs = {
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'K': aligned_k,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='wgrad_gemm_fp8_fp8_fp32_nt',
|
||||
keys={'M': m, 'N': n,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages,
|
||||
'LAST_STAGES': last_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8WGradGemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
|
||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
batch_sizes: List[int]):
|
||||
"""
|
||||
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
Results will be accumulated into the output tensor.
|
||||
|
||||
Requirements:
|
||||
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
||||
Each batch's LHS, RHS, and output tensors must be contiguous.
|
||||
The RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
||||
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
|
||||
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
||||
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
||||
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
|
||||
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
||||
"""
|
||||
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
|
||||
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
|
||||
num_batches, m, n = out.shape
|
||||
|
||||
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
|
||||
|
||||
for idx in range(num_batches):
|
||||
k = batch_sizes[idx]
|
||||
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
D = out[idx]
|
||||
|
||||
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
|
||||
|
||||
lhs_offset += m * k
|
||||
rhs_offset += n * k
|
||||
scales_offset += (k + 127) // 128
|
||||
Reference in New Issue
Block a user