chore(setup): properly package the repository as a Python package

This commit is contained in:
Xuehai Pan 2025-02-24 18:18:38 +08:00
parent 18e32770cc
commit 26d3077949
8 changed files with 108 additions and 8 deletions

7
MANIFEST.in Normal file
View File

@ -0,0 +1,7 @@
recursive-include flash_mla *.pyi
recursive-include flash_mla *.typed
include LICENSE
# Include source files in sdist
include .gitmodules
recursive-include csrc *

View File

@ -11,7 +11,9 @@ Currently released:
### Install ### Install
```bash ```bash
python setup.py install python3 -m pip install --upgrade pip setuptools
python3 -m pip install torch pybind11 --index-url https://download.pytorch.org/whl/cu126
python3 -m pip install --no-build-isolation --editable .
``` ```
### Benchmark ### Benchmark
@ -52,7 +54,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-
```bibtex ```bibtex
@misc{flashmla2025, @misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernel}, title={FlashMLA: Efficient MLA decoding kernel},
author={Jiashi Li}, author={Jiashi Li},
year={2025}, year={2025},
publisher = {GitHub}, publisher = {GitHub},

View File

@ -1,6 +1,15 @@
__version__ = "1.0.0" """FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."""
from flash_mla.flash_mla_interface import ( from flash_mla.flash_mla_interface import (
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache, flash_mla_with_kvcache,
) )
__all__ = [
"get_mla_metadata",
"flash_mla_with_kvcache",
]
__version__ = "1.0.0"

View File

@ -0,0 +1,19 @@
import torch
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def fwd_kvcache_mla(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor | None,
head_dim_v: int,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
softmax_scale: float,
causal: bool,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ...

View File

@ -2,7 +2,7 @@ from typing import Optional, Tuple
import torch import torch
import flash_mla_cuda from flash_mla import flash_mla_cuda
def get_mla_metadata( def get_mla_metadata(

0
flash_mla/py.typed Normal file
View File

64
pyproject.toml Normal file
View File

@ -0,0 +1,64 @@
# Package ######################################################################
[build-system]
requires = ["setuptools", "pybind11", "torch ~= 2.0"]
build-backend = "setuptools.build_meta"
[project]
name = "flash-mla"
description = "FlashMLA: An efficient MLA decoding kernel for Hopper GPUs."
readme = "README.md"
requires-python = ">= 3.8"
authors = [
{ name = "FlashMLA Contributors" },
{ name = "Jiashi Li", email = "450993438@qq.com" },
]
license = { text = "MIT" }
keywords = [
"Multi-head Latent Attention",
"MLA",
"Flash MLA",
"Flash Attention",
"CUDA",
"kernel",
]
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: MIT License",
"Programming Language :: C++",
"Programming Language :: CUDA",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
]
dependencies = ["torch ~= 2.0"]
dynamic = ["version"]
[project.urls]
Homepage = "https://github.com/deepseek-ai/FlashMLA"
Repository = "https://github.com/deepseek-ai/FlashMLA"
"Bug Report" = "https://github.com/deepseek-ai/FlashMLA/issues"
[project.optional-dependencies]
test = ["triton"]
benchmark = ["triton"]
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
include = ["flash_mla", "flash_mla.*"]
[tool.setuptools.package-data]
flash_mla = ['*.so', '*.pyd']

View File

@ -3,7 +3,7 @@ from pathlib import Path
from datetime import datetime from datetime import datetime
import subprocess import subprocess
from setuptools import setup, find_packages from setuptools import setup
from torch.utils.cpp_extension import ( from torch.utils.cpp_extension import (
BuildExtension, BuildExtension,
@ -33,7 +33,7 @@ else:
ext_modules = [] ext_modules = []
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="flash_mla_cuda", name="flash_mla.flash_mla_cuda",
sources=[ sources=[
"csrc/flash_api.cpp", "csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu", "csrc/flash_fwd_mla_bf16_sm90.cu",
@ -77,9 +77,8 @@ except Exception as _:
setup( setup(
name="flash_mla", name="flash-mla",
version="1.0.0" + rev, version="1.0.0" + rev,
packages=find_packages(include=['flash_mla']),
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
) )