From fd1e662debb640812123111c8e1f1396bbcf94a0 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 16:52:30 +0800 Subject: [PATCH] fix mma0 --- csrc/flash_fwd_mla_kernel.h | 21 +++++++++++++++------ tests/test_flash_mla.py | 10 +++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index b4f3ed7..ad52b3c 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -328,8 +328,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if (n_block % 2 == 1) { // Double buffer for sK constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -392,8 +397,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } cute::copy(softmax.row_max, tRow_maxsRow_max); @@ -513,9 +522,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (Kernel_traits::Is_FP8) __syncthreads(); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); if constexpr (!Kernel_traits::Is_FP8) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); tOrVt.data() = tOrVt.data() + sK_offset / 8; } } diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 6cfd466..03c9037 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -27,13 +27,17 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): return attn_weight @ value, lse -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 + + if use_fp8: + assert cos_diff < 1e-3 + else: + assert cos_diff < 1e-5 @torch.inference_mode() @@ -111,7 +115,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla)