diff --git a/README.md b/README.md index ca48284..5b9388a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. -Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function comprising around **~300 lines of code**. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques. +Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques. Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. @@ -13,7 +13,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## Roadmap - [ ] More correctness tests for grouped-contiguous layout -- [ ] Shared memory swizzling for output instead of padding +- [x] Shared memory swizzling for output - [ ] Larger block size on N (up to 256) - [ ] MoE scheduler with TMA multicast compatibility - [ ] Weight gradient kernels for dense models @@ -29,55 +29,6 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [ ] BF16 kernels - [ ] Split/stream-k optimizations -## Performance - -We test all shapes potentially used in DeepSeek-V3/R1 inference (including both prefilling and decoding, but without tensor parallelism) on H800 SXM5 with NVCC 12.8. All speedup metrics are calculated in comparison to our internally and carefully optimized implementation based on CUTLASS 3.6. - -DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed if you are interested. - -### Normal GEMMs for dense models - -| M | N | K | Computation | Memory bandwidth | Speedup | -|:----:|:-----:|:-----:|:-----------:|:----------------:|:-------:| -| 64 | 2112 | 7168 | 206 TFLOPS | 1688 GB/s | 2.7x | -| 64 | 24576 | 1536 | 289 TFLOPS | 2455 GB/s | 1.7x | -| 64 | 32768 | 512 | 219 TFLOPS | 2143 GB/s | 1.8x | -| 64 | 7168 | 16384 | 336 TFLOPS | 2668 GB/s | 1.4x | -| 64 | 4096 | 7168 | 287 TFLOPS | 2320 GB/s | 1.4x | -| 64 | 7168 | 2048 | 295 TFLOPS | 2470 GB/s | 1.7x | -| 128 | 2112 | 7168 | 352 TFLOPS | 1509 GB/s | 2.4x | -| 128 | 24576 | 1536 | 535 TFLOPS | 2448 GB/s | 1.6x | -| 128 | 32768 | 512 | 358 TFLOPS | 2103 GB/s | 1.5x | -| 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x | -| 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x | -| 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x | -| 4096 | 2112 | 7168 | 1127 TFLOPS | 562 GB/s | 1.2x | -| 4096 | 24576 | 1536 | 1212 TFLOPS | 962 GB/s | 1.2x | -| 4096 | 32768 | 512 | 775 TFLOPS | 1620 GB/s | 1.2x | -| 4096 | 7168 | 16384 | 1520 TFLOPS | 384 GB/s | 1.4x | -| 4096 | 4096 | 7168 | 1410 TFLOPS | 541 GB/s | 1.3x | -| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x | - -### Grouped GEMMs for MoE models (contiguous layout) - -| #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup | -|:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:| -| 4 | 8192 | 4096 | 7168 | 1346 TFLOPS | 434 GB/s | 1.3x | -| 4 | 8192 | 7168 | 2048 | 1214 TFLOPS | 752 GB/s | 1.3x | -| 8 | 4096 | 4096 | 7168 | 1346 TFLOPS | 516 GB/s | 1.3x | -| 8 | 4096 | 7168 | 2048 | 1214 TFLOPS | 826 GB/s | 1.2x | - -### Grouped GEMMs for MoE models (masked layout) - -| #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup | -|:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:| -| 1 | 1024 | 4096 | 7168 | 1233 TFLOPS | 924 GB/s | 1.2x | -| 1 | 1024 | 7168 | 2048 | 925 TFLOPS | 968 GB/s | 1.2x | -| 2 | 512 | 4096 | 7168 | 1040 TFLOPS | 1288 GB/s | 1.2x | -| 2 | 512 | 7168 | 2048 | 916 TFLOPS | 1405 GB/s | 1.2x | -| 4 | 256 | 4096 | 7168 | 932 TFLOPS | 2064 GB/s | 1.1x | -| 4 | 256 | 7168 | 2048 | 815 TFLOPS | 2047 GB/s | 1.2x | - ## Quick start ### Requirements diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index a43ee2c..cfcb569 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -45,6 +45,7 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it template = 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); @@ -63,6 +65,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Types using WGMMA = typename FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); // Shared memory static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); @@ -86,6 +89,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + + // `tensor_map_d` is only used in swizzling mode + // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode + if constexpr (kSwizzleDMode > 0) + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); } __syncwarp(); @@ -345,6 +353,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, } }, num_former_iters); + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, + "Swizzling and padding are not compatible"); + // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll @@ -352,38 +371,65 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, auto m_offset = local_idx * WAVE_BLOCK_M; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), - __float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr int kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + // NOTES: padding must be zero for BF16 output + DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); + } + + // NOTES: only 16 lanes' addresses are used SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16 + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr ); } - - // Issue TMA store - cute::tma_store_fence(); - if (lane_idx < 16) { - uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); - auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); - auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; - auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N; - cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16)); - } - __syncwarp(); } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - // Wait TMA to be finished - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + + // Wait TMA to be finished + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); } } #else @@ -395,6 +441,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, template @@ -410,6 +457,7 @@ public: const CUtensorMap& tma_a_desc, const CUtensorMap& tma_b_desc, const CUtensorMap& tma_scales_a_desc, + const CUtensorMap& tma_d_desc, cudaStream_t stream, int num_sms, uint32_t smem_size) { // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps @@ -418,6 +466,7 @@ public: auto kernel = fp8_gemm_kernel; @@ -442,7 +491,7 @@ public: auto status = cudaLaunchKernelEx(&config, kernel, gmem_d, scales_b, grouped_layout, shape_m, - tma_a_desc, tma_b_desc, tma_scales_a_desc); + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); DG_HOST_ASSERT(status == cudaSuccess); } @@ -458,6 +507,21 @@ public: SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); } + template + static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { + auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B; + if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B; + if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B; + + // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes + // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, + BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), + swizzle_mode); + } + template static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { // Make TMA aligned to 16 bytes diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 3cf20e3..c17d466 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime: '-gencode=arch=compute_90a,code=sm_90a', '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - '--diag-suppress=177,174,940'] + '--diag-suppress=39,174,177,940'] cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] include_dirs = [get_jit_include_dir()] diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 4b5db70..eab5442 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -17,20 +17,23 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto BLOCK_K = 128; constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; +constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; +constexpr auto kNumGroups = 1; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated GEMM -using gemm_t = Gemm; +using gemm_t = Gemm; // Launch kernel auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); gemm_t::run(out, rhs_scales, nullptr, m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, stream, num_sms, smem_size); """ @@ -41,15 +44,28 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 +def get_swizzle_mode(block_n: int) -> int: + # TODO: remove some candidates if slow + elem_size = 2 + for mode_bytes in (128, 64, 32): + if (block_n * elem_size) % mode_bytes == 0: + return mode_bytes + return 0 + + def get_block_n_padding_for_smem_d(block_n: int) -> int: + # NOTES: padding is for solving bank conflicts, but wastes shared memory space elem_size, requirement = 2, (4, 8) bank_stride = (block_n * elem_size) // 4 padding = (requirement[0] - bank_stride) % requirement[1] return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size -def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: - block_n_padding = get_block_n_padding_for_smem_d(block_n) +def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]: + # Try swizzle first, as it does not waste shared memory + swizzle_mode = get_swizzle_mode(block_n) + block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0 + smem_d = block_m * (block_n + block_n_padding) * 2 smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 @@ -64,13 +80,17 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: smem_size += num_stages * smem_b_per_stage smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += smem_barrier - return smem_size + + # Swizzle and padding are not compatible + assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 + + return smem_size, swizzle_mode, block_n_padding @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ - Tuple[int, int, int, int, Tuple[int, bool], int]: + Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: if not is_grouped_contiguous: block_ms = (64, 128, 256) else: @@ -109,16 +129,17 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced - best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 + best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 stage_candidates = (8, 7, 6, 5, 4, 3) if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = (4, 3) for num_stages in stage_candidates: - best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) + if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break + assert best_smem_config is not None assert best_num_stages is not None # Decide the number of TMA multicast and whether broadcast on A @@ -142,7 +163,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] assert num_min_sms <= num_sms - return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -192,12 +213,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0]) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n), + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 4ad321c..3b518c9 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -16,21 +16,23 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto BLOCK_K = 128; constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING}; +constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE}; constexpr auto kNumGroups = {NUM_GROUPS}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated grouped GEMM -using gemm_t = Gemm; +using gemm_t = Gemm; // Launch kernel auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m); +auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m); gemm_t::run(out, rhs_scales, grouped_layout, m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, stream, num_sms, smem_size); """ @@ -87,14 +89,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) args = (lhs, lhs_scales, rhs, rhs_scales, out, m_indices, m, num_groups, - torch.cuda.current_stream(), num_sms, smem_size) + torch.cuda.current_stream(), num_sms, smem_config[0]) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n), + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], @@ -165,7 +168,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) # Extra checks for TMA store if num_groups > 1 and m > block_m: @@ -173,11 +176,12 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] args = (lhs, lhs_scales, rhs, rhs_scales, out, masked_m, m, - torch.cuda.current_stream(), num_sms, smem_size) + torch.cuda.current_stream(), num_sms, smem_config[0]) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n), + 'SWIZZLE_D_MODE': smem_config[1], + 'BLOCK_N_PADDING': smem_config[2], 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0], diff --git a/indexing/main.cu b/indexing/main.cu index 426f8f5..8e86a30 100644 --- a/indexing/main.cu +++ b/indexing/main.cu @@ -11,18 +11,20 @@ int main() { constexpr int BLOCK_N = 128; constexpr int BLOCK_K = 128; constexpr int BLOCK_N_PADDING = 0; + constexpr int kSwizzleDMode = 0; constexpr int kNumGroups = 1; constexpr int kNumStages = 5; constexpr int kNumTMAMulticast = 1; constexpr bool kIsTMAMulticastOnA = false; - using gemm_t = Gemm; + using gemm_t = Gemm; auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m); auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0)); + auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast(0), m); auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast(0), m); gemm_t::run(nullptr, nullptr, nullptr, m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, nullptr, 132, 0); return 0; }