mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-13 18:31:10 +00:00
107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
import torch
|
|
|
|
_num_sms = None
|
|
|
|
|
|
def set_num_sms(num_sms: int) -> None:
|
|
"""
|
|
Set the maximum SM count for all GEMM kernels to use.
|
|
|
|
Arguments:
|
|
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
|
"""
|
|
global _num_sms
|
|
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
|
_num_sms = num_sms
|
|
|
|
|
|
def get_num_sms() -> int:
|
|
"""
|
|
Get the current maximum limit of SM count for all GEMM kernels to use.
|
|
If the count is never specified, the function will return the number of device SMs.
|
|
|
|
Returns:
|
|
Current maximum limit of SM count for all GEMM kernels to use.
|
|
"""
|
|
global _num_sms
|
|
if _num_sms is None:
|
|
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
|
return _num_sms
|
|
|
|
|
|
def ceil_div(x: int, y: int) -> int:
|
|
"""
|
|
Perform ceiling division of two integers.
|
|
|
|
Args:
|
|
x: the dividend.
|
|
y: the divisor.
|
|
|
|
Returns:
|
|
The result of the ceiling division.
|
|
"""
|
|
return (x + y - 1) // y
|
|
|
|
|
|
def get_m_alignment_for_contiguous_layout():
|
|
"""
|
|
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
|
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
|
with GEMM block shape.
|
|
|
|
Returns:
|
|
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
|
"""
|
|
return 128
|
|
|
|
|
|
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
|
"""
|
|
Global memory address of TMA must be 16-byte aligned.
|
|
Since we use column-major layout for the LHS scaling tensor,
|
|
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
|
|
|
Arguments:
|
|
x: original M-axis shape of the LHS scaling tensor.
|
|
element_size: element size of the LHS scaling tensor.
|
|
|
|
Returns:
|
|
M-axis shape of the LHS scaling tensor after padding.
|
|
"""
|
|
tma_alignment_bytes = 16
|
|
assert tma_alignment_bytes % element_size == 0
|
|
alignment = tma_alignment_bytes // element_size
|
|
return ceil_div(x, alignment) * alignment
|
|
|
|
|
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
|
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
|
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
|
|
|
Arguments:
|
|
x: usually the LHS scaling tensor in GEMM.
|
|
|
|
Returns:
|
|
The LHS scaling tensor of TMA-aligned transposed format.
|
|
"""
|
|
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
|
assert x.dim() in (2, 3)
|
|
remove_dim = False
|
|
if x.dim() == 2:
|
|
x, remove_dim = x.unsqueeze(0), True
|
|
|
|
b, m, n = x.shape
|
|
aligned_m = get_tma_aligned_size(m, x.element_size())
|
|
|
|
# The last kernel gives a column-major TMA aligned layout
|
|
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
|
return x.squeeze(0) if remove_dim else x
|
|
|
|
# Normal layout requires transposing
|
|
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
|
aligned_x[:, :m, :] = x
|
|
aligned_x = aligned_x[:, :m, :]
|
|
return aligned_x.squeeze(0) if remove_dim else aligned_x
|