From a6ced6f2079003934c7993457ac1585859043de8 Mon Sep 17 00:00:00 2001 From: Zhean Xu Date: Wed, 14 May 2025 14:41:39 +0800 Subject: [PATCH] Add stride(0) assertions --- deep_gemm/jit_kernels/gemm.py | 3 +++ deep_gemm/jit_kernels/wgrad_gemm.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index ce14171..c782f28 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -191,6 +191,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs_stride = rhs.stride(0) out_stride = out.stride(0) + # The stride(0) of LHS, RHS, and output must be aligned to 16 bytes + assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0 + # LHS scales must be transposed for TMA loads, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 9eb0373..7dd5fc5 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -51,6 +51,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs_stride = rhs.stride(0) out_stride = out.stride(0) + # The stride(0) of LHS, RHS, and output must be aligned to 16 bytes + assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0 + # LHS and RHS scales must be transposed for TMA load # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels if lhs_scales.shape == ((k + 127) // 128, m):