mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 23:44:22 +00:00
Solve STSM bank conflict via padding and 3D TMA
This commit is contained in:
parent
c57699ac93
commit
6db7e1863b
@ -40,6 +40,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
@ -50,11 +51,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
const __grid_constant__ std::pair<CUtensorMap, CUtensorMap> tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
@ -62,7 +63,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Shared memory
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
@ -82,7 +83,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
if constexpr (SHAPE_N >= BLOCK_N)
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.first));
|
||||
if constexpr (SHAPE_N % BLOCK_N != 0)
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d.second));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@ -141,8 +145,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
|
||||
constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
|
||||
// NOTES: for too-many branches (> 5), we disable this optimization
|
||||
@ -340,14 +344,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
@ -355,8 +359,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
if (n_block_idx < SHAPE_N / BLOCK_N) {
|
||||
// Except the last unaligned block
|
||||
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
} else {
|
||||
// The last unaligned block
|
||||
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
}
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
@ -371,6 +382,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
@ -386,14 +398,17 @@ public:
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
const std::pair<CUtensorMap, CUtensorMap>& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
BLOCK_N_PADDING,
|
||||
kNumGroups, kNumStages,
|
||||
kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
@ -433,11 +448,26 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
min(BLOCK_M, shape_m), BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
static std::pair<CUtensorMap, CUtensorMap> make_3d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
// NOTES: must be row-major
|
||||
auto m = shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1);
|
||||
uint64_t gmem_strides[2] = {BLOCK_N * sizeof(T), SHAPE_N * sizeof(T)};
|
||||
uint32_t smem_dim[3] = {BLOCK_N + BLOCK_N_PADDING, 1, BLOCK_M};
|
||||
|
||||
// `SHAPE_N % BLOCK_N` maybe not zero, dividing them into two parts
|
||||
CUtensorMap aligned;
|
||||
if constexpr (SHAPE_N >= BLOCK_N) {
|
||||
uint64_t gmem_dim[3] = {BLOCK_N, SHAPE_N / BLOCK_N, m};
|
||||
aligned = make_3d_tma_copy_desc(global_address, gmem_dim, gmem_strides, smem_dim);
|
||||
}
|
||||
|
||||
CUtensorMap unaligned;
|
||||
if constexpr (SHAPE_N % BLOCK_N != 0) {
|
||||
uint64_t gmem_dim[3] = {SHAPE_N % BLOCK_N, 1, m};
|
||||
unaligned = make_3d_tma_copy_desc(global_address + (SHAPE_N / BLOCK_N) * BLOCK_N,
|
||||
gmem_dim, gmem_strides, smem_dim);
|
||||
}
|
||||
return {aligned, unaligned};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -63,16 +63,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
CUtensorMap tensor_map = {};
|
||||
uint64_t global_stride[1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[2] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
@ -81,6 +80,27 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3],
|
||||
uint64_t gmem_strides[2], uint32_t smem_dim[3],
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map = {};
|
||||
uint32_t elem_strides[3] = {1, 1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 3,
|
||||
global_address, gmem_dim, gmem_strides, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
|
@ -48,6 +48,6 @@ __device__ __host__ constexpr T ceil_div(T a, T b) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T gcd(T a, T b) {
|
||||
return b == 0 ? a : gcd(b, a % b);
|
||||
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
|
||||
return b == 0 ? a : constexpr_gcd(b, a % b);
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ def put(path, data, is_binary=False):
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
nvcc_flags = ['-std=c++20', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
|
@ -14,22 +14,24 @@ using namespace deep_gemm;
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
@ -39,8 +41,16 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
|
||||
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
elem_size, requirement = 2, (4, 8)
|
||||
bank_stride = (block_n * elem_size) // 4
|
||||
padding = (requirement[0] - bank_stride) % requirement[1]
|
||||
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
|
||||
|
||||
|
||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
||||
smem_d = block_m * block_n * 2
|
||||
block_n_padding = get_block_n_padding_for_smem_d(block_n)
|
||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
@ -91,10 +101,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
stage_candidates = (8, 7, 6, 5, 4)
|
||||
stage_candidates = (8, 7, 6, 5, 4, 3)
|
||||
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = (4, )
|
||||
stage_candidates = (4, 3)
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
@ -119,7 +129,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
assert num_min_sms <= num_sms
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size
|
||||
@ -177,6 +187,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
@ -14,22 +14,25 @@ using namespace deep_gemm;
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kNumGroups = {NUM_GROUPS};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
@ -91,7 +94,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
@ -172,7 +177,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
|
Loading…
Reference in New Issue
Block a user