mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Unify ceil_divs
This commit is contained in:
parent
4373af2e82
commit
350989eef3
@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
aligned_k = ceil_div(k, 128) * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
|
||||
@ -7,7 +7,7 @@ from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ == m__ and k == k_ and n == n_
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
|
||||
@ -7,7 +7,7 @@ from .runtime import (
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .gemm import get_best_configs
|
||||
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||
from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||
|
||||
|
||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -40,41 +40,39 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and m > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
|
||||
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
|
||||
assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.float
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
# LHS and RHS scales must be transposed for TMA load
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
||||
lhs_scales = lhs_scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
|
||||
else:
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert lhs_scales.stride(0) == 1
|
||||
|
||||
if rhs_scales.shape == ((k + 127) // 128, n):
|
||||
rhs_scales = rhs_scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
|
||||
else:
|
||||
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
|
||||
assert rhs_scales.stride(0) == 1
|
||||
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||
def get_valid_scales(scales: torch.Tensor, mn: int):
|
||||
if scales.shape == (ceil_div(k, 128), mn):
|
||||
# For k-grouped GEMMs
|
||||
scales = scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
|
||||
else:
|
||||
scales = get_col_major_tma_aligned_tensor(scales)
|
||||
return scales
|
||||
|
||||
lhs_scales = get_valid_scales(lhs_scales, m)
|
||||
rhs_scales = get_valid_scales(rhs_scales, n)
|
||||
|
||||
# Do nothing if `k` is zero
|
||||
if k == 0:
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
aligned_k = ceil_div(k, 128) * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
num_last_stages = (k + 127) // 128 % num_stages
|
||||
num_last_stages = ceil_div(k, 128) % num_stages
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
@ -151,10 +149,10 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
k = batch_sizes[i]
|
||||
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
||||
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
|
||||
|
||||
lhs_offset += m * k
|
||||
rhs_offset += n * k
|
||||
scales_offset += (k + 127) // 128
|
||||
scales_offset += ceil_div(k, 128)
|
||||
|
||||
@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
||||
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
||||
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
||||
|
||||
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
|
||||
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
|
||||
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
||||
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
||||
|
||||
@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
||||
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
||||
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
||||
|
||||
num_scales = (k + 127) // 128
|
||||
num_scales = ceil_div(k, 128)
|
||||
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
||||
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user