From b69f630b914c5c98f28f5ed8741d14e1952f2404 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 28 Feb 2025 09:46:38 +0800 Subject: [PATCH] Minor fix util function --- deep_gemm/jit_kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index 703e1e2..c1a1557 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -102,4 +102,5 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: # Normal layout requires transposing aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x