Update docs

This commit is contained in:
Zhean Xu 2025-05-13 15:30:32 +08:00
parent 919f55be9c
commit 6233709c67
3 changed files with 44 additions and 30 deletions

View File

@ -157,11 +157,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor) -> None: out: torch.Tensor) -> None:
""" """
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed. Requirements:
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
this function will do a transposing with a set of slow PyTorch operations. The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments: Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,

View File

@ -14,13 +14,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, m_indices: torch.Tensor) -> None: out: torch.Tensor, m_indices: torch.Tensor) -> None:
""" """
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed. Requirements:
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
this function will do a transposing with a set of slow PyTorch operations. RHS and RHS scaling factors are required to be transposed.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
`get_m_alignment_for_contiguous_layout()` (128). this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments: Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
@ -116,13 +118,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
""" """
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed. Requirements:
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
this function will do a transposing with a set of slow PyTorch operations. RHS and RHS scaling factors are required to be transposed.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
should be separately transposed. this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments: Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,

View File

@ -15,18 +15,22 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
out: Tuple[torch.Tensor, torch.Tensor]): out: Tuple[torch.Tensor, torch.Tensor]):
""" """
Do a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. Results will be accumulated into the output tensor.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, Requirements:
this function will do a transposing with a set of slow PyTorch operations. LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments: Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, k / 128]`. the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, k / 128]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, k / 128]`. the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, k / 128]`.
out: the FP32 output tensor of shape `[m, n]`, representing the result. out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
""" """
lhs, lhs_scales = lhs lhs, lhs_scales = lhs
rhs, rhs_scales = rhs rhs, rhs_scales = rhs
@ -131,10 +135,13 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
batch_sizes: List[int]): batch_sizes: List[int]):
""" """
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
This function handles multiple batches with varying k-dimensions, processing each batch sequentially. Results will be accumulated into the output tensor.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed. Requirements:
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
Arguments: Arguments:
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
@ -145,7 +152,7 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
the second element is an FP32 scaling tensor for RHS with shape `[k / 128 for k in batch_sizes), n]`, the second element is an FP32 scaling tensor for RHS with shape `[k / 128 for k in batch_sizes), n]`,
representing the per-128-channel scaling factors. representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], representing the result. out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
batch_sizes: A list of integers specifying the k-dimension for each batch. batch_sizes: A list of integers specifying the k-dimension for each batch.
""" """
lhs, lhs_scales = lhs[0].view(-1), lhs[1] lhs, lhs_scales = lhs[0].view(-1), lhs[1]