mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-02 00:03:59 +00:00
Merge pull request #65 from Z-NAVY/main
Fix get_col_major_tma_aligned_tensor to handle 2-dimensional inputs
This commit is contained in:
commit
e1c070fbef
@ -89,11 +89,14 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|||||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||||
assert x.dim() in (2, 3)
|
assert x.dim() in (2, 3)
|
||||||
remove_dim = False
|
remove_dim = False
|
||||||
|
m, n = x.shape[-2], x.shape[-1]
|
||||||
|
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||||
if x.dim() == 2:
|
if x.dim() == 2:
|
||||||
|
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||||
|
return x
|
||||||
x, remove_dim = x.unsqueeze(0), True
|
x, remove_dim = x.unsqueeze(0), True
|
||||||
|
|
||||||
b, m, n = x.shape
|
b = x.shape[0]
|
||||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
|
||||||
|
|
||||||
# The last kernel gives a column-major TMA aligned layout
|
# The last kernel gives a column-major TMA aligned layout
|
||||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||||
|
Loading…
Reference in New Issue
Block a user