mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add stride(0) assertions
This commit is contained in:
parent
279eb03190
commit
a6ced6f207
@ -191,6 +191,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
|
||||
@ -51,6 +51,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
|
||||
|
||||
# LHS and RHS scales must be transposed for TMA load
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user