mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
enable fp8 compile
This commit is contained in:
parent
fed0499301
commit
7409203f44
@ -135,6 +135,11 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
||||
|
||||
|
||||
|
||||
// for fp8 trans-v
|
||||
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
@ -170,7 +175,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using Element = typename Kernel_traits::ElementO;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
@ -272,7 +277,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||
|
||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||
|
5
setup.py
5
setup.py
@ -37,7 +37,7 @@ ext_modules.append(
|
||||
sources=[
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
#"csrc/flash_fwd_mla_fp8_sm90.cu",
|
||||
"csrc/flash_fwd_mla_fp8_sm90.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": cxx_args,
|
||||
@ -55,7 +55,8 @@ ext_modules.append(
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v,--register-usage-level=10"
|
||||
"--ptxas-options=-v,--register-usage-level=10",
|
||||
"--ftemplate-backtrace-limit=0"
|
||||
]
|
||||
+ cc_flag
|
||||
),
|
||||
|
Loading…
Reference in New Issue
Block a user