Initial commit

This commit is contained in:
Chenggang Zhao 2025-02-25 22:52:41 +08:00
commit a6d97a1c1b
26 changed files with 3274 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
cmake-build-*
.idea
.DS_Store
build
dist
*.egg-info
*.pyc

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

213
README.md Normal file
View File

@ -0,0 +1,213 @@
# DeepGEMM
DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function comprising around **~300 lines of code**. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques.
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
## Performance
We test all shapes potentially used in DeepSeek-V3/R1 inference (including both prefilling and decoding, but without tensor parallelism) on H800 with NVCC 12.8. All speedup metrics are calculated in comparison to our internally and carefully optimized implementation based on CUTLASS 3.6.
DeepGEMM does not behavior very well on some shapes, optimization PRs are welcomed if you are interested.
### Normal GEMMs for dense models
| M | N | K | Computation | Memory bandwidth | Speedup |
|:----:|:-----:|:-----:|:-----------:|:----------------:|:-------:|
| 64 | 2112 | 7168 | 206 TFLOPS | 1688 GB/s | 2.7x |
| 64 | 24576 | 1536 | 289 TFLOPS | 2455 GB/s | 1.7x |
| 64 | 32768 | 512 | 219 TFLOPS | 2143 GB/s | 1.8x |
| 64 | 7168 | 16384 | 336 TFLOPS | 2668 GB/s | 1.4x |
| 64 | 4096 | 7168 | 287 TFLOPS | 2320 GB/s | 1.4x |
| 64 | 7168 | 2048 | 295 TFLOPS | 2470 GB/s | 1.7x |
| 128 | 2112 | 7168 | 352 TFLOPS | 1509 GB/s | 2.4x |
| 128 | 24576 | 1536 | 535 TFLOPS | 2448 GB/s | 1.6x |
| 128 | 32768 | 512 | 358 TFLOPS | 2103 GB/s | 1.5x |
| 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
| 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
| 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
| 4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x |
| 4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x |
| 4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x |
| 4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x |
| 4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x |
| 4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x |
### Grouped GEMMs for MoE models (contiguous layout)
| #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup |
|:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:|
| 4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x |
| 4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x |
| 8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x |
| 8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x |
### Grouped GEMMs for MoE models (masked layout)
| #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup |
|:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:|
| 1 | 1024 | 4096 | 7168 | 1233 TFLOPS | 924 GB/s | 1.2x |
| 1 | 1024 | 7168 | 2048 | 925 TFLOPS | 968 GB/s | 1.2x |
| 2 | 512 | 4096 | 7168 | 1040 TFLOPS | 1288 GB/s | 1.2x |
| 2 | 512 | 7168 | 2048 | 916 TFLOPS | 1405 GB/s | 1.2x |
| 4 | 256 | 4096 | 7168 | 932 TFLOPS | 2064 GB/s | 1.1x |
| 4 | 256 | 7168 | 2048 | 815 TFLOPS | 2047 GB/s | 1.2x |
## Quick start
### Requirements
- Hopper architecture GPUs, `sm_90a` must be supported
- Python 3.8 or above
- CUDA 12.3 or above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.1 or above
- CUTLASS 3.6 or above (could be cloned by Git submodule)
### Development
```bash
# Submodule must be cloned
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Test JIT compilation
python tests/test_jit.py
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py
```
### Installation
```bash
python setup.py install
```
Then, import `deep_gemm` in your Python project, and enjoy!
## Interfaces
#### Notices
This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
#### Normal dense GEMMs (non-grouped)
To perform a basic non-grouped FP8 GEMM, call the `deep_gemm.gemm_fp8_fp8_bf16_nt` function. For more details, please refer to the function documentation.
#### Grouped GEMMs (contiguous layout)
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape.
For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_m_alignment_for_contiguous_layout()`).
For more information, please refer to the `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` function documentation.
#### Grouped GEMMs (masked layout)
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
Use `m_grouped_gemm_fp8_fp8_bf16_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.
#### Utilities
The library provides some utility functions besides the above kernels:
- `deep_gemm.set_num_sms`: set the maximum SM count to use
- `deep_gemm.get_num_sms`: get the current SM maximum count
- `deep_gemm.get_m_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
- `deep_gemm.get_col_major_tma_aligned_tensor`: get a column-major TMA-aligned tensor
The library also provides some environment variables, which may be useful:
- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default
- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization
- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output
- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details
- `DG_JIT_PRINT_NVCC_COMMAND`: 0 or 1, print NVCC compilation command
- `DG_JIT_DEBUG`: 0 or 1, print more debugging information
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
## Optimizations
We indicate the techniques excluded from CUTLASS with 🐳.
#### Persistent warp-specialization
Following the CUTLASS design, the kernels in DeepGEMM are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below:
![design](figures/design.png)
#### Hopper TMA features
The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#tensor-memory-accelerator) (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for:
- TMA load for LHS, LHS scaling factors, and RHS matrices
- TMA store for the output matrix
- TMA multicast (exclusive to the LHS matrix)
- TMA descriptor prefetching
#### Common detail optimizations
- Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction
- [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups
- Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳
#### A unified and optimized block scheduler
- [One scheduler](deep_gemm/include/deep_gemm/scheduler.cuh) for all non-grouped and grouped kernels
- [Rasterization](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/media/docs/efficient_gemm.md#threadblock-rasterization) to enhance L2 cache reuse
#### Fully JIT design 🐳
DeepGEMM employs a fully [Just-In-Time](deep_gemm/jit) (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages:
- GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants
- Saving registers
- Compilers may do more optimizations
- Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size
- But without auto-tuning, the optimal one is deterministically selected
- Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities
- Very important for small shapes
- Refer to `launch_k_iterations` in [the kernel file](deep_gemm/include/deep_gemm/fp8_gemm.cuh) for details
Overall, JIT significantly improves performance for small shapes, similar to the approach of the [Triton](https://github.com/triton-lang/triton/) compiler.
#### Unaligned block sizes 🐳
For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with `M=256, N=7168`, a typical block size assignment of `BLOCK_M=128, BLOCK_N=128` results in only `(256 / 128) * (7168 / 128) = 112` out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling `(256 / 128) * (7168 / 112) = 128` SMs to work in such scenarios. Implementing this technique alongside fine-grained scaling requires careful optimization but ultimately delivers performance gains.
#### FFMA SASS interleaving 🐳
We observe a performance improvement in [the CUTLASS FP8 kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm) between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in [a series of `FADD` instructions](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/include/cutlass/gemm/collective/fp8_accumulation.hpp#L73) is flipped in an interleaving pattern.
After referencing some open-source [CUDA assembler](https://github.com/cloudcores/CuAssembler/blob/master/CuAsm/CuControlCode.py#L46) implementations, we identified that this bit controls `yield`, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work).
To leverage this, we develop [a similar script](deep_gemm/jit/interleave_ffma.py) to modify the `FFMA` instructions in the compiled binary. Besides simply modifying the `yield` bit, we also flip the `reuse` bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained scaling FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion `FFMA` instructions.
## Acknowledgement
DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers!
## License
This code repository is released under [the MIT License](LICENSE).
## Citation
```bibtex
@misc{deepgemm2025,
title={DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling},
author={Chenggang Zhao and Liang Zhao and Jiashi Li and Zhean Xu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}},
}
```

13
deep_gemm/__init__.py Normal file
View File

@ -0,0 +1,13 @@
import torch
from . import jit
from .jit_kernels import (
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
cell_div,
set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout
)
from .utils import bench, bench_kineto, calc_diff

View File

@ -0,0 +1,444 @@
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "mma_utils.cuh"
#include "scheduler.cuh"
#include "tma_utils.cuh"
#include "utils.cuh"
namespace deep_gemm {
enum class Layout {
RowMajor,
ColMajor
};
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast,
GemmType kGemmType>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b;
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
}
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
// Fill barriers
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(kNumIterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Issue TMA A with broadcasting
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}
};
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(s);
// Promote with scales
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Write back to shared memory using STSM
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast,
GemmType kGemmType>
class Gemm {
private:
using Barrier = cuda::barrier<cuda::thread_scope_block>;
public:
Gemm() = default;
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
// Launch
auto status = cudaLaunchKernelEx(&config, kernel,
gmem_d, scales_b, grouped_layout,
shape_m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
}
template <typename T>
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
return make_2d_tma_desc(global_address, Layout::ColMajor,
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor,
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_desc(
T* global_address, Layout layout,
uint32_t gmem_rows, uint32_t gmem_cols,
uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
};
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@ -0,0 +1,885 @@
#pragma once
#include <cuda.h>
#include "utils.cuh"
namespace deep_gemm {
struct SM90_64x16x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
" %8,"
" %9,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 16;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x24x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %14, 0;\n"
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11},"
" %12,"
" %13,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 24;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x32x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15},"
" %16,"
" %17,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 32;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x40x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 40;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x48x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23},"
" %24,"
" %25,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 48;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x56x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %30, 0;\n"
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27}, "
" %28,"
" %29,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 56;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x64x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31}, "
" %32,"
" %33,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 64;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x72x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %38, 0;\n"
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35}, "
" %36,"
" %37,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 72;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x80x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %42, 0;\n"
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39}, "
" %40,"
" %41,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 80;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x88x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %46, 0;\n"
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43}, "
" %44,"
" %45,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 88;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x96x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %50, 0;\n"
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47}, "
" %48,"
" %49,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 96;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x104x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %54, 0;\n"
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51}, "
" %52,"
" %53,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 104;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x112x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %58, 0;\n"
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55}, "
" %56,"
" %57,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 112;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x120x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %62, 0;\n"
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59}, "
" %60,"
" %61,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 120;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x128x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %66, 0;\n"
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63}, "
" %64,"
" %65,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 128;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x192x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %98, 0;\n"
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71, "
" %72, %73, %74, %75, %76, %77, %78, %79, "
" %80, %81, %82, %83, %84, %85, %86, %87, "
" %88, %89, %90, %91, %92, %93, %94, %95}, "
" %96,"
" %97,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 192;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
}
};
template <typename dtype_t>
struct SM90_U32x4_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
}
};
__device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
__device__ void warpgroup_commit_batch() {
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
}
__device__ void warpgroup_fence_operand(float& reg) {
asm volatile("" : "+f"(reg) :: "memory");
}
__forceinline__ __device__ uint32_t get_lane_id() {
uint32_t lane_id;
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
uint32_t ret;
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
int4 ret;
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
float ret;
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
template <int N>
__device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
}
union GmmaDescriptor {
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
uint16_t reg16_[4];
struct {
uint16_t start_address_: 14, : 2;
uint16_t leading_byte_offset_: 14, : 2;
uint16_t stride_byte_offset_: 14, : 2;
uint8_t : 1, base_offset_: 3, : 4;
uint8_t : 6, layout_type_: 2;
} bitfield;
// Decay to an `uint64_t`
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
};
template <class PointerType>
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
int leading_byte_offset = 0,
int stride_byte_offset = 1024) {
GmmaDescriptor desc;
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
template <int N>
struct FP8MMASelector {
static constexpr auto select_type() {
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
}
using type = decltype(select_type());
};
} // namespace deep_gemm

View File

@ -0,0 +1,103 @@
#include "utils.cuh"
namespace deep_gemm {
enum class GemmType {
Normal,
GroupedContiguous,
GroupedMasked
};
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
uint32_t num_aligned_m_blocks;
// For normal GEMM
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;
// For grouped GEMM
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
int* grouped_layout = nullptr) {
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
} else if (kGemmType == GemmType::GroupedContiguous) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::GroupedMasked) {
curr_group_idx = curr_cumsum = 0;
this->grouped_layout = grouped_layout;
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
}
template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if (kGemmType == GemmType::GroupedContiguous) {
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
return offset * shape_dim + block_idx * block_size;
} else if (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
}
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
if constexpr (kGemmType == GemmType::GroupedMasked) {
uint32_t num_m_blocks;
while (true) {
// End of the task
if (curr_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;
// Move to check the next group
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
} else {
if (next_block_idx >= num_blocks)
return false;
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@ -0,0 +1,96 @@
#pragma once
#include <cassert>
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/barrier>
#include "utils.cuh"
namespace deep_gemm {
template <class T>
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
if constexpr (std::is_same<T, uint8_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, uint16_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
} else if constexpr (std::is_same<T, uint32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
} else if constexpr (std::is_same<T, uint64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
} else if constexpr (std::is_same<T, int32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT32;
} else if constexpr (std::is_same<T, int64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT64;
} else if constexpr (std::is_same<T, __half>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if constexpr (std::is_same<T, float>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if constexpr (std::is_same<T, double>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
}
}
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;
#if CUDA_VERSION >= 12050
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
cudaEnableDefault, &driver_status);
#else
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
cudaEnableDefault, &driver_status);
#endif
if (driver_status != cudaDriverEntryPointSuccess)
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
}
template <typename T>
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map{};
constexpr uint32_t rank = 2;
uint64_t global_stride[rank - 1] = {stride_in_bytes};
uint32_t elem_strides[rank] = {1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
int32_t const& crd_0, int32_t const& crd_1) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
if constexpr (kNumTMAMulticast == 1) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
} else if (cute::block_rank_in_cluster() == 0) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
}
}
} // namespace deep_gemm

View File

@ -0,0 +1,48 @@
#pragma once
#include <exception>
#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif
class AssertionException : public std::exception {
private:
std::string message{};
public:
explicit AssertionException(const std::string& message) : message(message) {}
const char *what() const noexcept override { return message.c_str(); }
};
#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", \
__FILE__, __LINE__, #cond); \
throw AssertionException("Assertion failed: " #cond); \
} \
} while (0)
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif
template <typename T>
__device__ __host__ constexpr T cell_div(T a, T b) {
return (a + b - 1) / b;
}

View File

@ -0,0 +1,3 @@
from .compiler import get_nvcc_compiler, build
from .template import cpp_format, generate
from .runtime import Runtime

150
deep_gemm/jit/compiler.py Normal file
View File

@ -0,0 +1,150 @@
import hashlib
import functools
import os
import re
import subprocess
import uuid
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache()
def hash_to_hex(s: str) -> str:
md5 = hashlib.md5()
md5.update(s.encode('utf-8'))
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def get_jit_include_dir() -> str:
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
# Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm'
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f:
md5.update(f.read())
# Update `interleave_ffma.py`
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
md5.update(f.read())
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def get_nvcc_compiler() -> Tuple[str, str]:
paths = []
if os.getenv('DG_NVCC_COMPILER'):
paths.append(os.getenv('DG_NVCC_COMPILER'))
paths.append(f'{CUDA_HOME}/bin/nvcc')
# Try to find the first available NVCC compiler
least_version_required = '12.3'
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
match = version_pattern.search(os.popen(f'{path} --version').read())
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
return path, version
raise RuntimeError('Cannot find any available NVCC compiler')
@functools.lru_cache(maxsize=None)
def get_default_user_dir():
if 'DG_CACHE_DIR' in os.environ:
path = os.getenv('DG_CACHE_DIR')
os.makedirs(path, exist_ok=True)
return path
return os.path.expanduser('~') + '/.deep_gemm'
@functools.lru_cache(maxsize=None)
def get_tmp_dir():
return f'{get_default_user_dir()}/tmp'
@functools.lru_cache(maxsize=None)
def get_cache_dir():
return f'{get_default_user_dir()}/cache'
def make_tmp_dir():
tmp_dir = get_tmp_dir()
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
def put(path, data, is_binary=False):
# Write and do POSIX atomic replace
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
f.write(data)
os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Compiler flags
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=177,174,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
args_path = f'{path}/kernel.args'
src_path = f'{path}/kernel.cu'
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
put(src_path, code)
# Compile into a temporary SO file
so_path = f'{path}/kernel.so'
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_so_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
assert subprocess.check_call(command) == 0, f'Failed to compile {src_path}'
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_so_path)
# Atomic replace SO file
os.replace(tmp_so_path, so_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]

View File

@ -0,0 +1,138 @@
import argparse
import mmap
import os
import re
import subprocess
from torch.utils.cpp_extension import CUDA_HOME
def run_cuobjdump(file_path):
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
return result.stdout
def extract_ffma(sass):
lines = sass.splitlines()
collected = []
current = []
arch_name, func_name = 'N/A', 'N/A'
skip_next_line = False
for line in lines:
if 'code for' in line:
arch_name = line.lstrip().lstrip('code for ').rstrip()
elif 'Function :' in line:
func_name = line.lstrip().lstrip('Function :').rstrip()
elif 'FFMA' in line:
current.append(line)
skip_next_line = True
elif skip_next_line:
current.append(line)
skip_next_line = False
else:
if len(current) >= 16:
assert len(current) % 2 == 0
collected.append((f'{arch_name}::{func_name}', current))
current = []
if os.getenv('DG_PRINT_REG_REUSE', None):
print(f"Found {len(collected)} FFMA segments")
return collected
def extract_hex_from_line(line):
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
assert match
return int(match.group(1), 16)
def validate(m, offset, le_bytes, num_lines):
assert len(le_bytes) == num_lines // 2
assert m[offset:offset + 16] == le_bytes[0]
for i in range(1, num_lines // 2):
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
return False
return True
def parse_registers(line):
import re
line = re.sub(r'/\*.*?\*/', '', line)
line = line.replace(';', '')
tokens = line.strip().split(',')
registers = []
for token in tokens:
token = token.strip()
words = token.split()
for word in words:
if word.startswith('R'):
reg = word.split('.')[0]
registers.append(reg)
return registers
def modify_segment(m, name, ffma_lines):
num_lines = len(ffma_lines)
assert num_lines % 2 == 0
le_bytes, new_le_bytes = [], []
reused_list = []
dst_reg_set = set()
last_reused, last_dst_reg = False, ''
num_changed = 0
for i in range(num_lines // 2):
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
reused = (high_hex & 0x0800000000000000) != 0
if reused:
is_first_occurred = dst_reg not in dst_reg_set
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
# Modify the `reuse` and `yield` bits
assert high_hex & 0x0800200000000000, f"{hex(high_hex)}"
high_hex ^= 0x0800200000000000
reused = False
num_changed += 1
else:
reused_list.append(i)
dst_reg_set.add(dst_reg)
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg
if os.getenv('DG_PRINT_REG_REUSE', None):
print(f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}")
# Find the offset
offsets = []
offset = m.find(le_bytes[0])
while offset != -1:
offsets.append(offset)
offset = m.find(le_bytes[0], offset + 1)
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
# Replace with `new_le_bytes`
for offset in offsets:
for i in range(num_lines // 2):
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
def process(path):
if os.getenv('DG_PRINT_REG_REUSE', None):
print(f'Processing {path}')
output = run_cuobjdump(path)
segments = extract_ffma(output)
with open(path, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
for segment in segments:
modify_segment(mm, *segment)
mm.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
parser.add_argument('--so', help='Path to the SO file')
args = parser.parse_args()
process(args.so)

66
deep_gemm/jit/runtime.py Normal file
View File

@ -0,0 +1,66 @@
import ctypes
import os
import torch
from typing import Optional
from .template import map_ctype
class Runtime:
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.args = None
assert self.is_path_valid(self.path)
@staticmethod
def is_path_valid(path: str) -> bool:
# Exists and is a directory
if not os.path.exists(path) or not os.path.isdir(path):
return False
# Contains all necessary files
files = ['kernel.cu', 'kernel.args', 'kernel.so']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, *args) -> int:
# Load SO file
if self.lib is None or self.args is None:
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
self.args = eval(f.read())
# Check args and launch
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
cargs = []
for arg, (name, dtype) in zip(args, self.args):
if isinstance(arg, torch.Tensor):
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
else:
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
cargs.append(map_ctype(arg))
return_code = ctypes.c_int(0)
self.lib.launch(*cargs, ctypes.byref(return_code))
return return_code.value
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = Runtime(path)
self.cache[path] = runtime
return runtime
return None
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime

93
deep_gemm/jit/template.py Normal file
View File

@ -0,0 +1,93 @@
import copy
import ctypes
import os
import torch
from typing import Any, Iterable, Dict, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
torch.int: 'torch.int',
torch.float: 'torch.float',
torch.bfloat16: 'torch.bfloat16',
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
torch.cuda.Stream: 'torch.cuda.Stream',
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
}
# Type map for both Python API and source code usages
genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
torch.int: ('void*', 'int*'),
torch.float: ('void*', 'float*'),
torch.bfloat16: ('void*', '__nv_bfloat16*'),
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
torch.cuda.Stream: ('void*', 'cudaStream_t'),
}
def map_ctype(value: Any) -> Any:
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
if isinstance(value, torch.Tensor):
return ctype(value.data_ptr())
if isinstance(value, torch.cuda.Stream):
return ctype(value.cuda_stream)
return ctype(value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
new_template = new_template.replace(f'{{{key}}}', f'{value}')
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
# Common prefix
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
# Includes
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
# Function signature
raw = '__raw_'
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
code += f'extern "C" void launch('
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
code += ') {\n'
# Cast raw types
code += ' // Cast raw types (if needed)\n'
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
# Function body
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
# End the function
code += '}\n\n'
# Debug print
if os.getenv('DG_JIT_DEBUG', None):
print(f'Generated code:\n{code}')
return code

View File

@ -0,0 +1,10 @@
from .gemm import gemm_fp8_fp8_bf16_nt
from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .utils import (
cell_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout
)

View File

@ -0,0 +1,171 @@
import torch
from typing import Tuple
from .tuner import jit_tuner
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1:
return True
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
smem_d = block_m * block_n * 2
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b = cell_div(k, block_k) * 4
smem_barrier = num_stages * 8 * 2
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2)
smem_size += smem_barrier
return smem_size
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
if not is_grouped_contiguous:
# TODO: for some cases, smaller M block is better, add them into tuning space
block_ms = (64 if m <= 64 else 128, )
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8))
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves
best_block_m, best_block_n = None, None
for block_m in block_ms:
for block_n in block_ns:
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m))
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
if best_smem_size <= sm90_capacity:
best_num_stages = num_stages
break
assert best_num_stages is not None
# Decide the number of TMA multicast
best_num_tma_multicast = 1
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
best_num_tma_multicast = 2
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, k / 128]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[n / 128, k / 128]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
m_, n_ = out.shape
assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
return
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
# Run the kernel
runtime(*args)

View File

@ -0,0 +1,182 @@
import torch
from typing import Tuple
from .gemm import get_best_configs
from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
// Make a templated grouped GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
GemmType::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, m_indices: torch.Tensor) -> None:
"""
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, k / 128]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, n / 128, k / 128]`.
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
`m_indices[i]` records the group which the j-th row of the LHS belong to,
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
Values of `m_indices` in every-m-alignment-block must also be the same.
`-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
num_groups, n, k_ = rhs.shape
m_, n_ = out.shape
m__ = m_indices.numel()
# Type and shape checks
assert m == m_ == m__ and k == k_ and n == n_
assert lhs_scales.shape == (m, (k + 127) // 128)
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert m_indices.dtype == torch.int32
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous() and m_indices.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
return
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
# Run the kernel
runtime(*args)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
"""
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, k / 128]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, n / 128, k / 128]`.
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
in the i-th group.
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
correctly setting this value may lead to better performance.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
num_groups, m, k = lhs.shape
num_groups_, n, k_ = rhs.shape
num_groups__, m_, n_ = out.shape
num_groups___ = masked_m.numel()
# Type and shape checks
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert masked_m.dtype == torch.int32
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
# Run the kernel
runtime(*args)

