mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Unify ceil_divs
This commit is contained in:
@@ -179,15 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert n > 0 and k > 0
|
assert n > 0 and k > 0
|
||||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||||
|
|
||||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||||
assert rhs_scales.is_contiguous()
|
assert rhs_scales.is_contiguous()
|
||||||
|
|
||||||
@@ -196,7 +196,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
return
|
return
|
||||||
|
|
||||||
# K must be aligned to 128
|
# K must be aligned to 128
|
||||||
aligned_k = (k + 127) // 128 * 128
|
aligned_k = ceil_div(k, 128) * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from .runtime import (
|
|||||||
FP8GemmRuntime, GemmType,
|
FP8GemmRuntime, GemmType,
|
||||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
|
||||||
|
|
||||||
|
|
||||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
|
|
||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ == m__ and k == k_ and n == n_
|
assert m == m_ == m__ and k == k_ and n == n_
|
||||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
@@ -142,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
|
||||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from .runtime import (
|
|||||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||||
from .gemm import get_best_configs
|
from .gemm import get_best_configs
|
||||||
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||||
|
|
||||||
|
|
||||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -40,41 +40,39 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert n > 0 and m > 0
|
assert n > 0 and m > 0
|
||||||
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
|
assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
|
||||||
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
|
assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.float
|
assert out.dtype == torch.float
|
||||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||||
|
|
||||||
# LHS and RHS scales must be transposed for TMA load
|
# LHS and RHS scales must be transposed for TMA load
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
def get_valid_scales(scales: torch.Tensor, mn: int):
|
||||||
lhs_scales = lhs_scales.permute(1, 0)
|
if scales.shape == (ceil_div(k, 128), mn):
|
||||||
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
|
# For k-grouped GEMMs
|
||||||
else:
|
scales = scales.permute(1, 0)
|
||||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
|
||||||
assert lhs_scales.stride(0) == 1
|
else:
|
||||||
|
scales = get_col_major_tma_aligned_tensor(scales)
|
||||||
|
return scales
|
||||||
|
|
||||||
if rhs_scales.shape == ((k + 127) // 128, n):
|
lhs_scales = get_valid_scales(lhs_scales, m)
|
||||||
rhs_scales = rhs_scales.permute(1, 0)
|
rhs_scales = get_valid_scales(rhs_scales, n)
|
||||||
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
|
|
||||||
else:
|
|
||||||
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
|
|
||||||
assert rhs_scales.stride(0) == 1
|
|
||||||
|
|
||||||
# Do nothing if `k` is zero
|
# Do nothing if `k` is zero
|
||||||
if k == 0:
|
if k == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# K must be aligned to 128
|
# K must be aligned to 128
|
||||||
aligned_k = (k + 127) // 128 * 128
|
aligned_k = ceil_div(k, 128) * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||||
num_last_stages = (k + 127) // 128 % num_stages
|
num_last_stages = ceil_div(k, 128) % num_stages
|
||||||
block_k = 128
|
block_k = 128
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
@@ -151,10 +149,10 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
k = batch_sizes[i]
|
k = batch_sizes[i]
|
||||||
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||||
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||||
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||||
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||||
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
|
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
|
||||||
|
|
||||||
lhs_offset += m * k
|
lhs_offset += m * k
|
||||||
rhs_offset += n * k
|
rhs_offset += n * k
|
||||||
scales_offset += (k + 127) // 128
|
scales_offset += ceil_div(k, 128)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
|||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
x_fp8 = per_token_cast_to_fp8(x)
|
x_fp8 = per_token_cast_to_fp8(x)
|
||||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
|||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
@@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
|||||||
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
||||||
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
|
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
|
||||||
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
||||||
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
||||||
|
|
||||||
@@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
|||||||
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
||||||
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
||||||
|
|
||||||
num_scales = (k + 127) // 128
|
num_scales = ceil_div(k, 128)
|
||||||
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
||||||
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user