mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-10 13:35:27 +00:00
Fix get_col_major_tma_aligned_tensor to handle 2-dimensional inputs
This commit is contained in:
parent
bd2a775528
commit
3f92607b98
@ -89,11 +89,14 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.dim() == 2:
|
||||
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b, m, n = x.shape
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
b = x.shape[0]
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
|
Loading…
Reference in New Issue
Block a user