mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
chore(setup): properly package the repository as a Python package
This commit is contained in:
parent
18e32770cc
commit
26d3077949
7
MANIFEST.in
Normal file
7
MANIFEST.in
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
recursive-include flash_mla *.pyi
|
||||||
|
recursive-include flash_mla *.typed
|
||||||
|
include LICENSE
|
||||||
|
|
||||||
|
# Include source files in sdist
|
||||||
|
include .gitmodules
|
||||||
|
recursive-include csrc *
|
@ -11,7 +11,9 @@ Currently released:
|
|||||||
### Install
|
### Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python setup.py install
|
python3 -m pip install --upgrade pip setuptools
|
||||||
|
python3 -m pip install torch pybind11 --index-url https://download.pytorch.org/whl/cu126
|
||||||
|
python3 -m pip install --no-build-isolation --editable .
|
||||||
```
|
```
|
||||||
|
|
||||||
### Benchmark
|
### Benchmark
|
||||||
@ -52,7 +54,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{flashmla2025,
|
@misc{flashmla2025,
|
||||||
title={FlashMLA: Efficient MLA decoding kernel},
|
title={FlashMLA: Efficient MLA decoding kernel},
|
||||||
author={Jiashi Li},
|
author={Jiashi Li},
|
||||||
year={2025},
|
year={2025},
|
||||||
publisher = {GitHub},
|
publisher = {GitHub},
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
__version__ = "1.0.0"
|
"""FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."""
|
||||||
|
|
||||||
from flash_mla.flash_mla_interface import (
|
from flash_mla.flash_mla_interface import (
|
||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
flash_mla_with_kvcache,
|
flash_mla_with_kvcache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_mla_metadata",
|
||||||
|
"flash_mla_with_kvcache",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = "1.0.0"
|
||||||
|
19
flash_mla/flash_mla_cuda.pyi
Normal file
19
flash_mla/flash_mla_cuda.pyi
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def get_mla_metadata(
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
num_heads_per_head_k: int,
|
||||||
|
num_heads_k: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
||||||
|
def fwd_kvcache_mla(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor | None,
|
||||||
|
head_dim_v: int,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
causal: bool,
|
||||||
|
tile_scheduler_metadata: torch.Tensor,
|
||||||
|
num_splits: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
@ -2,7 +2,7 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import flash_mla_cuda
|
from flash_mla import flash_mla_cuda
|
||||||
|
|
||||||
|
|
||||||
def get_mla_metadata(
|
def get_mla_metadata(
|
||||||
|
0
flash_mla/py.typed
Normal file
0
flash_mla/py.typed
Normal file
64
pyproject.toml
Normal file
64
pyproject.toml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Package ######################################################################
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools", "pybind11", "torch ~= 2.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "flash-mla"
|
||||||
|
description = "FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">= 3.8"
|
||||||
|
authors = [
|
||||||
|
{ name = "FlashMLA Contributors" },
|
||||||
|
{ name = "Jiashi Li", email = "450993438@qq.com" },
|
||||||
|
]
|
||||||
|
license = { text = "MIT" }
|
||||||
|
keywords = [
|
||||||
|
"Multi-head Latent Attention",
|
||||||
|
"MLA",
|
||||||
|
"Flash MLA",
|
||||||
|
"Flash Attention",
|
||||||
|
"CUDA",
|
||||||
|
"kernel",
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: C++",
|
||||||
|
"Programming Language :: CUDA",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
|
"Operating System :: Microsoft :: Windows",
|
||||||
|
"Operating System :: POSIX :: Linux",
|
||||||
|
"Environment :: GPU :: NVIDIA CUDA :: 12",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Education",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
]
|
||||||
|
dependencies = ["torch ~= 2.0"]
|
||||||
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/deepseek-ai/FlashMLA"
|
||||||
|
Repository = "https://github.com/deepseek-ai/FlashMLA"
|
||||||
|
"Bug Report" = "https://github.com/deepseek-ai/FlashMLA/issues"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = ["triton"]
|
||||||
|
benchmark = ["triton"]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["flash_mla", "flash_mla.*"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
flash_mla = ['*.so', '*.pyd']
|
7
setup.py
7
setup.py
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup
|
||||||
|
|
||||||
from torch.utils.cpp_extension import (
|
from torch.utils.cpp_extension import (
|
||||||
BuildExtension,
|
BuildExtension,
|
||||||
@ -33,7 +33,7 @@ else:
|
|||||||
ext_modules = []
|
ext_modules = []
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="flash_mla_cuda",
|
name="flash_mla.flash_mla_cuda",
|
||||||
sources=[
|
sources=[
|
||||||
"csrc/flash_api.cpp",
|
"csrc/flash_api.cpp",
|
||||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||||
@ -77,9 +77,8 @@ except Exception as _:
|
|||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="flash_mla",
|
name="flash-mla",
|
||||||
version="1.0.0" + rev,
|
version="1.0.0" + rev,
|
||||||
packages=find_packages(include=['flash_mla']),
|
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user