Minor fix util function

This commit is contained in:
Chenggang Zhao 2025-02-28 09:46:38 +08:00
parent 6e10cba207
commit b69f630b91

View File

@ -102,4 +102,5 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
# Normal layout requires transposing
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x