diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index a20f408..5a0caa1 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -69,8 +69,8 @@ mha_fwd_kvcache_mla( bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 - c10::optional &descale_q, // batch_size - c10::optional &descale_k // batch_size + c10::optional &descale_q_, // batch_size + c10::optional &descale_k_ // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -81,6 +81,7 @@ mha_fwd_kvcache_mla( auto q_dtype = q.scalar_type(); TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn); TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); + bool is_fp8 = q_dtype == torch::kFloat8_e4m3fn; CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -107,6 +108,20 @@ mha_fwd_kvcache_mla( TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (is_fp8) { + TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8"); + auto descale_q = descale_q_.value(); + auto descale_k = descale_k_.value(); + CHECK_DEVICE(descale_q); + CHECK_DEVICE(descale_k); + TORCH_CHECK(descale_q.stride(-1) == 1); + TORCH_CHECK(descale_k.stride(-1) == 1); + TORCH_CHECK(descale_q.dtype() == torch::kFloat); + TORCH_CHECK(descale_k.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q, batch_size); + CHECK_SHAPE(descale_k, batch_size); + } + if (seqlen_q_ori == 1) { is_causal = false; } const int ngroups = num_heads_ori / num_heads_k; @@ -130,7 +145,7 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; + auto out_type = is_fp8 ? torch::kBFloat16 : q_dtype; at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -171,6 +186,11 @@ mha_fwd_kvcache_mla( params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; + if (is_fp8) { + params.descale_q_ptr = reinterpret_cast(descale_q_.value().data_ptr()); + params.descale_k_ptr = reinterpret_cast(descale_k_.value().data_ptr()); + } + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); @@ -190,7 +210,7 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - if (q_dtype == torch::kFloat8_e4m3fn) { + if (is_fp8) { run_mha_fwd_splitkv_mla(params, stream); } else { run_mha_fwd_splitkv_mla(params, stream); diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h index a2ef414..b7e2fed 100644 --- a/csrc/flash_mla.h +++ b/csrc/flash_mla.h @@ -17,6 +17,9 @@ struct Flash_fwd_mla_params { void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; + index_t q_batch_stride; index_t k_batch_stride; index_t v_batch_stride;