Unify ceil_divs

This commit is contained in:
Chenggang Zhao
2025-05-15 16:48:32 +08:00
parent 4373af2e82
commit 350989eef3
4 changed files with 33 additions and 35 deletions

View File

@@ -7,7 +7,7 @@ from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Type and shape checks
assert m == m_ == m__ and k == k_ and n == n_
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
@@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16