Some lints and refactor

This commit is contained in:
Chenggang Zhao
2025-05-06 17:23:35 +08:00
parent 8aff6309d4
commit 981cc58932
18 changed files with 421 additions and 449 deletions

View File

@@ -3,8 +3,8 @@ import torch
from functools import lru_cache
from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .runtime import FP8GemmRuntime, generate
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@@ -122,7 +122,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
assert best_smem_config is not None
assert best_num_stages is not None
# Decide the number of TMA multicast and whether broadcast on A
# Decide the number of TMA multicasts and whether broadcast on A
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
@@ -155,13 +155,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
"""
@@ -183,7 +183,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
@@ -201,11 +201,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k)
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n)
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
@@ -237,7 +237,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel