mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add flag to disable FP16 compile
This commit is contained in:
@@ -77,7 +77,6 @@ mha_fwd_kvcache_mla(
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16);
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
@@ -189,9 +188,15 @@ mha_fwd_kvcache_mla(
|
||||
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
} else {
|
||||
}
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
else if (q_dtype == torch::kHalf) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||
}
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
|
||||
Reference in New Issue
Block a user