mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-17 01:31:15 +00:00
commit
d5b974da2b
@ -5,7 +5,7 @@ from .jit_kernels import (
|
|||||||
gemm_fp8_fp8_bf16_nt,
|
gemm_fp8_fp8_bf16_nt,
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||||
cell_div,
|
ceil_div,
|
||||||
set_num_sms, get_num_sms,
|
set_num_sms, get_num_sms,
|
||||||
get_col_major_tma_aligned_tensor,
|
get_col_major_tma_aligned_tensor,
|
||||||
get_m_alignment_for_contiguous_layout
|
get_m_alignment_for_contiguous_layout
|
||||||
|
@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||||
// Scaling checks
|
// Scaling checks
|
||||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||||
|
|
||||||
// Types
|
// Types
|
||||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||||
@ -54,14 +54,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
|
||||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||||
|
|
||||||
// Configs
|
// Configs
|
||||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
||||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||||
const uint32_t lane_idx = get_lane_id();
|
const uint32_t lane_idx = get_lane_id();
|
||||||
|
|
||||||
@ -218,7 +218,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
// Load B scales with math warp-groups
|
// Load B scales with math warp-groups
|
||||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||||
if (threadIdx.x >= 32) {
|
if (threadIdx.x >= 32) {
|
||||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||||
@ -414,10 +414,10 @@ public:
|
|||||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||||
// Make TMA aligned to 16 bytes
|
// Make TMA aligned to 16 bytes
|
||||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||||
|
|
||||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ enum class GemmType {
|
|||||||
template <GemmType kGemmType,
|
template <GemmType kGemmType,
|
||||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
||||||
uint32_t kNumNBlocksPerGroup = 16>
|
uint32_t kNumNBlocksPerGroup = 16>
|
||||||
struct Scheduler {
|
struct Scheduler {
|
||||||
int current_iter = -1;
|
int current_iter = -1;
|
||||||
@ -30,7 +30,7 @@ struct Scheduler {
|
|||||||
|
|
||||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||||
int* grouped_layout = nullptr) {
|
int* grouped_layout = nullptr) {
|
||||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||||
if constexpr (kGemmType == GemmType::Normal) {
|
if constexpr (kGemmType == GemmType::Normal) {
|
||||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||||
@ -79,7 +79,7 @@ struct Scheduler {
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Within current group
|
// Within current group
|
||||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||||
break;
|
break;
|
||||||
|
@ -43,6 +43,6 @@ do {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
__device__ __host__ constexpr T ceil_div(T a, T b) {
|
||||||
return (a + b - 1) / b;
|
return (a + b - 1) / b;
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ from .m_grouped_gemm import (
|
|||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
cell_div, set_num_sms, get_num_sms,
|
ceil_div, set_num_sms, get_num_sms,
|
||||||
get_col_major_tma_aligned_tensor,
|
get_col_major_tma_aligned_tensor,
|
||||||
get_m_alignment_for_contiguous_layout
|
get_m_alignment_for_contiguous_layout
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from .tuner import jit_tuner
|
from .tuner import jit_tuner
|
||||||
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||||
|
|
||||||
# C++ code templates
|
# C++ code templates
|
||||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||||
@ -42,7 +42,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
|
|||||||
smem_a_per_stage = block_m * block_k
|
smem_a_per_stage = block_m * block_k
|
||||||
smem_scales_a_per_stage = block_m * 4
|
smem_scales_a_per_stage = block_m * 4
|
||||||
smem_b_per_stage = block_n * block_k
|
smem_b_per_stage = block_n * block_k
|
||||||
smem_scales_b = cell_div(k, block_k) * 4
|
smem_scales_b = ceil_div(k, block_k) * 4
|
||||||
smem_barrier = num_stages * 8 * 2
|
smem_barrier = num_stages * 8 * 2
|
||||||
|
|
||||||
smem_size = 0
|
smem_size = 0
|
||||||
@ -65,8 +65,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
block_ns = tuple(range(16, 129, 8))
|
block_ns = tuple(range(16, 129, 8))
|
||||||
|
|
||||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||||
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
|
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
|
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||||
|
|
||||||
# Decide block sizes by waves
|
# Decide block sizes by waves
|
||||||
best_block_m, best_block_n = None, None
|
best_block_m, best_block_n = None, None
|
||||||
|
@ -29,7 +29,7 @@ def get_num_sms() -> int:
|
|||||||
return _num_sms
|
return _num_sms
|
||||||
|
|
||||||
|
|
||||||
def cell_div(x: int, y: int) -> int:
|
def ceil_div(x: int, y: int) -> int:
|
||||||
"""
|
"""
|
||||||
Perform ceiling division of two integers.
|
Perform ceiling division of two integers.
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
|||||||
tma_alignment_bytes = 16
|
tma_alignment_bytes = 16
|
||||||
assert tma_alignment_bytes % element_size == 0
|
assert tma_alignment_bytes % element_size == 0
|
||||||
alignment = tma_alignment_bytes // element_size
|
alignment = tma_alignment_bytes // element_size
|
||||||
return cell_div(x, alignment) * alignment
|
return ceil_div(x, alignment) * alignment
|
||||||
|
|
||||||
|
|
||||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
|
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||||
|
|
||||||
|
|
||||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -17,7 +17,7 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|||||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert x.dim() == 2
|
assert x.dim() == 2
|
||||||
m, n = x.shape
|
m, n = x.shape
|
||||||
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
||||||
x_padded[:m, :n] = x
|
x_padded[:m, :n] = x
|
||||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
Loading…
Reference in New Issue
Block a user