mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-05-05 04:24:54 +00:00
fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them
This commit is contained in:
parent
4edea86f9e
commit
b67980309b
@ -1,15 +1,16 @@
|
||||
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import argparse
|
||||
|
||||
# pip install flashinfer-python
|
||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
||||
import flashinfer
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
query = query.float()
|
||||
@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
|
||||
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
|
||||
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
|
||||
|
||||
|
||||
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
@ -501,7 +503,8 @@ def get_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
with open("all_perf.csv", "w") as fout:
|
||||
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
|
||||
with open(f"{benchmark_type}_perf.csv", "w") as fout:
|
||||
fout.write("name,batch,seqlen,head,bw\n")
|
||||
for shape in shape_configs:
|
||||
if args.all:
|
||||
@ -509,6 +512,9 @@ if __name__ == "__main__":
|
||||
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
|
||||
elif args.compare:
|
||||
compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
|
||||
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
|
||||
elif args.one:
|
||||
compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
|
@ -1,7 +1,17 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
file_path = 'all_perf.csv'
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Visualize benchmark results')
|
||||
parser.add_argument('--file', type=str, default='all_perf.csv',
|
||||
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
file_path = args.file
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
@ -16,4 +26,4 @@ plt.xlabel('seqlen')
|
||||
plt.ylabel('bw (GB/s)')
|
||||
plt.legend()
|
||||
|
||||
plt.savefig('bandwidth_vs_seqlen.png')
|
||||
plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')
|
Loading…
Reference in New Issue
Block a user