fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them

This commit is contained in:
yangsijia.614 2025-02-25 23:52:54 +08:00
parent 4edea86f9e
commit b67980309b
2 changed files with 24 additions and 8 deletions

View File

@ -1,15 +1,16 @@
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import math import math
import random import random
import flashinfer
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
import argparse
# pip install flashinfer-python # pip install flashinfer-python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache from flash_mla import flash_mla_with_kvcache, get_mla_metadata
import flashinfer
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float() query = query.float()
@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
@ -501,7 +503,8 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
with open("all_perf.csv", "w") as fout: benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n") fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs: for shape in shape_configs:
if args.all: if args.all:
@ -509,6 +512,9 @@ if __name__ == "__main__":
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
elif args.compare: elif args.compare:
compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one: elif args.one:
compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')

View File

@ -1,7 +1,17 @@
import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
file_path = 'all_perf.csv'
def parse_args():
parser = argparse.ArgumentParser(description='Visualize benchmark results')
parser.add_argument('--file', type=str, default='all_perf.csv',
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
return parser.parse_args()
args = parse_args()
file_path = args.file
df = pd.read_csv(file_path) df = pd.read_csv(file_path)
@ -16,4 +26,4 @@ plt.xlabel('seqlen')
plt.ylabel('bw (GB/s)') plt.ylabel('bw (GB/s)')
plt.legend() plt.legend()
plt.savefig('bandwidth_vs_seqlen.png') plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')