This commit is contained in:
AcraeaTerpsicore
2025-02-26 18:37:22 +08:00
committed by GitHub
parent bc989405fe
commit 96b31fd6bb
8 changed files with 20 additions and 20 deletions

View File

@@ -2,7 +2,7 @@ import torch
from typing import Tuple
from .tuner import jit_tuner
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
@@ -42,7 +42,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = cell_div(k, block_k) * 4
smem_scales_b = ceil_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2
smem_size = 0
@@ -65,8 +65,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ns = tuple(range(16, 129, 8))
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves
best_block_m, best_block_n = None, None