Support unaligned n,k and gmem stride

This commit is contained in:
Zhean Xu 2025-05-09 12:55:46 +08:00
parent adf5de0244
commit 919f55be9c
2 changed files with 18 additions and 10 deletions

View File

@ -176,8 +176,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape n, k_ = rhs.shape
m_, n_ = out.shape m_, n_ = out.shape
assert n % 64 == 0 and k % 128 == 0
# Type and shape checks # Type and shape checks
assert m == m_ and n == n_ and k == k_ assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0 assert n > 0 and k > 0
@ -186,7 +184,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16 assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# LHS scales must be transposed for TMA loads, but not for RHS scales # LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
@ -197,6 +199,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
if m == 0: if m == 0:
return return
aligned_n = (n + 63) // 64 * 64
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation # Auto-tuning with compilation
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
@ -206,11 +211,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
num_math_threads_per_group = 128 num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc( tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k, 1) GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
tensor_map_b = make_2d_tma_b_desc( tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n, 1) GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
tensor_map_d = make_2d_tma_d_desc( tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1]) GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc( tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k) GemmType.Normal, lhs_scales, m, k, block_m, block_k)
@ -235,7 +240,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
runtime, best_keys = jit_tuner.compile_and_tune( runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt', name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1], 'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2], 'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages, 'NUM_STAGES': num_stages,

View File

@ -13,11 +13,14 @@ from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0 assert x.dim() == 2
m, n = x.shape m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128) x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -161,7 +164,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
def test_gemm() -> None: def test_gemm() -> None:
print('Testing GEMM:') print('Testing GEMM:')
for m in (64, 128, 4096): for m in (64, 128, 4096):
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n) x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out) diff = calc_diff(out, ref_out)