diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 2f3aa46..b2922af 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -16,7 +16,7 @@ def get_mla_metadata( num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. - Return: + Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ @@ -40,13 +40,13 @@ def flash_mla_with_kvcache( k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - Return: + Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """