mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
enable fp8 compile
This commit is contained in:
parent
fed0499301
commit
7409203f44
@ -135,6 +135,11 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||||
GmemLayoutAtomOaccum{},
|
GmemLayoutAtomOaccum{},
|
||||||
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// for fp8 trans-v
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
@ -170,7 +175,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const
|
|||||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||||
using Element = typename Kernel_traits::Element;
|
using Element = typename Kernel_traits::ElementO;
|
||||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||||
using index_t = typename Kernel_traits::index_t;
|
using index_t = typename Kernel_traits::index_t;
|
||||||
|
|
||||||
@ -272,7 +277,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||||
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
||||||
|
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||||
|
|
||||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||||
|
5
setup.py
5
setup.py
@ -37,7 +37,7 @@ ext_modules.append(
|
|||||||
sources=[
|
sources=[
|
||||||
"csrc/flash_api.cpp",
|
"csrc/flash_api.cpp",
|
||||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||||
#"csrc/flash_fwd_mla_fp8_sm90.cu",
|
"csrc/flash_fwd_mla_fp8_sm90.cu",
|
||||||
],
|
],
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": cxx_args,
|
"cxx": cxx_args,
|
||||||
@ -55,7 +55,8 @@ ext_modules.append(
|
|||||||
"--expt-relaxed-constexpr",
|
"--expt-relaxed-constexpr",
|
||||||
"--expt-extended-lambda",
|
"--expt-extended-lambda",
|
||||||
"--use_fast_math",
|
"--use_fast_math",
|
||||||
"--ptxas-options=-v,--register-usage-level=10"
|
"--ptxas-options=-v,--register-usage-level=10",
|
||||||
|
"--ftemplate-backtrace-limit=0"
|
||||||
]
|
]
|
||||||
+ cc_flag
|
+ cc_flag
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user