mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add stride(0) assertions
This commit is contained in:
parent
279eb03190
commit
a6ced6f207
@ -191,6 +191,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
rhs_stride = rhs.stride(0)
|
rhs_stride = rhs.stride(0)
|
||||||
out_stride = out.stride(0)
|
out_stride = out.stride(0)
|
||||||
|
|
||||||
|
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||||
|
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
|
||||||
|
|
||||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||||
|
|||||||
@ -51,6 +51,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
rhs_stride = rhs.stride(0)
|
rhs_stride = rhs.stride(0)
|
||||||
out_stride = out.stride(0)
|
out_stride = out.stride(0)
|
||||||
|
|
||||||
|
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
||||||
|
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
|
||||||
|
|
||||||
# LHS and RHS scales must be transposed for TMA load
|
# LHS and RHS scales must be transposed for TMA load
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
if lhs_scales.shape == ((k + 127) // 128, m):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user