Merge pull request #16 from AcraeaTerpsicore/patch-1

Fix typos
This commit is contained in:
Chenggang Zhao 2025-02-27 10:34:12 +08:00 committed by GitHub
commit d5b974da2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 20 additions and 20 deletions

View File

@ -5,7 +5,7 @@ from .jit_kernels import (
gemm_fp8_fp8_bf16_nt, gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked, m_grouped_gemm_fp8_fp8_bf16_nt_masked,
cell_div, ceil_div,
set_num_sms, get_num_sms, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout get_m_alignment_for_contiguous_layout

View File

@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks // Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
// Types // Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type; using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
@ -54,14 +54,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K); static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
// Configs // Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M); constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages); constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id(); const uint32_t lane_idx = get_lane_id();
@ -218,7 +218,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Load B scales with math warp-groups // Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) { if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll #pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
@ -414,10 +414,10 @@ public:
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
// Make TMA aligned to 16 bytes // Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T); constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = cell_div(shape_m, kAlignment) * kAlignment; shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor, return make_2d_tma_desc(global_address, Layout::ColMajor,
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
} }

View File

@ -13,7 +13,7 @@ enum class GemmType {
template <GemmType kGemmType, template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16> uint32_t kNumNBlocksPerGroup = 16>
struct Scheduler { struct Scheduler {
int current_iter = -1; int current_iter = -1;
@ -30,7 +30,7 @@ struct Scheduler {
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
int* grouped_layout = nullptr) { int* grouped_layout = nullptr) {
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M); num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
if constexpr (kGemmType == GemmType::Normal) { if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_aligned_m_blocks * kNumNBlocks; num_blocks = num_aligned_m_blocks * kNumNBlocks;
} else if (kGemmType == GemmType::GroupedContiguous) { } else if (kGemmType == GemmType::GroupedContiguous) {
@ -79,7 +79,7 @@ struct Scheduler {
return false; return false;
// Within current group // Within current group
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks; auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks) if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break; break;

View File

@ -43,6 +43,6 @@ do {
#endif #endif
template <typename T> template <typename T>
__device__ __host__ constexpr T cell_div(T a, T b) { __device__ __host__ constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }

View File

@ -4,7 +4,7 @@ from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked m_grouped_gemm_fp8_fp8_bf16_nt_masked
) )
from .utils import ( from .utils import (
cell_div, set_num_sms, get_num_sms, ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout get_m_alignment_for_contiguous_layout
) )

View File

@ -2,7 +2,7 @@ import torch
from typing import Tuple from typing import Tuple
from .tuner import jit_tuner from .tuner import jit_tuner
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
# C++ code templates # C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', ) includes = ('"deep_gemm/fp8_gemm.cuh"', )
@ -42,7 +42,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_a_per_stage = block_m * block_k smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4 smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k smem_b_per_stage = block_n * block_k
smem_scales_b = cell_div(k, block_k) * 4 smem_scales_b = ceil_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2 smem_barrier = num_stages * 8 * 2
smem_size = 0 smem_size = 0
@ -65,8 +65,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ns = tuple(range(16, 129, 8)) block_ns = tuple(range(16, 129, 8))
fix_wave_saturate = lambda x: num_sms if x == 0 else x fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None) get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms) get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves # Decide block sizes by waves
best_block_m, best_block_n = None, None best_block_m, best_block_n = None, None

View File

@ -29,7 +29,7 @@ def get_num_sms() -> int:
return _num_sms return _num_sms
def cell_div(x: int, y: int) -> int: def ceil_div(x: int, y: int) -> int:
""" """
Perform ceiling division of two integers. Perform ceiling division of two integers.
@ -71,7 +71,7 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
tma_alignment_bytes = 16 tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0 assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size alignment = tma_alignment_bytes // element_size
return cell_div(x, alignment) * alignment return ceil_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:

View File

@ -3,7 +3,7 @@ import torch
from typing import Tuple from typing import Tuple
import deep_gemm import deep_gemm
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -17,7 +17,7 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 assert x.dim() == 2
m, n = x.shape m, n = x.shape
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)