From 919f55be9c77afdd593a0207c4bb42d7d7b99023 Mon Sep 17 00:00:00 2001 From: Zhean Xu Date: Fri, 9 May 2025 12:55:46 +0800 Subject: [PATCH] Support unaligned n,k and gmem stride --- deep_gemm/jit_kernels/gemm.py | 19 ++++++++++++------- tests/test_core.py | 9 ++++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 64cda12..f515ab4 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -176,8 +176,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], n, k_ = rhs.shape m_, n_ = out.shape - assert n % 64 == 0 and k % 128 == 0 - # Type and shape checks assert m == m_ and n == n_ and k == k_ assert n > 0 and k > 0 @@ -186,7 +184,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert out.dtype == torch.bfloat16 - assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() + assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 + + lhs_stride = lhs.stride(0) + rhs_stride = rhs.stride(0) + out_stride = out.stride(0) # LHS scales must be transposed for TMA loads, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels @@ -197,6 +199,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], if m == 0: return + aligned_n = (n + 63) // 64 * 64 + aligned_k = (k + 127) // 128 * 128 + # Auto-tuning with compilation num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( @@ -206,11 +211,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], num_math_threads_per_group = 128 tensor_map_a = make_2d_tma_a_desc( - GemmType.Normal, lhs, m, k, block_m, block_k, 1) + GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) tensor_map_b = make_2d_tma_b_desc( - GemmType.Normal, rhs, k, n, block_k, block_n, 1) + GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) tensor_map_d = make_2d_tma_d_desc( - GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1]) + GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride) tensor_map_scales_a = make_2d_tma_scales_a_desc( GemmType.Normal, lhs_scales, m, k, block_m, block_k) @@ -235,7 +240,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], runtime, best_keys = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], 'BLOCK_N_PADDING': smem_config[2], 'NUM_STAGES': num_stages, diff --git a/tests/test_core.py b/tests/test_core.py index 052d42b..c45b511 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,11 +13,14 @@ from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 + assert x.dim() == 2 m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -161,7 +164,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ def test_gemm() -> None: print('Testing GEMM:') for m in (64, 128, 4096): - for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: + for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: x_fp8, y_fp8, out, ref_out = construct(m, k, n) deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) diff = calc_diff(out, ref_out)