fix 2D path

This commit is contained in:
YLGH 2025-03-10 17:09:05 +00:00
parent 5e4badc577
commit 5b9bfa6057

View File

@ -11,7 +11,11 @@ def set_num_sms(num_sms: int) -> None:
num_sms: the desired maximum SM count for all GEMM kernels to use. num_sms: the desired maximum SM count for all GEMM kernels to use.
""" """
global _num_sms global _num_sms
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count assert (
0
< num_sms
<= torch.cuda.get_device_properties(device="cuda").multi_processor_count
)
_num_sms = num_sms _num_sms = num_sms
@ -25,7 +29,7 @@ def get_num_sms() -> int:
""" """
global _num_sms global _num_sms
if _num_sms is None: if _num_sms is None:
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count _num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count
return _num_sms return _num_sms
@ -90,6 +94,10 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
assert x.dim() in (2, 3) assert x.dim() in (2, 3)
remove_dim = False remove_dim = False
if x.dim() == 2: if x.dim() == 2:
m, n = x.shape
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.stride() == (1, aligned_m):
return x
x, remove_dim = x.unsqueeze(0), True x, remove_dim = x.unsqueeze(0), True
b, m, n = x.shape b, m, n = x.shape
@ -100,7 +108,9 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return x.squeeze(0) if remove_dim else x return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing # Normal layout requires transposing
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2
)
aligned_x[:, :m, :] = x aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :] aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x return aligned_x.squeeze(0) if remove_dim else aligned_x