From 3f92607b981fe6ae8537dc9d7b2b9972751bc280 Mon Sep 17 00:00:00 2001 From: z-navy Date: Thu, 13 Mar 2025 22:15:16 +0800 Subject: [PATCH] Fix get_col_major_tma_aligned_tensor to handle 2-dimensional inputs --- deep_gemm/jit_kernels/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index c1a1557..c6da56b 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -89,11 +89,14 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA assert x.dim() in (2, 3) remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x x, remove_dim = x.unsqueeze(0), True - b, m, n = x.shape - aligned_m = get_tma_aligned_size(m, x.element_size()) + b = x.shape[0] # The last kernel gives a column-major TMA aligned layout if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: