mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
change to use per_tensor
This commit is contained in:
@@ -118,8 +118,8 @@ mha_fwd_kvcache_mla(
|
||||
TORCH_CHECK(descale_k.stride(-1) == 1);
|
||||
TORCH_CHECK(descale_q.dtype() == torch::kFloat);
|
||||
TORCH_CHECK(descale_k.dtype() == torch::kFloat);
|
||||
CHECK_SHAPE(descale_q, batch_size);
|
||||
CHECK_SHAPE(descale_k, batch_size);
|
||||
CHECK_SHAPE(descale_q, 1);
|
||||
CHECK_SHAPE(descale_k, 1);
|
||||
}
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
Reference in New Issue
Block a user