This commit is contained in:
AcraeaTerpsicore
2025-02-26 18:37:22 +08:00
committed by GitHub
parent bc989405fe
commit 96b31fd6bb
8 changed files with 20 additions and 20 deletions

View File

@@ -4,7 +4,7 @@ from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .utils import (
cell_div, set_num_sms, get_num_sms,
ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout
)

View File

@@ -2,7 +2,7 @@ import torch
from typing import Tuple
from .tuner import jit_tuner
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
@@ -42,7 +42,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = cell_div(k, block_k) * 4
smem_scales_b = ceil_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2
smem_size = 0
@@ -65,8 +65,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ns = tuple(range(16, 129, 8))
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves
best_block_m, best_block_n = None, None

View File

@@ -29,7 +29,7 @@ def get_num_sms() -> int:
return _num_sms
def cell_div(x: int, y: int) -> int:
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
@@ -71,7 +71,7 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return cell_div(x, alignment) * alignment
return ceil_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: