Merge pull request #65 from Z-NAVY/main

Fix get_col_major_tma_aligned_tensor to handle 2-dimensional inputs
This commit is contained in:
Liang 2025-03-14 13:50:08 +08:00 committed by GitHub
commit e1c070fbef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,11 +89,14 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b, m, n = x.shape
aligned_m = get_tma_aligned_size(m, x.element_size())
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: