From 0b5d353dba2ab84e02170c6afe6aeee95cfc3efb Mon Sep 17 00:00:00 2001 From: dxh Date: Fri, 7 Mar 2025 10:58:37 +0800 Subject: [PATCH] tensor alignment fix --- deep_gemm/jit_kernels/gemm.py | 77 +------------------------ deep_gemm/jit_kernels/m_grouped_gemm.py | 7 ++- deep_gemm/jit_kernels/utils.py | 31 +++++----- 3 files changed, 23 insertions(+), 92 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 1ba413a..9cce52f 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -2,6 +2,7 @@ import torch from typing import Tuple from .tuner import jit_tuner +from .config import config_cache from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout # C++ code templates @@ -30,80 +31,6 @@ GemmType::run(out, rhs_scales, nullptr, stream, num_sms, smem_size); """ - -def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: - if num_tma_multicast == 1: - return True - return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 - - -def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: - smem_d = block_m * block_n * 2 - smem_a_per_stage = block_m * block_k - smem_scales_a_per_stage = block_m * 4 - smem_b_per_stage = block_n * block_k - smem_scales_b = ceil_div(k, block_k) * 4 - smem_barrier = num_stages * 8 * 2 - - smem_size = 0 - smem_size += smem_d - smem_size += num_stages * smem_a_per_stage - smem_size += num_stages * smem_scales_a_per_stage - smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 - smem_size += smem_barrier - return smem_size - - -def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]: - if not is_grouped_contiguous: - # TODO: for some cases, smaller M block is better, add them into tuning space - block_ms = (64 if m <= 64 else 128, ) - else: - block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) - - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) - - # Decide block sizes by waves - best_block_m, best_block_n = None, None - for block_m in block_ms: - for block_n in block_ns: - success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) - if best_block_m is None or best_block_n is None: - success = True - elif num_waves < best_num_waves: - success = True - elif num_waves == best_num_waves: - # Check last wave utilization - util = get_last_wave_util(block_m, block_n) - best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) - assert best_block_m is not None and best_block_n is not None - - # Always pick the longest one - # NOTES: for double B scales, the best number of stages may be reduced - best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): - best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: - best_num_stages = num_stages - break - assert best_num_stages is not None - - # Decide the number of TMA multicast - best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 - - return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size - - def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor) -> None: @@ -151,7 +78,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + block_m, block_n, num_stages, num_tma_multicast, smem_size = config_cache.compute_and_cache(m, n, k, 1, num_sms) args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 97fb636..ad3447a 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,8 +1,9 @@ import torch from typing import Tuple -from .gemm import get_best_configs + from .tuner import jit_tuner +from .config import config_cache from .utils import get_col_major_tma_aligned_tensor, get_num_sms # C++ code templates @@ -84,7 +85,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, + block_m, block_n, num_stages, num_tma_multicast, smem_size = config_cache.compute_and_cache(m, n, k, 1, num_sms, is_grouped_contiguous=True) args = (lhs, lhs_scales, rhs, rhs_scales, out, m_indices, m, num_groups, @@ -158,7 +159,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + block_m, block_n, num_stages, num_tma_multicast, smem_size = config_cache.compute_and_cache(expected_m, n, k, num_groups, num_sms) # Extra checks for TMA store if num_groups > 1 and m > block_m: diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index c1a1557..4701d9a 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -88,19 +88,22 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: """ # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA assert x.dim() in (2, 3) - remove_dim = False if x.dim() == 2: - x, remove_dim = x.unsqueeze(0), True + m, n = x.shape + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + aligned_x = torch.transpose(torch.empty((n, aligned_m), device=x.device), 0, 1) + aligned_x[:m, :] = x + aligned_x = aligned_x[:m, :] + return aligned_x + elif x.dim() == 3: + b, m, n = x.shape + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: + return x + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device), 1,2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x - b, m, n = x.shape - aligned_m = get_tma_aligned_size(m, x.element_size()) - - # The last kernel gives a column-major TMA aligned layout - if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: - return x.squeeze(0) if remove_dim else x - - # Normal layout requires transposing - aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) - aligned_x[:, :m, :] = x - aligned_x = aligned_x[:, :m, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x