From 6233709c67d50ec4583d33c54dd3b1a4bec9b5f2 Mon Sep 17 00:00:00 2001 From: Zhean Xu Date: Tue, 13 May 2025 15:30:32 +0800 Subject: [PATCH] Update docs --- deep_gemm/jit_kernels/gemm.py | 13 ++++++---- deep_gemm/jit_kernels/m_grouped_gemm.py | 32 ++++++++++++++----------- deep_gemm/jit_kernels/wgrad_gemm.py | 29 +++++++++++++--------- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index f515ab4..16cde20 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -157,11 +157,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor) -> None: """ - Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. + Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + + Requirements: + LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. + The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 24a2183..e8c1922 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -14,13 +14,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, m_indices: torch.Tensor) -> None: """ - Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - On the M axis, inputs are grouped into several batches, of which batch sizes aligned to - `get_m_alignment_for_contiguous_layout()` (128). + Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + + Requirements: + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + On the M axis, inputs are grouped into several batches, of which batch sizes aligned to + `get_m_alignment_for_contiguous_layout()` (128). Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, @@ -116,13 +118,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: """ - Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch - should be separately transposed. + Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + + Requirements: + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. + Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch + should be separately transposed. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 157647f..4fad99a 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -15,18 +15,22 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: Tuple[torch.Tensor, torch.Tensor]): """ - Do a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. - LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. + Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. + Results will be accumulated into the output tensor. + + Requirements: + LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. + The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, + this function will do a transposing with a set of slow PyTorch operations. Arguments: lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`. - out: the FP32 output tensor of shape `[m, n]`, representing the result. + out: the FP32 output tensor of shape `[m, n]`, which will be accumulated. """ lhs, lhs_scales = lhs rhs, rhs_scales = rhs @@ -131,10 +135,13 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], batch_sizes: List[int]): """ Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. - This function handles multiple batches with varying k-dimensions, processing each batch sequentially. - Each batch's LHS, RHS, and output tensors must be contiguous. - The RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. + Results will be accumulated into the output tensor. + + Requirements: + This function handles multiple batches with varying k-dimensions, processing each batch sequentially. + Each batch's LHS, RHS, and output tensors must be contiguous. + The RHS and RHS scaling factors are required to be transposed. + The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. Arguments: lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, @@ -145,7 +152,7 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, representing the per-128-channel scaling factors. - out: The FP32 output tensor of shape [num_batches, m, n], representing the result. + out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. batch_sizes: A list of integers specifying the k-dimension for each batch. """ lhs, lhs_scales = lhs[0].view(-1), lhs[1]