add fp8 ut

This commit is contained in:
chenhongmin.will
2025-02-25 23:29:18 +08:00
parent dfe8ffc75a
commit 870418802a
3 changed files with 28 additions and 7 deletions

View File

@@ -33,6 +33,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
@@ -45,7 +47,9 @@ def flash_mla_with_kvcache(
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float. dequant scale for query
descale_k: (batch_size), torch.float. dequant scale for key
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
@@ -63,6 +67,6 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
None, None, None,
descale_q, descale_k,
)
return out, softmax_lse