Unify ceil_divs

This commit is contained in:
Chenggang Zhao
2025-05-15 16:48:32 +08:00
parent 4373af2e82
commit 350989eef3
4 changed files with 33 additions and 35 deletions

View File

@@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
@@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation
num_sms = get_num_sms()