mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-04-16 21:41:40 +00:00
Update docstring
This commit is contained in:
parent
bcb90f2afd
commit
c4c5912b05
@ -16,7 +16,7 @@ def get_mla_metadata(
|
|||||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||||
num_heads_k: num_heads_k.
|
num_heads_k: num_heads_k.
|
||||||
|
|
||||||
Return:
|
Returns:
|
||||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||||
num_splits: (batch_size + 1), dtype torch.int32.
|
num_splits: (batch_size + 1), dtype torch.int32.
|
||||||
"""
|
"""
|
||||||
@ -40,13 +40,13 @@ def flash_mla_with_kvcache(
|
|||||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||||
cache_seqlens: (batch_size), torch.int32.
|
cache_seqlens: (batch_size), torch.int32.
|
||||||
head_dim_v: Head_dim of v.
|
head_dim_v: Head dimension of v.
|
||||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||||
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||||
causal: bool. Whether to apply causal attention mask.
|
causal: bool. Whether to apply causal attention mask.
|
||||||
|
|
||||||
Return:
|
Returns:
|
||||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user