mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Update docs
This commit is contained in:
parent
919f55be9c
commit
6233709c67
@ -157,8 +157,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
out: torch.Tensor) -> None:
|
out: torch.Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
|
||||||
|
Requirements:
|
||||||
|
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||||
|
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
|
||||||
RHS and RHS scaling factors are required to be transposed.
|
RHS and RHS scaling factors are required to be transposed.
|
||||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||||
this function will do a transposing with a set of slow PyTorch operations.
|
this function will do a transposing with a set of slow PyTorch operations.
|
||||||
|
|||||||
@ -14,7 +14,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||||
RHS and RHS scaling factors are required to be transposed.
|
RHS and RHS scaling factors are required to be transposed.
|
||||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||||
@ -116,7 +118,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||||
"""
|
"""
|
||||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||||
RHS and RHS scaling factors are required to be transposed.
|
RHS and RHS scaling factors are required to be transposed.
|
||||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||||
|
|||||||
@ -15,8 +15,12 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
out: Tuple[torch.Tensor, torch.Tensor]):
|
out: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
"""
|
"""
|
||||||
Do a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||||
|
Results will be accumulated into the output tensor.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||||
|
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
|
||||||
RHS and RHS scaling factors are required to be transposed.
|
RHS and RHS scaling factors are required to be transposed.
|
||||||
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
|
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
|
||||||
this function will do a transposing with a set of slow PyTorch operations.
|
this function will do a transposing with a set of slow PyTorch operations.
|
||||||
@ -26,7 +30,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
|
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
|
||||||
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
|
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
|
||||||
out: the FP32 output tensor of shape `[m, n]`, representing the result.
|
out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
|
||||||
"""
|
"""
|
||||||
lhs, lhs_scales = lhs
|
lhs, lhs_scales = lhs
|
||||||
rhs, rhs_scales = rhs
|
rhs, rhs_scales = rhs
|
||||||
@ -131,6 +135,9 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
batch_sizes: List[int]):
|
batch_sizes: List[int]):
|
||||||
"""
|
"""
|
||||||
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||||
|
Results will be accumulated into the output tensor.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
||||||
Each batch's LHS, RHS, and output tensors must be contiguous.
|
Each batch's LHS, RHS, and output tensors must be contiguous.
|
||||||
The RHS and RHS scaling factors are required to be transposed.
|
The RHS and RHS scaling factors are required to be transposed.
|
||||||
@ -145,7 +152,7 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
||||||
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||||
representing the per-128-channel scaling factors.
|
representing the per-128-channel scaling factors.
|
||||||
out: The FP32 output tensor of shape [num_batches, m, n], representing the result.
|
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
|
||||||
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
||||||
"""
|
"""
|
||||||
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
|
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user