mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Support unaligned n,k and gmem stride
This commit is contained in:
@@ -176,8 +176,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
@@ -186,7 +184,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
lhs_stride = lhs.stride(0)
|
||||
rhs_stride = rhs.stride(0)
|
||||
out_stride = out.stride(0)
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
@@ -197,6 +199,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
aligned_n = (n + 63) // 64 * 64
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
@@ -206,11 +211,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
|
||||
@@ -235,7 +240,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_STAGES': num_stages,
|
||||
|
||||
Reference in New Issue
Block a user