From 7409203f44dd54b8f51734e05c1fe7789c2d86a9 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 17:48:07 +0800 Subject: [PATCH] enable fp8 compile --- csrc/flash_fwd_mla_kernel.h | 11 +++++++++-- setup.py | 5 +++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 1af3eb7..fb53f79 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -135,6 +135,11 @@ struct Flash_fwd_kernel_traits_mla { Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>>{})); // Val layout, 4 vals per store + + + + // for fp8 trans-v + }; namespace flash { @@ -170,7 +175,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; + using Element = typename Kernel_traits::ElementO; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; @@ -272,7 +277,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + auto sVt = cute::conditional_return( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}), + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); diff --git a/setup.py b/setup.py index bfe931f..8b11e00 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ ext_modules.append( sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", - #"csrc/flash_fwd_mla_fp8_sm90.cu", + "csrc/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args, @@ -55,7 +55,8 @@ ext_modules.append( "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - "--ptxas-options=-v,--register-usage-level=10" + "--ptxas-options=-v,--register-usage-level=10", + "--ftemplate-backtrace-limit=0" ] + cc_flag ),