mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
chore(setup): properly package the repository as a Python package
This commit is contained in:
parent
18e32770cc
commit
26d3077949
7
MANIFEST.in
Normal file
7
MANIFEST.in
Normal file
@ -0,0 +1,7 @@
|
||||
recursive-include flash_mla *.pyi
|
||||
recursive-include flash_mla *.typed
|
||||
include LICENSE
|
||||
|
||||
# Include source files in sdist
|
||||
include .gitmodules
|
||||
recursive-include csrc *
|
@ -11,7 +11,9 @@ Currently released:
|
||||
### Install
|
||||
|
||||
```bash
|
||||
python setup.py install
|
||||
python3 -m pip install --upgrade pip setuptools
|
||||
python3 -m pip install torch pybind11 --index-url https://download.pytorch.org/whl/cu126
|
||||
python3 -m pip install --no-build-isolation --editable .
|
||||
```
|
||||
|
||||
### Benchmark
|
||||
|
@ -1,6 +1,15 @@
|
||||
__version__ = "1.0.0"
|
||||
"""FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."""
|
||||
|
||||
from flash_mla.flash_mla_interface import (
|
||||
get_mla_metadata,
|
||||
flash_mla_with_kvcache,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_mla_metadata",
|
||||
"flash_mla_with_kvcache",
|
||||
]
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
19
flash_mla/flash_mla_cuda.pyi
Normal file
19
flash_mla/flash_mla_cuda.pyi
Normal file
@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
||||
def fwd_kvcache_mla(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor | None,
|
||||
head_dim_v: int,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
causal: bool,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
@ -2,7 +2,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import flash_mla_cuda
|
||||
from flash_mla import flash_mla_cuda
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
|
0
flash_mla/py.typed
Normal file
0
flash_mla/py.typed
Normal file
64
pyproject.toml
Normal file
64
pyproject.toml
Normal file
@ -0,0 +1,64 @@
|
||||
# Package ######################################################################
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "pybind11", "torch ~= 2.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "flash-mla"
|
||||
description = "FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
authors = [
|
||||
{ name = "FlashMLA Contributors" },
|
||||
{ name = "Jiashi Li", email = "450993438@qq.com" },
|
||||
]
|
||||
license = { text = "MIT" }
|
||||
keywords = [
|
||||
"Multi-head Latent Attention",
|
||||
"MLA",
|
||||
"Flash MLA",
|
||||
"Flash Attention",
|
||||
"CUDA",
|
||||
"kernel",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: C++",
|
||||
"Programming Language :: CUDA",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Environment :: GPU :: NVIDIA CUDA :: 12",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
]
|
||||
dependencies = ["torch ~= 2.0"]
|
||||
dynamic = ["version"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/deepseek-ai/FlashMLA"
|
||||
Repository = "https://github.com/deepseek-ai/FlashMLA"
|
||||
"Bug Report" = "https://github.com/deepseek-ai/FlashMLA/issues"
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["triton"]
|
||||
benchmark = ["triton"]
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["flash_mla", "flash_mla.*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
flash_mla = ['*.so', '*.pyd']
|
7
setup.py
7
setup.py
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
import subprocess
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import setup
|
||||
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
@ -33,7 +33,7 @@ else:
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
name="flash_mla.flash_mla_cuda",
|
||||
sources=[
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
@ -77,9 +77,8 @@ except Exception as _:
|
||||
|
||||
|
||||
setup(
|
||||
name="flash_mla",
|
||||
name="flash-mla",
|
||||
version="1.0.0" + rev,
|
||||
packages=find_packages(include=['flash_mla']),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user