diff --git a/README.md b/README.md index 6266863..a55311e 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ The library also provides some environment variables, which may be useful: - `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default +- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler - `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization - `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output - `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index d9ab480..41e563e 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -40,6 +40,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it template tensor_map_d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; @@ -62,7 +63,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Shared memory static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); @@ -82,7 +83,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + if constexpr (SHAPE_N >= BLOCK_N) + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d.first)); + if constexpr (SHAPE_N % BLOCK_N != 0) + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d.second)); } __syncwarp(); @@ -141,8 +145,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, struct DivisibleK {}; struct NotDivisibleK {}; auto launch_k_iterations = [](const auto& func, int num_former_iters) { - constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8; + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; // NOTES: for too-many branches (> 5), we disable this optimization @@ -340,14 +344,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) ); } if constexpr (WGMMA::kNumAccum % 8 != 0) { SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 ); } cute::tma_store_fence(); @@ -355,8 +359,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Use TMA store to write back to global memory if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + if (n_block_idx < SHAPE_N / BLOCK_N) { + // Except the last unaligned block + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + } else { + // The last unaligned block + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + } cute::tma_store_arrive(); cute::tma_store_wait<0>(); } @@ -371,6 +382,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, template @@ -386,14 +398,17 @@ public: const CUtensorMap& tma_a_desc, const CUtensorMap& tma_b_desc, const CUtensorMap& tma_scales_a_desc, - const CUtensorMap& tma_d_desc, + const std::pair& tma_d_desc, cudaStream_t stream, int num_sms, uint32_t smem_size) { // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps constexpr uint32_t kNumTMAThreads = 128; constexpr uint32_t kNumMathThreadsPerGroup = 128; - auto kernel = fp8_gemm_kernel; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); @@ -433,11 +448,26 @@ public: } template - static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - min(BLOCK_M, shape_m), BLOCK_N, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + static std::pair make_3d_tma_d_desc(T* global_address, uint32_t shape_m) { + // NOTES: must be row-major + auto m = shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1); + uint64_t gmem_strides[2] = {BLOCK_N * sizeof(T), SHAPE_N * sizeof(T)}; + uint32_t smem_dim[3] = {BLOCK_N + BLOCK_N_PADDING, 1, BLOCK_M}; + + // `SHAPE_N % BLOCK_N` maybe not zero, dividing them into two parts + CUtensorMap aligned; + if constexpr (SHAPE_N >= BLOCK_N) { + uint64_t gmem_dim[3] = {BLOCK_N, SHAPE_N / BLOCK_N, m}; + aligned = make_3d_tma_copy_desc(global_address, gmem_dim, gmem_strides, smem_dim); + } + + CUtensorMap unaligned; + if constexpr (SHAPE_N % BLOCK_N != 0) { + uint64_t gmem_dim[3] = {SHAPE_N % BLOCK_N, 1, m}; + unaligned = make_3d_tma_copy_desc(global_address + (SHAPE_N / BLOCK_N) * BLOCK_N, + gmem_dim, gmem_strides, smem_dim); + } + return {aligned, unaligned}; } template diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index c938c4d..d3db1b6 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -63,16 +63,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr) { - CUtensorMap tensor_map{}; - constexpr uint32_t rank = 2; - uint64_t global_stride[rank - 1] = {stride_in_bytes}; - uint32_t elem_strides[rank] = {1, 1}; + CUtensorMap tensor_map = {}; + uint64_t global_stride[1] = {stride_in_bytes}; + uint32_t elem_strides[2] = {1, 1}; if (encode_func == nullptr) encode_func = get_cuTensorMapEncodeTiled(); auto result = encode_func( - &tensor_map, get_CUtensorMapDataType::type>(), rank, + &tensor_map, get_CUtensorMapDataType>(), 2, global_address, gmem_dim, global_stride, smem_dim, elem_strides, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, @@ -81,6 +80,27 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], return tensor_map; } +template +CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3], + uint64_t gmem_strides[2], uint32_t smem_dim[3], + PFN_cuTensorMapEncodeTiled encode_func = nullptr) { + CUtensorMap tensor_map = {}; + uint32_t elem_strides[3] = {1, 1, 1}; + + if (encode_func == nullptr) + encode_func = get_cuTensorMapEncodeTiled(); + + auto result = encode_func( + &tensor_map, get_CUtensorMapDataType>(), 3, + global_address, gmem_dim, gmem_strides, smem_dim, elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + DG_HOST_ASSERT(result == CUDA_SUCCESS); + return tensor_map; +} + template __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index fe2c016..9b93af5 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -48,6 +48,6 @@ __device__ __host__ constexpr T ceil_div(T a, T b) { } template -__device__ __host__ constexpr T gcd(T a, T b) { - return b == 0 ? a : gcd(b, a % b); +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); } diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index fec2eb9..3cf20e3 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -96,7 +96,8 @@ def put(path, data, is_binary=False): def build(name: str, arg_defs: tuple, code: str) -> Runtime: # Compiler flags - nvcc_flags = ['-std=c++20', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', + cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) + nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', '-gencode=arch=compute_90a,code=sm_90a', '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 65b44ff..cec83fb 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -14,22 +14,24 @@ using namespace deep_gemm; constexpr auto N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto BLOCK_K = 128; +constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated GEMM -using GemmType = Gemm; +using gemm_t = Gemm; // Launch kernel -auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); -GemmType::run(out, rhs_scales, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); +auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m); +gemm_t::run(out, rhs_scales, nullptr, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + stream, num_sms, smem_size); """ @@ -39,8 +41,16 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 +def get_block_n_padding_for_smem_d(block_n: int) -> int: + elem_size, requirement = 2, (4, 8) + bank_stride = (block_n * elem_size) // 4 + padding = (requirement[0] - bank_stride) % requirement[1] + return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size + + def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: - smem_d = block_m * block_n * 2 + block_n_padding = get_block_n_padding_for_smem_d(block_n) + smem_d = block_m * (block_n + block_n_padding) * 2 smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 smem_b_per_stage = block_n * block_k @@ -91,10 +101,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - stage_candidates = (8, 7, 6, 5, 4) + stage_candidates = (8, 7, 6, 5, 4, 3) if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: # Unrolling both stages and `num_former_iters` will cause large code size - stage_candidates = (4, ) + stage_candidates = (4, 3) for num_stages in stage_candidates: best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) if best_smem_size <= sm90_capacity: @@ -119,7 +129,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0] + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] assert num_min_sms <= num_sms return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size @@ -177,6 +187,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n), 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index bffe137..908e9e8 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,7 +1,7 @@ import torch from typing import Tuple -from .gemm import get_best_configs +from .gemm import get_best_configs, get_block_n_padding_for_smem_d from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms @@ -14,22 +14,25 @@ using namespace deep_gemm; constexpr auto N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; +constexpr auto BLOCK_K = 128; +constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; +constexpr auto kNumGroups = {NUM_GROUPS}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated grouped GEMM -using GemmType = Gemm; +using gemm_t = Gemm; // Launch kernel -auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); -GemmType::run(out, rhs_scales, grouped_layout, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); +auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); +auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); +auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = gemm_t::make_3d_tma_d_desc(out, m); +gemm_t::run(out, rhs_scales, grouped_layout, + m, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, + stream, num_sms, smem_size); """ @@ -91,7 +94,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n), + 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], @@ -172,7 +177,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + 'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n), + 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],