mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
update fp8 api
This commit is contained in:
parent
ef644a56e0
commit
4b314cd655
@ -69,8 +69,8 @@ mha_fwd_kvcache_mla(
|
||||
bool is_causal,
|
||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor &num_splits, // batch_size + 1
|
||||
c10::optional<const at::Tensor> &descale_q, // batch_size
|
||||
c10::optional<const at::Tensor> &descale_k // batch_size
|
||||
c10::optional<const at::Tensor> &descale_q_, // batch_size
|
||||
c10::optional<const at::Tensor> &descale_k_ // batch_size
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
@ -81,6 +81,7 @@ mha_fwd_kvcache_mla(
|
||||
auto q_dtype = q.scalar_type();
|
||||
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype");
|
||||
bool is_fp8 = q_dtype == torch::kFloat8_e4m3fn;
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
@ -107,6 +108,20 @@ mha_fwd_kvcache_mla(
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (is_fp8) {
|
||||
TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8");
|
||||
auto descale_q = descale_q_.value();
|
||||
auto descale_k = descale_k_.value();
|
||||
CHECK_DEVICE(descale_q);
|
||||
CHECK_DEVICE(descale_k);
|
||||
TORCH_CHECK(descale_q.stride(-1) == 1);
|
||||
TORCH_CHECK(descale_k.stride(-1) == 1);
|
||||
TORCH_CHECK(descale_q.dtype() == torch::kFloat);
|
||||
TORCH_CHECK(descale_k.dtype() == torch::kFloat);
|
||||
CHECK_SHAPE(descale_q, batch_size);
|
||||
CHECK_SHAPE(descale_k, batch_size);
|
||||
}
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
const int ngroups = num_heads_ori / num_heads_k;
|
||||
@ -130,7 +145,7 @@ mha_fwd_kvcache_mla(
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
|
||||
auto out_type = is_fp8 ? torch::kBFloat16 : q_dtype;
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type));
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
@ -171,6 +186,11 @@ mha_fwd_kvcache_mla(
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
if (is_fp8) {
|
||||
params.descale_q_ptr = reinterpret_cast<float*>(descale_q_.value().data_ptr());
|
||||
params.descale_k_ptr = reinterpret_cast<float*>(descale_k_.value().data_ptr());
|
||||
}
|
||||
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
@ -190,7 +210,7 @@ mha_fwd_kvcache_mla(
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
|
||||
if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||
if (is_fp8) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||
|
@ -17,6 +17,9 @@ struct Flash_fwd_mla_params {
|
||||
void *__restrict__ o_ptr;
|
||||
void *__restrict__ softmax_lse_ptr;
|
||||
|
||||
float* __restrict__ descale_q_ptr = nullptr;
|
||||
float* __restrict__ descale_k_ptr = nullptr;
|
||||
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
|
Loading…
Reference in New Issue
Block a user