import torch _num_sms = None def set_num_sms(num_sms: int) -> None: """ Set the maximum SM count for all GEMM kernels to use. Arguments: num_sms: the desired maximum SM count for all GEMM kernels to use. """ global _num_sms assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count _num_sms = num_sms def get_num_sms() -> int: """ Get the current maximum limit of SM count for all GEMM kernels to use. If the count is never specified, the function will return the number of device SMs. Returns: Current maximum limit of SM count for all GEMM kernels to use. """ global _num_sms if _num_sms is None: _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count return _num_sms def ceil_div(x: int, y: int) -> int: """ Perform ceiling division of two integers. Args: x: the dividend. y: the divisor. Returns: The result of the ceiling division. """ return (x + y - 1) // y def get_m_alignment_for_contiguous_layout(): """ When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well with GEMM block shape. Returns: Group-level alignment requirement for grouped contiguous layout, which is always 128. """ return 128 def get_tma_aligned_size(x: int, element_size: int) -> int: """ Global memory address of TMA must be 16-byte aligned. Since we use column-major layout for the LHS scaling tensor, the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. Arguments: x: original M-axis shape of the LHS scaling tensor. element_size: element size of the LHS scaling tensor. Returns: M-axis shape of the LHS scaling tensor after padding. """ tma_alignment_bytes = 16 assert tma_alignment_bytes % element_size == 0 alignment = tma_alignment_bytes // element_size return ceil_div(x, alignment) * alignment def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: """ Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. If the input tensor is already column-major layout and 16-byte aligned along the M axis (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. Arguments: x: usually the LHS scaling tensor in GEMM. Returns: The LHS scaling tensor of TMA-aligned transposed format. """ # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA assert x.dim() in (2, 3) remove_dim = False if x.dim() == 2: x, remove_dim = x.unsqueeze(0), True b, m, n = x.shape aligned_m = get_tma_aligned_size(m, x.element_size()) # The last kernel gives a column-major TMA aligned layout if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: return x.squeeze(0) if remove_dim else x # Normal layout requires transposing aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x