mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-08 17:14:02 +00:00
commit
d5b974da2b
@ -5,7 +5,7 @@ from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
cell_div,
|
||||
ceil_div,
|
||||
set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
|
@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
@ -54,14 +54,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
||||
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
@ -218,7 +218,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
@ -414,10 +414,10 @@ public:
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ enum class GemmType {
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
@ -30,7 +30,7 @@ struct Scheduler {
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
||||
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
@ -79,7 +79,7 @@ struct Scheduler {
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
@ -43,6 +43,6 @@ do {
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
||||
__device__ __host__ constexpr T ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .utils import (
|
||||
cell_div, set_num_sms, get_num_sms,
|
||||
ceil_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
@ -42,7 +42,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = cell_div(k, block_k) * 4
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
@ -65,8 +65,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
|
@ -29,7 +29,7 @@ def get_num_sms() -> int:
|
||||
return _num_sms
|
||||
|
||||
|
||||
def cell_div(x: int, y: int) -> int:
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
@ -71,7 +71,7 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return cell_div(x, alignment) * alignment
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -17,7 +17,7 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
||||
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
|
Loading…
Reference in New Issue
Block a user