add fp8 ut

This commit is contained in:
chenhongmin.will 2025-02-25 23:29:18 +08:00
parent dfe8ffc75a
commit 870418802a
3 changed files with 28 additions and 7 deletions

View File

@ -70,8 +70,7 @@ mha_fwd_kvcache_mla(
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
c10::optional<const at::Tensor> &descale_q, // batch_size
c10::optional<const at::Tensor> &descale_k, // batch_size
c10::optional<const at::Tensor> &descale_v // batch_size
c10::optional<const at::Tensor> &descale_k // batch_size
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;

View File

@ -33,6 +33,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
@ -45,7 +47,9 @@ def flash_mla_with_kvcache(
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float. dequant scale for query
descale_k: (batch_size), torch.float. dequant scale for key
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
@ -63,6 +67,6 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
None, None, None,
descale_q, descale_k,
)
return out, softmax_lse

View File

@ -37,7 +37,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
@ -59,11 +59,28 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
descale_q, descale_k = None, None
if use_fp8:
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((b), dtype=torch.float32)
descale_k = torch.ones((b), dtype=torch.float32)
q_fp8 = q.to(fp8_dtype)
blocked_k_fp8 = blocked_k.to(fp8_dtype)
blocked_v_fp8 = blocked_v.to(fp8_dtype)
q = q_fp8.to(q.dtype)
blocked_k = blocked_k_fp8.to(blocked_k.dtype)
blocked_v = blocked_v_fp8.to(blocked_v.dtype)
def flash_mla():
q_ = q_fp8 if use_fp8 else q
blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k
return flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv,
q_, blocked_k_, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=causal,
descale_q=descale_q, descale_k=descale_k,
)
def ref_mla():
@ -107,10 +124,11 @@ if __name__ == "__main__":
h_kv = 1
d, dv = 576, 512
causal = True
use_fp8 = False
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)