Add stride(0) assertions

This commit is contained in:
Zhean Xu 2025-05-14 14:41:39 +08:00
parent 279eb03190
commit a6ced6f207
2 changed files with 6 additions and 0 deletions

View File

@ -191,6 +191,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)

View File

@ -51,6 +51,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m):