enable fp8 api

This commit is contained in:
chenhongmin.will 2025-02-25 22:34:01 +08:00
parent c50d29d170
commit dfe8ffc75a
3 changed files with 18 additions and 8 deletions

View File

@ -68,7 +68,10 @@ mha_fwd_kvcache_mla(
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
const at::Tensor &num_splits, // batch_size + 1
c10::optional<const at::Tensor> &descale_q, // batch_size
c10::optional<const at::Tensor> &descale_k, // batch_size
c10::optional<const at::Tensor> &descale_v // batch_size
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
@ -76,9 +79,9 @@ mha_fwd_kvcache_mla(
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
auto q_dtype = q.scalar_type();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn);
TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
@ -128,7 +131,8 @@ mha_fwd_kvcache_mla(
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type));
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_mla_params params = {};
@ -186,7 +190,12 @@ mha_fwd_kvcache_mla(
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
if (q_dtype == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
} else {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});

View File

@ -278,8 +278,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);

View File

@ -63,5 +63,6 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
None, None, None,
)
return out, softmax_lse