mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-08 15:22:12 +00:00
fix 2D path
This commit is contained in:
parent
5e4badc577
commit
5b9bfa6057
@ -11,7 +11,11 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
assert (
|
||||
0
|
||||
< num_sms
|
||||
<= torch.cuda.get_device_properties(device="cuda").multi_processor_count
|
||||
)
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
@ -25,7 +29,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
_num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
@ -48,7 +52,7 @@ def get_m_alignment_for_contiguous_layout():
|
||||
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
||||
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
||||
with GEMM block shape.
|
||||
|
||||
|
||||
Returns:
|
||||
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
||||
"""
|
||||
@ -90,6 +94,10 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
if x.dim() == 2:
|
||||
m, n = x.shape
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.stride() == (1, aligned_m):
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b, m, n = x.shape
|
||||
@ -100,7 +108,9 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x = torch.transpose(
|
||||
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2
|
||||
)
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
|
Loading…
Reference in New Issue
Block a user