Hi, I think we can optimize the process that load q from gmem to smem.
In function compute_attn_1rowblock_splitkv_mla, after the last calculation
of q @ k_T, we no longer need q. So we can use async copy before the
next function compute_attn_1rowblock_splitkv_mla run, This means that load
the next q to the smem in advance.
In order to prevent the valid values from being overwritten in smem,
I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py
to test. The test can pass normally without any calculation errors.
I tested it on H800, and I use the average of 10 tests as the final result,
each test interval is 3 seconds to stabilize the GPU frequency.
The number of times to load q is very small, so this does not bring much
performance improvement. Under some parameters, there is a slight decrease
in performance, but it is gratifying that there is a roughly 0.5%
performance improvement overall.
batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage
64,1087,128,1384,1407,1.66%
64,2111,128,1744,1761,0.97%
64,4159,128,2188,2197,0.41%
64,8255,128,2341,2345,0.17%
64,16447,128,2330,2338,0.34%
64,32831,128,2374,2374,0.0%
128,1151,128,1756,1763,0.4%
128,2175,128,2066,2072,0.29%
128,4223,128,2284,2290,0.26%
128,8319,128,2343,2349,0.26%
128,16511,128,2375,2373,-0.08%
128,32895,128,2351,2358,0.3%
256,1279,128,2033,2035,0.1%
256,2303,128,2232,2228,-0.18%
256,4351,128,2322,2340,0.78%
256,8447,128,2371,2367,-0.17%
256,16639,128,2359,2394,1.48%
256,33023,128,2381,2392,0.46%
Thanks!