mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Minor fix to the docs to correct FlashAttention-3's paper link and typos (#73)
Thank you for open source FlashMLA! Just read the write up and very amazing work! Found some very minor mistakes regarding to typos, and the link to the FlashAttention-3 paper is wrong as that is the original FlashAttention paper, so I just send the PR here. Thanks again! Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
parent
a9444cd67d
commit
6cff5a73f5
@ -18,7 +18,7 @@ According to [the overview of DeepSeek's Online Inference System](https://github
|
||||
|
||||
To fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's "schedule."
|
||||
|
||||
[FlashAttention-3's paper](https://arxiv.org/abs/2205.14135) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possiblility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation.
|
||||
[FlashAttention-3's paper](https://arxiv.org/abs/2407.08608) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possibility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation.
|
||||
|
||||
(You might pause here to ponder - perhaps you can find a better solution than ours!)
|
||||
|
||||
@ -62,7 +62,7 @@ Other performance improvements include:
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
FlashMLA's algorithm and scheduling is inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work.
|
||||
FlashMLA's algorithm and scheduling are inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work.
|
||||
|
||||
## Citation
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user