mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Support unaligned n,k and gmem stride
This commit is contained in:
parent
adf5de0244
commit
919f55be9c
@ -176,8 +176,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
n, k_ = rhs.shape
|
n, k_ = rhs.shape
|
||||||
m_, n_ = out.shape
|
m_, n_ = out.shape
|
||||||
|
|
||||||
assert n % 64 == 0 and k % 128 == 0
|
|
||||||
|
|
||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert n > 0 and k > 0
|
assert n > 0 and k > 0
|
||||||
@ -186,7 +184,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||||
|
|
||||||
|
lhs_stride = lhs.stride(0)
|
||||||
|
rhs_stride = rhs.stride(0)
|
||||||
|
out_stride = out.stride(0)
|
||||||
|
|
||||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||||
@ -197,6 +199,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
if m == 0:
|
if m == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
aligned_n = (n + 63) // 64 * 64
|
||||||
|
aligned_k = (k + 127) // 128 * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||||
@ -206,11 +211,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
tensor_map_a = make_2d_tma_a_desc(
|
tensor_map_a = make_2d_tma_a_desc(
|
||||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
|
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
||||||
tensor_map_b = make_2d_tma_b_desc(
|
tensor_map_b = make_2d_tma_b_desc(
|
||||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
|
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
||||||
tensor_map_d = make_2d_tma_d_desc(
|
tensor_map_d = make_2d_tma_d_desc(
|
||||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
|
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
||||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||||
|
|
||||||
@ -235,7 +240,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||||
name='gemm_fp8_fp8_bf16_nt',
|
name='gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'SWIZZLE_D_MODE': smem_config[1],
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
'BLOCK_N_PADDING': smem_config[2],
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
'NUM_STAGES': num_stages,
|
'NUM_STAGES': num_stages,
|
||||||
|
|||||||
@ -13,11 +13,14 @@ from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout
|
|||||||
|
|
||||||
|
|
||||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
assert x.dim() == 2
|
||||||
m, n = x.shape
|
m, n = x.shape
|
||||||
|
pad_size = (128 - (n % 128)) % 128
|
||||||
|
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||||
x_view = x.view(m, -1, 128)
|
x_view = x.view(m, -1, 128)
|
||||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||||
|
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||||
|
|
||||||
|
|
||||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -161,7 +164,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
|||||||
def test_gemm() -> None:
|
def test_gemm() -> None:
|
||||||
print('Testing GEMM:')
|
print('Testing GEMM:')
|
||||||
for m in (64, 128, 4096):
|
for m in (64, 128, 4096):
|
||||||
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
||||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||||
diff = calc_diff(out, ref_out)
|
diff = calc_diff(out, ref_out)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user