diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index c1a1557..c6da56b 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -89,11 +89,14 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA assert x.dim() in (2, 3) remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x x, remove_dim = x.unsqueeze(0), True - b, m, n = x.shape - aligned_m = get_tma_aligned_size(m, x.element_size()) + b = x.shape[0] # The last kernel gives a column-major TMA aligned layout if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: