FlashMLA/flash_mla/flash_mla_cuda.pyi

20 lines
515 B
Python

import torch
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def fwd_kvcache_mla(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor | None,
head_dim_v: int,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
softmax_scale: float,
causal: bool,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ...