From e7fff7ef0a40cdc7d57961b2f43244787e3f1b2d Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 13 Mar 2025 22:09:15 +0800 Subject: [PATCH] Update m_grouped_gemm.py --- deep_gemm/jit_kernels/m_grouped_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 140e165..415fc67 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -51,7 +51,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. m_indices: a tensor of shape `[m_sum]` with type `torch.int`. - `m_indices[i]` records the group which the j-th row of the LHS belong to, + `m_indices[i]` records the group which the i-th row of the LHS belong to, which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. Values of `m_indices` in every-m-alignment-block must also be the same. """