View File

@ -0,0 +1,81 @@
import copy
import os
import torch
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if os.getenv('DG_JIT_DEBUG', None):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
assert args is not None
space = (dict(), ) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(includes, arg_defs, cpp_format(template, full_keys))
# Illegal build must raise errors
kernels.append((build(name, arg_defs, code), tuned_keys))
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(*args)
if return_code != 0:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(*args) == 0
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = best_runtime
return best_runtime
jit_tuner = JITTuner()

View File

@ -0,0 +1,105 @@
import torch
_num_sms = None
def set_num_sms(num_sms: int) -> None:
"""
Set the maximum SM count for all GEMM kernels to use.
Arguments:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
global _num_sms
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
_num_sms = num_sms
def get_num_sms() -> int:
"""
Get the current maximum limit of SM count for all GEMM kernels to use.
If the count is never specified, the function will return the number of device SMs.
Returns:
Current maximum limit of SM count for all GEMM kernels to use.
"""
global _num_sms
if _num_sms is None:
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
return _num_sms
def cell_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def get_m_alignment_for_contiguous_layout():
"""
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
with GEMM block shape.
Returns:
Group-level alignment requirement for grouped contiguous layout, which is always 128.
"""
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return cell_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b, m, n = x.shape
aligned_m = get_tma_aligned_size(m, x.element_size())
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
return aligned_x.squeeze(0) if remove_dim else aligned_x

154
deep_gemm/utils.py Normal file
View File

@ -0,0 +1,154 @@
import os
import sys
import torch
import torch.distributed as dist
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
cache.zero_()
# Warmup
for _ in range(num_warmups):
fn()
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
x @ y
# Testing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_tests):
fn()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_tests
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
# Conflict with Nsight Systems
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
# For some auto-tuning kernels with prints
fn()
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
if flush_l2:
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
fn()
if not using_nsys:
profiler.step()
# Return 1 if using Nsight Systems
if using_nsys:
return 1
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
profiler.export_chrome_trace(trace_path)
# Return average kernel times
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def count_bytes(tensors):
total = 0
for t in tensors:
if isinstance(t, tuple):
total += count_bytes(t)
else:
total += t.numel() * t.element_size()
return total

BIN
figures/design.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 558 KiB

65
setup.py Normal file
View File

@ -0,0 +1,65 @@
import os
import setuptools
import shutil
import subprocess
from setuptools.command.develop import develop
from setuptools.command.install import install
current_dir = os.path.dirname(os.path.realpath(__file__))
jit_include_dirs = ('deep_gemm/include/deep_gemm', )
third_party_include_dirs = ('third-party/cutlass/include/cute', 'third-party/cutlass/include/cutlass')
class PostDevelopCommand(develop):
def run(self):
develop.run(self)
self.make_jit_include_symlinks()
@staticmethod
def make_jit_include_symlinks():
# Make symbolic links of third-party include directories
for d in third_party_include_dirs:
dirname = d.split('/')[-1]
src_dir = f'{current_dir}/{d}'
dst_dir = f'{current_dir}/deep_gemm/include/{dirname}'
assert os.path.exists(src_dir)
if os.path.exists(dst_dir):
assert os.path.islink(dst_dir)
os.unlink(dst_dir)
os.symlink(src_dir, dst_dir, target_is_directory=True)
class PostInstallCommand(install):
def run(self):
install.run(self)
self.copy_jit_includes()
def copy_jit_includes(self):
# Copy include directories needed by JIT
shutil.rmtree(f'{self.build_lib}/deep_gemm/include', ignore_errors=True)
os.makedirs(f'{self.build_lib}/deep_gemm/include', exist_ok=False)
for d in jit_include_dirs + third_party_include_dirs:
src_dir = f'{current_dir}/{d}'
dst_dir = f'{self.build_lib}/deep_gemm/include/{d.split("/")[-1]}'
assert os.path.exists(src_dir)
shutil.copytree(src_dir, dst_dir)
if __name__ == '__main__':
# noinspection PyBroadException
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except:
revision = ''
# noinspection PyTypeChecker
setuptools.setup(
name='deep_gemm',
version='1.0.0' + revision,
packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'],
cmdclass={
'develop': PostDevelopCommand,
'install': PostInstallCommand
}
)

158
tests/test_core.py Normal file
View File

@ -0,0 +1,158 @@
import random
import torch
from typing import Tuple
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def construct(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
# For non-masked input, we must merge the group and M dims
if not is_masked:
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def test_gemm() -> None:
print('Testing GEMM:')
for m in (64, 128, 4096):
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing grouped contiguous GEMM:')
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
# TODO: make a stronger test
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_masked() -> None:
print('Testing grouped masked GEMM:')
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
for k, n in ((7168, 4096), (2048, 7168), ):
# Test correctness
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
for i in range(10):
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = random.choice(masked_m_candidates)
expected_m = int(masked_m.float().mean()) + 1
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
for j in range(num_groups):
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
# Test performance with fixed shapes
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()

64
tests/test_jit.py Normal file
View File

@ -0,0 +1,64 @@
import os
import torch
from typing import Any
from deep_gemm import jit
class Capture:
def __init__(self) -> None:
self.read_fd = None
self.write_fd = None
self.saved_stdout = None
self.captured = None
def __enter__(self) -> Any:
self.read_fd, self.write_fd = os.pipe()
self.saved_stdout = os.dup(1)
os.dup2(self.write_fd, 1)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
os.dup2(self.saved_stdout, 1)
os.close(self.write_fd)
with os.fdopen(self.read_fd, 'r') as f:
self.captured = f.read()
def capture(self) -> str:
return self.captured
if __name__ == '__main__':
# Runtime
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
# Templates
print('Generated code:')
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
body = "\n"
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
body += 'std::cout << enable_double_streams << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
code = jit.generate((), args, body)
print(code)
# Build
print('Building ...')
func = jit.build('test_func', args, code)
# Test correctness
print('Running ...')
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
with Capture() as capture:
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
output = capture.capture()
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
assert output == ref_output, f'{output=}, {ref_output=}'
print('JIT test passed')

1
third-party/cutlass vendored Submodule

@ -0,0 +1 @@
Subproject commit eefa171318b79cbe2e78514d4cce5cd0fe919d0c