mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-23 07:24:19 +00:00
Minor fix util function
This commit is contained in:
parent
6e10cba207
commit
b69f630b91
@ -102,4 +102,5 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|||||||
# Normal layout requires transposing
|
# Normal layout requires transposing
|
||||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||||
aligned_x[:, :m, :] = x
|
aligned_x[:, :m, :] = x
|
||||||
|
aligned_x = aligned_x[:, :m, :]
|
||||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||||
|
Loading…
Reference in New Issue
Block a user