mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
enable fp8 api
This commit is contained in:
parent
c50d29d170
commit
dfe8ffc75a
@ -68,7 +68,10 @@ mha_fwd_kvcache_mla(
|
|||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||||
const at::Tensor &num_splits // batch_size + 1
|
const at::Tensor &num_splits, // batch_size + 1
|
||||||
|
c10::optional<const at::Tensor> &descale_q, // batch_size
|
||||||
|
c10::optional<const at::Tensor> &descale_k, // batch_size
|
||||||
|
c10::optional<const at::Tensor> &descale_v // batch_size
|
||||||
) {
|
) {
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||||
@ -76,9 +79,9 @@ mha_fwd_kvcache_mla(
|
|||||||
|
|
||||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||||
|
|
||||||
auto q_dtype = q.dtype();
|
auto q_dtype = q.scalar_type();
|
||||||
TORCH_CHECK(q_dtype == torch::kBFloat16);
|
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn);
|
||||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype");
|
||||||
|
|
||||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||||
|
|
||||||
@ -128,7 +131,8 @@ mha_fwd_kvcache_mla(
|
|||||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||||
|
|
||||||
auto opts = q.options();
|
auto opts = q.options();
|
||||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
|
||||||
|
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type));
|
||||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||||
|
|
||||||
Flash_fwd_mla_params params = {};
|
Flash_fwd_mla_params params = {};
|
||||||
@ -186,7 +190,12 @@ mha_fwd_kvcache_mla(
|
|||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
TORCH_CHECK(head_size == 576);
|
TORCH_CHECK(head_size == 576);
|
||||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
|
||||||
|
if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||||
|
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||||
|
}
|
||||||
|
|
||||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||||
|
@ -278,8 +278,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||||
|
|
||||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||||
|
@ -63,5 +63,6 @@ def flash_mla_with_kvcache(
|
|||||||
causal,
|
causal,
|
||||||
tile_scheduler_metadata,
|
tile_scheduler_metadata,
|
||||||
num_splits,
|
num_splits,
|
||||||
|
None, None, None,
|
||||||
)
|
)
|
||||||
return out, softmax_lse
|
return out, softmax_lse
|
||||||
|
Loading…
Reference in New Issue
Block a user