mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add fp8 ut
This commit is contained in:
@@ -33,6 +33,8 @@ def flash_mla_with_kvcache(
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
@@ -45,7 +47,9 @@ def flash_mla_with_kvcache(
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
descale_q: (batch_size), torch.float. dequant scale for query
|
||||
descale_k: (batch_size), torch.float. dequant scale for key
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
@@ -63,6 +67,6 @@ def flash_mla_with_kvcache(
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
None, None, None,
|
||||
descale_q, descale_k,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
Reference in New Issue
Block a user