diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 4be3c1c..a20f408 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -70,8 +70,7 @@ mha_fwd_kvcache_mla( const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 c10::optional &descale_q, // batch_size - c10::optional &descale_k, // batch_size - c10::optional &descale_v // batch_size + c10::optional &descale_k // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 33c0657..f249315 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -33,6 +33,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -45,7 +47,9 @@ def flash_mla_with_kvcache( num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - + descale_q: (batch_size), torch.float. dequant scale for query + descale_k: (batch_size), torch.float. dequant scale for key + Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. @@ -63,6 +67,6 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, - None, None, None, + descale_q, descale_k, ) return out, softmax_lse diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 8db5db0..37cbb10 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -37,7 +37,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False): print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) @@ -59,11 +59,28 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + + descale_q, descale_k = None, None + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((b), dtype=torch.float32) + descale_k = torch.ones((b), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + q = q_fp8.to(q.dtype) + blocked_k = blocked_k_fp8.to(blocked_k.dtype) + blocked_v = blocked_v_fp8.to(blocked_v.dtype) def flash_mla(): + q_ = q_fp8 if use_fp8 else q + blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k return flash_mla_with_kvcache( - q, blocked_k, block_table, cache_seqlens, dv, + q_, blocked_k_, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, descale_k=descale_k, ) def ref_mla(): @@ -107,10 +124,11 @@ if __name__ == "__main__": h_kv = 1 d, dv = 576, 512 causal = True + use_fp8 = False for b in [128]: for s in [4096, 8192]: for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)