mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
enable scale
This commit is contained in:
parent
4e055a6142
commit
8b939854d8
@ -151,9 +151,6 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
|
||||
// ------ for f8 ------
|
||||
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
|
||||
// using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||
// getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
// Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
|
||||
};
|
||||
|
||||
@ -186,7 +183,7 @@ struct SharedStorageMLA {
|
||||
|
||||
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
||||
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
@ -203,7 +200,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
||||
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, scale_softmax, descale_k);
|
||||
|
||||
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
@ -275,7 +272,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
const int bidb, const int bidh, const int m_block,
|
||||
const int n_split_idx, const int seqlen_k,
|
||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
||||
SharedStorage &shared_storage, const float descale_q, const float descale_k) {
|
||||
SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
@ -372,10 +369,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
|
||||
// We have key_padding_mask so we'll need to Check_inf
|
||||
Tensor scale_o = is_first_masking_step
|
||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, scale_softmax_log2)
|
||||
: is_masking_step ?
|
||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, scale_softmax_log2)
|
||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, scale_softmax_log2);
|
||||
|
||||
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
||||
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
|
||||
@ -535,9 +532,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
}
|
||||
|
||||
if (NoSplit)
|
||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax);
|
||||
else
|
||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
@ -560,10 +557,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
if (begin_idx >= params.b) return;
|
||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
||||
|
||||
float descale_q, descale_k;
|
||||
float descale_k = 1.f;
|
||||
float scale_softmax = params.scale_softmax;
|
||||
float scale_softmax_log2 = params.scale_softmax_log2;
|
||||
if constexpr (Kernel_traits::Is_FP8) {
|
||||
descale_q = __ldg(params.descale_q_ptr);
|
||||
float descale_q = __ldg(params.descale_q_ptr);
|
||||
descale_k = __ldg(params.descale_k_ptr);
|
||||
scale_softmax = scale_softmax * descale_q * descale_k;
|
||||
scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k;
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
@ -576,7 +577,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
||||
if (batch_id > begin_idx) {
|
||||
__syncthreads(); // Barrier between two tiles.
|
||||
}
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k);
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -174,7 +174,7 @@ struct Softmax {
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
@ -184,7 +184,7 @@ struct Softmax {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
|
@ -35,7 +35,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
|
||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||
|
||||
if use_fp8:
|
||||
assert cos_diff < 1e-3
|
||||
assert cos_diff < 1e-2
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user