mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Hi, I think we can optimize the process that load q from gmem to smem. In function compute_attn_1rowblock_splitkv_mla, after the last calculation of q @ k_T, we no longer need q. So we can use async copy before the next function compute_attn_1rowblock_splitkv_mla run, This means that load the next q to the smem in advance. In order to prevent the valid values from being overwritten in smem, I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py to test. The test can pass normally without any calculation errors. I tested it on H800, and I use the average of 10 tests as the final result, each test interval is 3 seconds to stabilize the GPU frequency. The number of times to load q is very small, so this does not bring much performance improvement. Under some parameters, there is a slight decrease in performance, but it is gratifying that there is a roughly 0.5% performance improvement overall. batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage 64,1087,128,1384,1407,1.66% 64,2111,128,1744,1761,0.97% 64,4159,128,2188,2197,0.41% 64,8255,128,2341,2345,0.17% 64,16447,128,2330,2338,0.34% 64,32831,128,2374,2374,0.0% 128,1151,128,1756,1763,0.4% 128,2175,128,2066,2072,0.29% 128,4223,128,2284,2290,0.26% 128,8319,128,2343,2349,0.26% 128,16511,128,2375,2373,-0.08% 128,32895,128,2351,2358,0.3% 256,1279,128,2033,2035,0.1% 256,2303,128,2232,2228,-0.18% 256,4351,128,2322,2340,0.78% 256,8447,128,2371,2367,-0.17% 256,16639,128,2359,2394,1.48% 256,33023,128,2381,2392,0.46% Thanks! |
||
---|---|---|
.. | ||
cutlass@afa1772203 | ||
flash_api.cpp | ||
flash_fwd_mla_bf16_sm90.cu | ||
flash_fwd_mla_fp16_sm90.cu | ||
flash_fwd_mla_kernel.h | ||
flash_fwd_mla_metadata.cu | ||
flash_mla.h | ||
named_barrier.h | ||
softmax.h | ||
static_switch.h | ||
utils.h |