Unify ceil_divs

This commit is contained in:
Chenggang Zhao 2025-05-15 16:48:32 +08:00
parent 4373af2e82
commit 350989eef3
4 changed files with 33 additions and 35 deletions

View File

@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation
num_sms = get_num_sms()

View File

@ -7,7 +7,7 @@ from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Type and shape checks
assert m == m_ == m__ and k == k_ and n == n_
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16

View File

@ -7,7 +7,7 @@ from .runtime import (
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .gemm import get_best_configs
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -40,41 +40,39 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and m > 0
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m):
lhs_scales = lhs_scales.permute(1, 0)
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
else:
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert lhs_scales.stride(0) == 1
if rhs_scales.shape == ((k + 127) // 128, n):
rhs_scales = rhs_scales.permute(1, 0)
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
else:
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
assert rhs_scales.stride(0) == 1
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
def get_valid_scales(scales: torch.Tensor, mn: int):
if scales.shape == (ceil_div(k, 128), mn):
# For k-grouped GEMMs
scales = scales.permute(1, 0)
assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
else:
scales = get_col_major_tma_aligned_tensor(scales)
return scales
lhs_scales = get_valid_scales(lhs_scales, m)
rhs_scales = get_valid_scales(rhs_scales, n)
# Do nothing if `k` is zero
if k == 0:
return
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
num_last_stages = (k + 127) // 128 % num_stages
num_last_stages = ceil_div(k, 128) % num_stages
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
@ -151,10 +149,10 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
k = batch_sizes[i]
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
lhs_offset += m * k
rhs_offset += n * k
scales_offset += (k + 127) // 128
scales_offset += ceil_div(k, 128)

View File

@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
num_scales = (k + 127) // 128
num_scales = ceil_div(k, 128)
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)