support fp16

This commit is contained in:
Sijia Chen
2025-02-24 01:58:53 -08:00
parent 15a82b81b8
commit 65fb7732fc
7 changed files with 139 additions and 91 deletions

View File

@@ -77,7 +77,7 @@ mha_fwd_kvcache_mla(
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16);
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
@@ -186,7 +186,12 @@ mha_fwd_kvcache_mla(
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
} else {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});