Refactor runtime

This commit is contained in:
Chenggang Zhao
2025-05-06 17:45:42 +08:00
parent 981cc58932
commit 317e83581d
6 changed files with 177 additions and 170 deletions

View File

@@ -3,8 +3,10 @@ import torch
from functools import lru_cache
from typing import Tuple
from .runtime import FP8GemmRuntime, generate
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@@ -238,7 +240,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)