diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index 703e1e2..c1a1557 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -102,4 +102,5 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: # Normal layout requires transposing aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x