change to use per_tensor

This commit is contained in:
chenhongmin.will 2025-02-26 10:17:29 +08:00
parent 4b314cd655
commit f6fab1b915
3 changed files with 12 additions and 6 deletions

View File

@ -118,8 +118,8 @@ mha_fwd_kvcache_mla(
TORCH_CHECK(descale_k.stride(-1) == 1);
TORCH_CHECK(descale_q.dtype() == torch::kFloat);
TORCH_CHECK(descale_k.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q, batch_size);
CHECK_SHAPE(descale_k, batch_size);
CHECK_SHAPE(descale_q, 1);
CHECK_SHAPE(descale_k, 1);
}
if (seqlen_q_ori == 1) { is_causal = false; }

View File

@ -260,7 +260,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage) {
SharedStorage &shared_storage, const float descale_q, const float descale_k) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
@ -494,6 +494,12 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
float descale_q, descale_k;
if constexpr (Kernel_traits::Is_FP8) {
descale_q = __ldg(params.descale_q_ptr);
descale_k = __ldg(params.descale_k_ptr);
}
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
@ -504,7 +510,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k);
}
}

View File

@ -67,8 +67,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
if use_fp8:
nonlocal q, blocked_k, blocked_v
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((b), dtype=torch.float32)
descale_k = torch.ones((b), dtype=torch.float32)
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q_fp8 = q.to(fp8_dtype)
blocked_k_fp8 = blocked_k.to(fp8_dtype)