mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
enable scale
This commit is contained in:
parent
4e055a6142
commit
8b939854d8
@ -151,9 +151,6 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
|
|
||||||
// ------ for f8 ------
|
// ------ for f8 ------
|
||||||
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
|
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
|
||||||
// using SmemLayoutVtMMa = decltype(tile_to_shape(
|
|
||||||
// getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
|
||||||
// Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
|
||||||
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
|
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -186,7 +183,7 @@ struct SharedStorageMLA {
|
|||||||
|
|
||||||
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
||||||
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
||||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) {
|
||||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||||
@ -203,7 +200,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const
|
|||||||
|
|
||||||
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
||||||
|
|
||||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, scale_softmax, descale_k);
|
||||||
|
|
||||||
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
||||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||||
@ -275,7 +272,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
const int bidb, const int bidh, const int m_block,
|
const int bidb, const int bidh, const int m_block,
|
||||||
const int n_split_idx, const int seqlen_k,
|
const int n_split_idx, const int seqlen_k,
|
||||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
const int n_block_min, const int n_block_max, const bool NoSplit,
|
||||||
SharedStorage &shared_storage, const float descale_q, const float descale_k) {
|
SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
|
||||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||||
@ -372,10 +369,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
// We have key_padding_mask so we'll need to Check_inf
|
// We have key_padding_mask so we'll need to Check_inf
|
||||||
Tensor scale_o = is_first_masking_step
|
Tensor scale_o = is_first_masking_step
|
||||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, scale_softmax_log2)
|
||||||
: is_masking_step ?
|
: is_masking_step ?
|
||||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, scale_softmax_log2)
|
||||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, scale_softmax_log2);
|
||||||
|
|
||||||
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
||||||
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
|
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
|
||||||
@ -535,9 +532,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (NoSplit)
|
if (NoSplit)
|
||||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax);
|
||||||
else
|
else
|
||||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||||
@ -560,10 +557,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
|||||||
if (begin_idx >= params.b) return;
|
if (begin_idx >= params.b) return;
|
||||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
||||||
|
|
||||||
float descale_q, descale_k;
|
float descale_k = 1.f;
|
||||||
|
float scale_softmax = params.scale_softmax;
|
||||||
|
float scale_softmax_log2 = params.scale_softmax_log2;
|
||||||
if constexpr (Kernel_traits::Is_FP8) {
|
if constexpr (Kernel_traits::Is_FP8) {
|
||||||
descale_q = __ldg(params.descale_q_ptr);
|
float descale_q = __ldg(params.descale_q_ptr);
|
||||||
descale_k = __ldg(params.descale_k_ptr);
|
descale_k = __ldg(params.descale_k_ptr);
|
||||||
|
scale_softmax = scale_softmax * descale_q * descale_k;
|
||||||
|
scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k;
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll 1
|
#pragma unroll 1
|
||||||
@ -576,7 +577,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
|
|||||||
if (batch_id > begin_idx) {
|
if (batch_id > begin_idx) {
|
||||||
__syncthreads(); // Barrier between two tiles.
|
__syncthreads(); // Barrier between two tiles.
|
||||||
}
|
}
|
||||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k);
|
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ struct Softmax {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) {
|
||||||
SumOp<float> sum_op;
|
SumOp<float> sum_op;
|
||||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||||
TensorT lse = make_fragment_like(row_sum);
|
TensorT lse = make_fragment_like(row_sum);
|
||||||
@ -184,7 +184,7 @@ struct Softmax {
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||||
float sum = row_sum(mi);
|
float sum = row_sum(mi);
|
||||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
|
||||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -35,7 +35,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
|
|||||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
assert cos_diff < 1e-3
|
assert cos_diff < 1e-2
|
||||||
else:
|
else:
|
||||||
assert cos_diff < 1e-5
|
assert cos_diff < 1e-5
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user