chore(setup): properly package the repository as a Python package

This commit is contained in:
Xuehai Pan
2025-02-24 18:18:38 +08:00
parent 18e32770cc
commit 26d3077949
8 changed files with 108 additions and 8 deletions

View File

@@ -1,6 +1,15 @@
__version__ = "1.0.0"
"""FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."""
from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
)
__all__ = [
"get_mla_metadata",
"flash_mla_with_kvcache",
]
__version__ = "1.0.0"

View File

@@ -0,0 +1,19 @@
import torch
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def fwd_kvcache_mla(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor | None,
head_dim_v: int,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
softmax_scale: float,
causal: bool,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ...

View File

@@ -2,7 +2,7 @@ from typing import Optional, Tuple
import torch
import flash_mla_cuda
from flash_mla import flash_mla_cuda
def get_mla_metadata(

0
flash_mla/py.typed Normal file
View File