enable fp8 compile

This commit is contained in:
chenhongmin.will 2025-02-25 17:48:07 +08:00
parent fed0499301
commit 7409203f44
2 changed files with 12 additions and 4 deletions

View File

@ -135,6 +135,11 @@ struct Flash_fwd_kernel_traits_mla {
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{}, GmemLayoutAtomOaccum{},
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
// for fp8 trans-v
}; };
namespace flash { namespace flash {
@ -170,7 +175,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const
constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS; constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum; using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t; using index_t = typename Kernel_traits::index_t;
@ -272,7 +277,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);

View File

@ -37,7 +37,7 @@ ext_modules.append(
sources=[ sources=[
"csrc/flash_api.cpp", "csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu", "csrc/flash_fwd_mla_bf16_sm90.cu",
#"csrc/flash_fwd_mla_fp8_sm90.cu", "csrc/flash_fwd_mla_fp8_sm90.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": cxx_args, "cxx": cxx_args,
@ -55,7 +55,8 @@ ext_modules.append(
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10" "--ptxas-options=-v,--register-usage-level=10",
"--ftemplate-backtrace-limit=0"
] ]
+ cc_flag + cc_flag
), ),