mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-21 22:24:13 +00:00
fix 2D path
This commit is contained in:
parent
5e4badc577
commit
5b9bfa6057
@ -11,7 +11,11 @@ def set_num_sms(num_sms: int) -> None:
|
|||||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||||
"""
|
"""
|
||||||
global _num_sms
|
global _num_sms
|
||||||
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
assert (
|
||||||
|
0
|
||||||
|
< num_sms
|
||||||
|
<= torch.cuda.get_device_properties(device="cuda").multi_processor_count
|
||||||
|
)
|
||||||
_num_sms = num_sms
|
_num_sms = num_sms
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +29,7 @@ def get_num_sms() -> int:
|
|||||||
"""
|
"""
|
||||||
global _num_sms
|
global _num_sms
|
||||||
if _num_sms is None:
|
if _num_sms is None:
|
||||||
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
_num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count
|
||||||
return _num_sms
|
return _num_sms
|
||||||
|
|
||||||
|
|
||||||
@ -90,6 +94,10 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|||||||
assert x.dim() in (2, 3)
|
assert x.dim() in (2, 3)
|
||||||
remove_dim = False
|
remove_dim = False
|
||||||
if x.dim() == 2:
|
if x.dim() == 2:
|
||||||
|
m, n = x.shape
|
||||||
|
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||||
|
if x.stride() == (1, aligned_m):
|
||||||
|
return x
|
||||||
x, remove_dim = x.unsqueeze(0), True
|
x, remove_dim = x.unsqueeze(0), True
|
||||||
|
|
||||||
b, m, n = x.shape
|
b, m, n = x.shape
|
||||||
@ -100,7 +108,9 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return x.squeeze(0) if remove_dim else x
|
return x.squeeze(0) if remove_dim else x
|
||||||
|
|
||||||
# Normal layout requires transposing
|
# Normal layout requires transposing
|
||||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
aligned_x = torch.transpose(
|
||||||
|
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2
|
||||||
|
)
|
||||||
aligned_x[:, :m, :] = x
|
aligned_x[:, :m, :] = x
|
||||||
aligned_x = aligned_x[:, :m, :]
|
aligned_x = aligned_x[:, :m, :]
|
||||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||||
|
Loading…
Reference in New Issue
Block a user