From ded740f736570c000b7063d7774f7c18b5d1ecd6 Mon Sep 17 00:00:00 2001 From: Liang <44948473+soundOfDestiny@users.noreply.github.com> Date: Tue, 4 Mar 2025 11:26:23 +0800 Subject: [PATCH] Fix documentation of `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` in m_grouped_gemm.py --- deep_gemm/jit_kernels/m_grouped_gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 28f838a..97fb636 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -54,7 +54,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten `m_indices[i]` records the group which the j-th row of the LHS belong to, which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. Values of `m_indices` in every-m-alignment-block must also be the same. - `-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block. """ lhs, lhs_scales = lhs rhs, rhs_scales = rhs