mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Unify ceil_divs
This commit is contained in:
@@ -7,7 +7,7 @@ from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ == m__ and k == k_ and n == n_
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
@@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
|
||||
Reference in New Issue
Block a user