mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
chore(setup): properly package the repository as a Python package
This commit is contained in:
@@ -1,6 +1,15 @@
|
||||
__version__ = "1.0.0"
|
||||
"""FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."""
|
||||
|
||||
from flash_mla.flash_mla_interface import (
|
||||
get_mla_metadata,
|
||||
flash_mla_with_kvcache,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_mla_metadata",
|
||||
"flash_mla_with_kvcache",
|
||||
]
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
19
flash_mla/flash_mla_cuda.pyi
Normal file
19
flash_mla/flash_mla_cuda.pyi
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
||||
def fwd_kvcache_mla(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor | None,
|
||||
head_dim_v: int,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
causal: bool,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
||||
@@ -2,7 +2,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import flash_mla_cuda
|
||||
from flash_mla import flash_mla_cuda
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
|
||||
0
flash_mla/py.typed
Normal file
0
flash_mla/py.typed
Normal file
Reference in New Issue
Block a user