enable scale

This commit is contained in:
chenhongmin.will 2025-02-28 19:43:24 +08:00
parent 4e055a6142
commit 8b939854d8
3 changed files with 18 additions and 17 deletions

View File

@ -151,9 +151,6 @@ struct Flash_fwd_kernel_traits_mla {
// ------ for f8 ------
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
// using SmemLayoutVtMMa = decltype(tile_to_shape(
// getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
// Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
};
@ -186,7 +183,7 @@ struct SharedStorageMLA {
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
@ -203,7 +200,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const
const int split_offset = __ldg(params.num_splits_ptr + bidb);
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, scale_softmax, descale_k);
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
@ -275,7 +272,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage, const float descale_q, const float descale_k) {
SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
@ -372,10 +369,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
// We have key_padding_mask so we'll need to Check_inf
Tensor scale_o = is_first_masking_step
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, scale_softmax_log2)
: is_masking_step ?
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, scale_softmax_log2)
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, scale_softmax_log2);
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
@ -535,9 +532,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
}
if (NoSplit)
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax);
else
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
@ -560,10 +557,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
float descale_q, descale_k;
float descale_k = 1.f;
float scale_softmax = params.scale_softmax;
float scale_softmax_log2 = params.scale_softmax_log2;
if constexpr (Kernel_traits::Is_FP8) {
descale_q = __ldg(params.descale_q_ptr);
float descale_q = __ldg(params.descale_q_ptr);
descale_k = __ldg(params.descale_k_ptr);
scale_softmax = scale_softmax * descale_q * descale_k;
scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k;
}
#pragma unroll 1
@ -576,7 +577,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k);
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2);
}
}

View File

@ -174,7 +174,7 @@ struct Softmax {
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
@ -184,7 +184,7 @@ struct Softmax {
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll

View File

@ -35,7 +35,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
if use_fp8:
assert cos_diff < 1e-3
assert cos_diff < 1e-2
else:
assert cos_diff < 1e-5