mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
20 lines
515 B
Python
20 lines
515 B
Python
import torch
|
|
|
|
def get_mla_metadata(
|
|
cache_seqlens: torch.Tensor,
|
|
num_heads_per_head_k: int,
|
|
num_heads_k: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
def fwd_kvcache_mla(
|
|
q: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor | None,
|
|
head_dim_v: int,
|
|
cache_seqlens: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
tile_scheduler_metadata: torch.Tensor,
|
|
num_splits: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]: ...
|