mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-23 21:54:03 +00:00
commit
4377c4dc57
@ -51,7 +51,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||||
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||||
`m_indices[i]` records the group which the j-th row of the LHS belong to,
|
`m_indices[i]` records the group which the i-th row of the LHS belong to,
|
||||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user