mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Fix LaTeX render error
This commit is contained in:
parent
6cff5a73f5
commit
69d6df34e5
@ -25,13 +25,13 @@ To fully utilize GPU compute resources, we need to overlap CUDA Core operations
|
|||||||
Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows:
|
Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows:
|
||||||
|
|
||||||
0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0).
|
0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0).
|
||||||
1. [0] Compute $\vec p_0 = \vec q K_0^\intercal / qk\_scale$.
|
1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$.
|
||||||
2. [1] Compute $\vec p_1 = \vec q K_1^\intercal / qk\_scale$.
|
2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$.
|
||||||
3. [0] Compute $mp_0 = \max(\vec p_0)$, $m\_new_0 = \max(m, mp_0)$, and $scale_0 = \exp(m\_new_0 - m)$. Update $m \gets m\_new_0$.
|
3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$.
|
||||||
4. [0] Perform softmax on $\vec p_0$: $\vec p_0 \gets \exp(\vec p_0 - m\_new_0)$.
|
4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$.
|
||||||
5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$.
|
5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$.
|
||||||
6. [1] Compute $mp_1 = \max(\vec p_1)$, $m\_new_1 = \max(m, mp_1)$, and $scale_1 = \exp(m\_new_1 - m)$. Update $m \gets m\_new_1$.
|
6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$.
|
||||||
7. [1] Perform softmax on $\vec p_1$: $\vec p_1 \gets \exp(\vec p_1 - m\_new_1)$.
|
7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$.
|
||||||
8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$.
|
8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$.
|
||||||
9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.
|
9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.
|
||||||
10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$.
|
10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$.
|
||||||
|
Loading…
Reference in New Issue
Block a user