This commit is contained in:
chenhongmin.will 2025-02-28 16:52:30 +08:00
parent 061af5fc56
commit fd1e662deb
2 changed files with 22 additions and 9 deletions

View File

@ -328,8 +328,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
if constexpr (Kernel_traits::Is_FP8) {
tSrK.data() = tSrK.data() + sK_offset / 16;
} else {
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
}
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
@ -392,8 +397,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
if constexpr (Kernel_traits::Is_FP8) {
tSrK.data() = tSrK.data() + sK_offset / 16;
} else {
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
}
cute::copy(softmax.row_max, tRow_maxsRow_max);
@ -513,9 +522,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
if constexpr (Kernel_traits::Is_FP8) __syncthreads();
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
if constexpr (!Kernel_traits::Is_FP8) {
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
}

View File

@ -27,13 +27,17 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5
if use_fp8:
assert cos_diff < 1e-3
else:
assert cos_diff < 1e-5
@torch.inference_mode()
@ -111,7 +115,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla)