From 96b31fd6bbe3eea736999776d50c706efab7fb36 Mon Sep 17 00:00:00 2001 From: AcraeaTerpsicore Date: Wed, 26 Feb 2025 18:37:22 +0800 Subject: [PATCH] fix typo --- deep_gemm/__init__.py | 2 +- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 12 ++++++------ deep_gemm/include/deep_gemm/scheduler.cuh | 6 +++--- deep_gemm/include/deep_gemm/utils.cuh | 2 +- deep_gemm/jit_kernels/__init__.py | 2 +- deep_gemm/jit_kernels/gemm.py | 8 ++++---- deep_gemm/jit_kernels/utils.py | 4 ++-- tests/test_core.py | 4 ++-- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 27932b0..15b22ca 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -5,7 +5,7 @@ from .jit_kernels import ( gemm_fp8_fp8_bf16_nt, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, - cell_div, + ceil_div, set_num_sms, get_num_sms, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index bf5249e..711649c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; @@ -54,14 +54,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); // Configs constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; - constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages); + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_id(); @@ -218,7 +218,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; #pragma unroll for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) @@ -414,10 +414,10 @@ public: static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { // Make TMA aligned to 16 bytes constexpr uint32_t kAlignment = 16 / sizeof(T); - shape_m = cell_div(shape_m, kAlignment) * kAlignment; + shape_m = ceil_div(shape_m, kAlignment) * kAlignment; return make_2d_tma_desc(global_address, Layout::ColMajor, - shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, + shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 5e1c211..329fbb0 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -13,7 +13,7 @@ enum class GemmType { template struct Scheduler { int current_iter = -1; @@ -30,7 +30,7 @@ struct Scheduler { __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, int* grouped_layout = nullptr) { - num_aligned_m_blocks = cell_div(shape_m, BLOCK_M); + num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); if constexpr (kGemmType == GemmType::Normal) { num_blocks = num_aligned_m_blocks * kNumNBlocks; } else if (kGemmType == GemmType::GroupedContiguous) { @@ -79,7 +79,7 @@ struct Scheduler { return false; // Within current group - num_m_blocks = cell_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 608945d..0005907 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -43,6 +43,6 @@ do { #endif template -__device__ __host__ constexpr T cell_div(T a, T b) { +__device__ __host__ constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; } diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py index d4c9aba..2a8624b 100644 --- a/deep_gemm/jit_kernels/__init__.py +++ b/deep_gemm/jit_kernels/__init__.py @@ -4,7 +4,7 @@ from .m_grouped_gemm import ( m_grouped_gemm_fp8_fp8_bf16_nt_masked ) from .utils import ( - cell_div, set_num_sms, get_num_sms, + ceil_div, set_num_sms, get_num_sms, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout ) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 0251b8c..1860b43 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -2,7 +2,7 @@ import torch from typing import Tuple from .tuner import jit_tuner -from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout +from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout # C++ code templates includes = ('"deep_gemm/fp8_gemm.cuh"', ) @@ -42,7 +42,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 smem_b_per_stage = block_n * block_k - smem_scales_b = cell_div(k, block_k) * 4 + smem_scales_b = ceil_div(k, block_k) * 4 smem_barrier = num_stages * 8 * 2 smem_size = 0 @@ -65,8 +65,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ns = tuple(range(16, 129, 8)) fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms) + get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) + get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) # Decide block sizes by waves best_block_m, best_block_n = None, None diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index 8ae50c9..703e1e2 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -29,7 +29,7 @@ def get_num_sms() -> int: return _num_sms -def cell_div(x: int, y: int) -> int: +def ceil_div(x: int, y: int) -> int: """ Perform ceiling division of two integers. @@ -71,7 +71,7 @@ def get_tma_aligned_size(x: int, element_size: int) -> int: tma_alignment_bytes = 16 assert tma_alignment_bytes % element_size == 0 alignment = tma_alignment_bytes // element_size - return cell_div(x, alignment) * alignment + return ceil_div(x, alignment) * alignment def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: diff --git a/tests/test_core.py b/tests/test_core.py index a227c3a..68d9b79 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,7 +3,7 @@ import torch from typing import Tuple import deep_gemm -from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor +from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -17,7 +17,7 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)