mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-04-09 15:04:00 +00:00
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
import flash_mla_cuda
|
|
|
|
|
|
def get_mla_metadata(
|
|
cache_seqlens: torch.Tensor,
|
|
num_heads_per_head_k: int,
|
|
num_heads_k: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Arguments:
|
|
cache_seqlens: (batch_size), dtype torch.int32.
|
|
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
|
num_heads_k: num_heads_k.
|
|
|
|
Returns:
|
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
|
num_splits: (batch_size + 1), dtype torch.int32.
|
|
"""
|
|
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
|
|
|
|
|
def flash_mla_with_kvcache(
|
|
q: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
head_dim_v: int,
|
|
tile_scheduler_metadata: torch.Tensor,
|
|
num_splits: torch.Tensor,
|
|
softmax_scale: Optional[float] = None,
|
|
causal: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Arguments:
|
|
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
|
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
|
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
|
cache_seqlens: (batch_size), torch.int32.
|
|
head_dim_v: Head dimension of v.
|
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
|
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
|
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
|
causal: bool. Whether to apply causal attention mask.
|
|
|
|
Returns:
|
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
|
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
|
"""
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
|
q,
|
|
k_cache,
|
|
None,
|
|
head_dim_v,
|
|
cache_seqlens,
|
|
block_table,
|
|
softmax_scale,
|
|
causal,
|
|
tile_scheduler_metadata,
|
|
num_splits,
|
|
)
|
|
return out, softmax_lse
|