Weight gradient kernels for dense and MoE models (#95)

* Init weight gradient kernels.

* Support unaligned n,k and gmem stride

* Update docs

* Several cleanups

* Remove restrictions on N

* Add stride(0) assertions

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
Zhean Xu
2025-05-14 14:47:58 +08:00
committed by GitHub
parent d75b218b7b
commit 04278f6dee
12 changed files with 911 additions and 72 deletions

View File

@@ -3,6 +3,10 @@ from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
)
from .utils import (
ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,

View File

@@ -34,17 +34,22 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
assert block_k == 128
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
smem_d = block_m * (block_n + block_n_padding) * 2
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = ceil_div(k, block_k) * 4
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
smem_barrier = num_stages * 8 * 2
smem_size = 0
@@ -52,8 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k %
block_n == 0 else 2), 8) * 8
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
@@ -64,13 +69,18 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
if not is_grouped_contiguous:
block_ms = (64, 128, 256)
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@@ -110,7 +120,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break
@@ -145,11 +155,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
@@ -164,8 +177,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape
assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
@@ -174,7 +185,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
@@ -185,6 +203,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
if m == 0:
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
@@ -194,11 +215,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
@@ -223,7 +244,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
keys={'N': n, 'K': aligned_k,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,

View File

@@ -14,13 +14,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, m_indices: torch.Tensor) -> None:
"""
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -116,13 +118,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
"""
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,

View File

@@ -87,45 +87,48 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
gmem_rows: int, gmem_cols: int,
gmem_rows: int, gmem_cols: int, gmem_stride: int,
smem_rows: int, smem_cols: int,
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_k: int,
block_m: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
a_stride = shape_k if a_stride == 0 else a_stride
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_k: int, shape_n: int,
block_k: int, block_n: int,
num_groups: int) -> cbd.CUtensorMap:
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
b_stride = shape_k if b_stride == 0 else b_stride
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1),
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m: int, shape_n: int,
block_m: int, block_n: int,
num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap:
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
d_stride = shape_n if d_stride == 0 else d_stride
return make_2d_tma_desc(global_address, Layout.RowMajor,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
swizzle_type_map[swizzle_mode])
@@ -136,10 +139,20 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1),
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size()
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
@@ -254,3 +267,111 @@ static void __instantiate_kernel() {{
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
class FP8WGradGemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'NUM_TMA_MULTICAST',
'K',
'BLOCK_M',
'GMEM_D',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_SCALES_B',
'TENSOR_MAP_D',
'STREAM',
])
@staticmethod
def generate(**kwargs) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_wgrad_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&fp8_wgrad_gemm_kernel<
{kwargs['M']},
{kwargs['N']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['NUM_STAGES']},
{kwargs['LAST_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}
>);
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 WGrad GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
arg_values = (
shape_k,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_scales_b,
tensor_map_d,
)
arg_types = (
ctypes.c_uint32,
None,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)

View File

@@ -0,0 +1,179 @@
import torch
from typing import List, Tuple
from .runtime import (
FP8WGradGemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
from .gemm import get_best_configs
from .tuner import jit_tuner
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: Tuple[torch.Tensor, torch.Tensor]):
"""
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
Requirements:
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
m_, n_ = out.shape
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and m > 0
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m):
lhs_scales = lhs_scales.permute(1, 0)
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
else:
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert lhs_scales.stride(0) == 1
if rhs_scales.shape == ((k + 127) // 128, n):
rhs_scales = rhs_scales.permute(1, 0)
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
else:
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
assert rhs_scales.stride(0) == 1
# Do nothing if `k` is zero
if k == 0:
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
last_stages = (k + 127) // 128 % num_stages
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
tensor_map_scales_b = make_2d_tma_scales_b_desc(
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
kwargs = {
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'K': aligned_k,
'NUM_GROUPS': 1,
'BLOCK_K': block_k,
'GMEM_D': out,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='wgrad_gemm_fp8_fp8_fp32_nt',
keys={'M': m, 'N': n,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages,
'LAST_STAGES': last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8WGradGemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
batch_sizes: List[int]):
"""
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
Requirements:
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
Arguments:
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
representing the per-128-channel scaling factors.
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
batch_sizes: A list of integers specifying the k-dimension for each batch.
"""
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
num_batches, m, n = out.shape
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
for idx in range(num_batches):
k = batch_sizes[idx]
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
D = out[idx]
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
lhs_offset += m * k
rhs_offset += n * k
scales_offset += (k + 127) // 128