mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Unify ceil_divs
This commit is contained in:
@@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
@@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
aligned_k = ceil_div(k, 128) * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
|
||||
Reference in New Issue
Block a user