change to use per_tensor

This commit is contained in:
chenhongmin.will
2025-02-26 10:17:29 +08:00
parent 4b314cd655
commit f6fab1b915
3 changed files with 12 additions and 6 deletions

View File

@@ -118,8 +118,8 @@ mha_fwd_kvcache_mla(
TORCH_CHECK(descale_k.stride(-1) == 1);
TORCH_CHECK(descale_q.dtype() == torch::kFloat);
TORCH_CHECK(descale_k.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q, batch_size);
CHECK_SHAPE(descale_k, batch_size);
CHECK_SHAPE(descale_q, 1);
CHECK_SHAPE(descale_k, 1);
}
if (seqlen_q_ori == 1) { is_causal = false; }