Fix get_col_major_tma_aligned_tensor to handle 2-dimensional inputs

This commit is contained in:
z-navy 2025-03-13 22:15:16 +08:00
parent bd2a775528
commit 3f92607b98

View File

@ -89,11 +89,14 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b, m, n = x.shape
aligned_m = get_tma_aligned_size(m, x.element_size())
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: