mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add fp8 ut
This commit is contained in:
parent
dfe8ffc75a
commit
870418802a
@ -70,8 +70,7 @@ mha_fwd_kvcache_mla(
|
|||||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||||
const at::Tensor &num_splits, // batch_size + 1
|
const at::Tensor &num_splits, // batch_size + 1
|
||||||
c10::optional<const at::Tensor> &descale_q, // batch_size
|
c10::optional<const at::Tensor> &descale_q, // batch_size
|
||||||
c10::optional<const at::Tensor> &descale_k, // batch_size
|
c10::optional<const at::Tensor> &descale_k // batch_size
|
||||||
c10::optional<const at::Tensor> &descale_v // batch_size
|
|
||||||
) {
|
) {
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||||
|
|||||||
@ -33,6 +33,8 @@ def flash_mla_with_kvcache(
|
|||||||
num_splits: torch.Tensor,
|
num_splits: torch.Tensor,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
|
descale_q: Optional[torch.Tensor] = None,
|
||||||
|
descale_k: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -45,7 +47,9 @@ def flash_mla_with_kvcache(
|
|||||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||||
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||||
causal: bool. Whether to apply causal attention mask.
|
causal: bool. Whether to apply causal attention mask.
|
||||||
|
descale_q: (batch_size), torch.float. dequant scale for query
|
||||||
|
descale_k: (batch_size), torch.float. dequant scale for key
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||||
@ -63,6 +67,6 @@ def flash_mla_with_kvcache(
|
|||||||
causal,
|
causal,
|
||||||
tile_scheduler_metadata,
|
tile_scheduler_metadata,
|
||||||
num_splits,
|
num_splits,
|
||||||
None, None, None,
|
descale_q, descale_k,
|
||||||
)
|
)
|
||||||
return out, softmax_lse
|
return out, softmax_lse
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
|
||||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
|
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
|
||||||
|
|
||||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||||
@ -59,11 +59,28 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
|||||||
blocked_v = blocked_k[..., :dv]
|
blocked_v = blocked_k[..., :dv]
|
||||||
|
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||||
|
|
||||||
|
|
||||||
|
descale_q, descale_k = None, None
|
||||||
|
if use_fp8:
|
||||||
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
descale_q = torch.ones((b), dtype=torch.float32)
|
||||||
|
descale_k = torch.ones((b), dtype=torch.float32)
|
||||||
|
|
||||||
|
q_fp8 = q.to(fp8_dtype)
|
||||||
|
blocked_k_fp8 = blocked_k.to(fp8_dtype)
|
||||||
|
blocked_v_fp8 = blocked_v.to(fp8_dtype)
|
||||||
|
q = q_fp8.to(q.dtype)
|
||||||
|
blocked_k = blocked_k_fp8.to(blocked_k.dtype)
|
||||||
|
blocked_v = blocked_v_fp8.to(blocked_v.dtype)
|
||||||
|
|
||||||
def flash_mla():
|
def flash_mla():
|
||||||
|
q_ = q_fp8 if use_fp8 else q
|
||||||
|
blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k
|
||||||
return flash_mla_with_kvcache(
|
return flash_mla_with_kvcache(
|
||||||
q, blocked_k, block_table, cache_seqlens, dv,
|
q_, blocked_k_, block_table, cache_seqlens, dv,
|
||||||
tile_scheduler_metadata, num_splits, causal=causal,
|
tile_scheduler_metadata, num_splits, causal=causal,
|
||||||
|
descale_q=descale_q, descale_k=descale_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def ref_mla():
|
def ref_mla():
|
||||||
@ -107,10 +124,11 @@ if __name__ == "__main__":
|
|||||||
h_kv = 1
|
h_kv = 1
|
||||||
d, dv = 576, 512
|
d, dv = 576, 512
|
||||||
causal = True
|
causal = True
|
||||||
|
use_fp8 = False
|
||||||
|
|
||||||
for b in [128]:
|
for b in [128]:
|
||||||
for s in [4096, 8192]:
|
for s in [4096, 8192]:
|
||||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||||
for s_q in [1, 2]: # MTP = 1, 2
|
for s_q in [1, 2]: # MTP = 1, 2
|
||||||
for varlen in [False, True]:
|
for varlen in [False, True]:
|
||||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user