feat: topic leaderboard

This commit is contained in:
Timothy J. Baek 2024-10-23 22:35:12 -07:00
parent 0f4b6cdb67
commit cde33002c7
3 changed files with 939 additions and 103 deletions

790
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -52,6 +52,7 @@
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
"@codemirror/language-data": "^6.5.1", "@codemirror/language-data": "^6.5.1",
"@codemirror/theme-one-dark": "^6.1.2", "@codemirror/theme-one-dark": "^6.1.2",
"@huggingface/transformers": "^3.0.0",
"@pyscript/core": "^0.4.32", "@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^2.0.0", "@sveltejs/adapter-node": "^2.0.0",
"@xyflow/svelte": "^0.1.19", "@xyflow/svelte": "^0.1.19",

View File

@ -1,10 +1,16 @@
<script lang="ts"> <script lang="ts">
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import dayjs from 'dayjs'; import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime'; import relativeTime from 'dayjs/plugin/relativeTime';
dayjs.extend(relativeTime); dayjs.extend(relativeTime);
import * as ort from 'onnxruntime-web';
import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
const embedding_model = 'TaylorAI/bge-micro-v2';
let tokenizer = null;
let model = null;
import { models } from '$lib/stores'; import { models } from '$lib/stores';
import { deleteFeedbackById, getAllFeedbacks } from '$lib/apis/evaluations'; import { deleteFeedbackById, getAllFeedbacks } from '$lib/apis/evaluations';
@ -13,49 +19,104 @@
import Tooltip from '../common/Tooltip.svelte'; import Tooltip from '../common/Tooltip.svelte';
import Badge from '../common/Badge.svelte'; import Badge from '../common/Badge.svelte';
import Pagination from '../common/Pagination.svelte'; import Pagination from '../common/Pagination.svelte';
import MagnifyingGlass from '../icons/MagnifyingGlass.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
let rankedModels = []; let rankedModels = [];
let feedbacks = []; let feedbacks = [];
let query = '';
let page = 1; let page = 1;
let tagEmbeddings = new Map();
let loaded = false;
let debounceTimer;
$: paginatedFeedbacks = feedbacks.slice((page - 1) * 10, page * 10); $: paginatedFeedbacks = feedbacks.slice((page - 1) * 10, page * 10);
type Feedback = { type Feedback = {
model_id: string; id: string;
sibling_model_ids?: string[]; data: {
rating: number; rating: number;
model_id: string;
sibling_model_ids: string[] | null;
reason: string;
comment: string;
tags: string[];
};
user: {
name: string;
profile_image_url: string;
};
updated_at: number;
}; };
type ModelStats = { type ModelStats = {
rating: number; rating: number;
won: number; won: number;
draw: number;
lost: number; lost: number;
}; };
function calculateModelStats(feedbacks: Feedback[]): Map<string, ModelStats> { //////////////////////
//
// Rank models by Elo rating
//
//////////////////////
const rankHandler = async (similarities: Map<string, number> = new Map()) => {
const modelStats = calculateModelStats(feedbacks, similarities);
rankedModels = $models
.filter((m) => m?.owned_by !== 'arena' && (m?.info?.meta?.hidden ?? false) !== true)
.map((model) => {
const stats = modelStats.get(model.id);
return {
...model,
rating: stats ? Math.round(stats.rating) : '-',
stats: {
count: stats ? stats.won + stats.lost : 0,
won: stats ? stats.won.toString() : '-',
lost: stats ? stats.lost.toString() : '-'
}
};
})
.sort((a, b) => {
if (a.rating === '-' && b.rating !== '-') return 1;
if (b.rating === '-' && a.rating !== '-') return -1;
if (a.rating !== '-' && b.rating !== '-') return b.rating - a.rating;
return a.name.localeCompare(b.name);
});
};
function calculateModelStats(
feedbacks: Feedback[],
similarities: Map<string, number>
): Map<string, ModelStats> {
const stats = new Map<string, ModelStats>(); const stats = new Map<string, ModelStats>();
const K = 32; const K = 32;
function getOrDefaultStats(modelId: string): ModelStats { function getOrDefaultStats(modelId: string): ModelStats {
return stats.get(modelId) || { rating: 1000, won: 0, draw: 0, lost: 0 }; return stats.get(modelId) || { rating: 1000, won: 0, lost: 0 };
} }
function updateStats(modelId: string, ratingChange: number, outcome: number) { function updateStats(modelId: string, ratingChange: number, outcome: number) {
const currentStats = getOrDefaultStats(modelId); const currentStats = getOrDefaultStats(modelId);
currentStats.rating += ratingChange; currentStats.rating += ratingChange;
if (outcome === 1) currentStats.won++; if (outcome === 1) currentStats.won++;
else if (outcome === 0.5) currentStats.draw++;
else if (outcome === 0) currentStats.lost++; else if (outcome === 0) currentStats.lost++;
stats.set(modelId, currentStats); stats.set(modelId, currentStats);
} }
function calculateEloChange(ratingA: number, ratingB: number, outcome: number): number { function calculateEloChange(
ratingA: number,
ratingB: number,
outcome: number,
similarity: number
): number {
const expectedScore = 1 / (1 + Math.pow(10, (ratingB - ratingA) / 400)); const expectedScore = 1 / (1 + Math.pow(10, (ratingB - ratingA) / 400));
return K * (outcome - expectedScore); return K * (outcome - expectedScore) * similarity;
} }
feedbacks.forEach((feedback) => { feedbacks.forEach((feedback) => {
@ -77,11 +138,13 @@
return; // Skip invalid ratings return; // Skip invalid ratings
} }
const similarity = similarities.get(feedback.id) || 1;
const opponents = feedback.data.sibling_model_ids || []; const opponents = feedback.data.sibling_model_ids || [];
opponents.forEach((modelB) => { opponents.forEach((modelB) => {
const statsB = getOrDefaultStats(modelB); const statsB = getOrDefaultStats(modelB);
const changeA = calculateEloChange(statsA.rating, statsB.rating, outcome); const changeA = calculateEloChange(statsA.rating, statsB.rating, outcome, similarity);
const changeB = calculateEloChange(statsB.rating, statsA.rating, 1 - outcome); const changeB = calculateEloChange(statsB.rating, statsA.rating, 1 - outcome, similarity);
updateStats(modelA, changeA, outcome); updateStats(modelA, changeA, outcome);
updateStats(modelB, changeB, 1 - outcome); updateStats(modelB, changeB, 1 - outcome);
@ -91,6 +154,108 @@
return stats; return stats;
} }
//////////////////////
//
// Calculate cosine similarity
//
//////////////////////
const cosineSimilarity = (vecA, vecB) => {
// Ensure the lengths of the vectors are the same
if (vecA.length !== vecB.length) {
throw new Error('Vectors must be the same length');
}
// Calculate the dot product
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
normA += vecA[i] ** 2;
normB += vecB[i] ** 2;
}
// Calculate the magnitudes
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
// Avoid division by zero
if (normA === 0 || normB === 0) {
return 0;
}
// Return the cosine similarity
return dotProduct / (normA * normB);
};
const calculateMaxSimilarity = (queryEmbedding, tagEmbeddings: Map<string, number[]>) => {
let maxSimilarity = 0;
for (const tagEmbedding of tagEmbeddings.values()) {
const similarity = cosineSimilarity(queryEmbedding, tagEmbedding);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
return maxSimilarity;
};
//////////////////////
//
// Embedding functions
//
//////////////////////
const getEmbeddings = async (text: string) => {
const tokens = await tokenizer(text);
const output = await model(tokens);
// Perform mean pooling on the last hidden states
const embeddings = output.last_hidden_state.mean(1);
return embeddings.ort_tensor.data;
};
const getTagEmbeddings = async (tags: string[]) => {
const embeddings = new Map();
for (const tag of tags) {
if (!tagEmbeddings.has(tag)) {
tagEmbeddings.set(tag, await getEmbeddings(tag));
}
embeddings.set(tag, tagEmbeddings.get(tag));
}
return embeddings;
};
const debouncedQueryHandler = async () => {
if (query.trim() === '') {
rankHandler();
return;
}
clearTimeout(debounceTimer);
debounceTimer = setTimeout(async () => {
const queryEmbedding = await getEmbeddings(query);
const similarities = new Map<string, number>();
for (const feedback of feedbacks) {
const feedbackTags = feedback.data.tags || [];
const tagEmbeddings = await getTagEmbeddings(feedbackTags);
const maxSimilarity = calculateMaxSimilarity(queryEmbedding, tagEmbeddings);
similarities.set(feedback.id, maxSimilarity);
}
rankHandler(similarities);
}, 1500); // Debounce for 1.5 seconds
};
$: query, debouncedQueryHandler();
//////////////////////
//
// CRUD operations
//
//////////////////////
const deleteFeedbackHandler = async (feedbackId: string) => { const deleteFeedbackHandler = async (feedbackId: string) => {
const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => { const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => {
toast.error(err); toast.error(err);
@ -101,51 +266,24 @@
} }
}; };
const rankHandler = async () => {
const modelStats = calculateModelStats(feedbacks);
rankedModels = $models
.filter((m) => m?.owned_by !== 'arena' && (m?.info?.meta?.hidden ?? false) !== true)
.map((model) => {
const stats = modelStats.get(model.id);
return {
...model,
rating: stats ? Math.round(stats.rating) : '-',
stats: {
count: stats ? stats.won + stats.draw + stats.lost : 0,
won: stats ? stats.won.toString() : '-',
lost: stats ? stats.lost.toString() : '-'
}
};
})
.sort((a, b) => {
// Handle sorting by rating ('-' goes to the end)
if (a.rating === '-' && b.rating !== '-') return 1;
if (b.rating === '-' && a.rating !== '-') return -1;
// If both have ratings (non '-'), sort by rating numerically (descending)
if (a.rating !== '-' && b.rating !== '-') return b.rating - a.rating;
// If both ratings are '-', sort alphabetically (by 'name')
return a.name.localeCompare(b.name);
});
};
$: if (feedbacks) {
rankHandler();
}
let loaded = false;
onMount(async () => { onMount(async () => {
feedbacks = await getAllFeedbacks(localStorage.token); feedbacks = await getAllFeedbacks(localStorage.token);
loaded = true; loaded = true;
tokenizer = await AutoTokenizer.from_pretrained(embedding_model);
model = await AutoModel.from_pretrained(embedding_model);
// Pre-compute embeddings for all unique tags
const allTags = new Set(feedbacks.flatMap((feedback) => feedback.data.tags || []));
await getTagEmbeddings(Array.from(allTags));
rankHandler();
}); });
</script> </script>
{#if loaded} {#if loaded}
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between"> <div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
<div class="flex md:self-center text-lg font-medium px-0.5"> <div class="flex md:self-center text-lg font-medium px-0.5 shrink-0">
{$i18n.t('Leaderboard')} {$i18n.t('Leaderboard')}
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" /> <div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
@ -153,6 +291,21 @@
<span class="text-lg font-medium text-gray-500 dark:text-gray-300">{rankedModels.length}</span <span class="text-lg font-medium text-gray-500 dark:text-gray-300">{rankedModels.length}</span
> >
</div> </div>
<div class=" flex space-x-2">
<Tooltip content={$i18n.t('Re-rank models by topic similarity')}>
<div class="flex flex-1">
<div class=" self-center ml-1 mr-3">
<MagnifyingGlass className="size-3" />
</div>
<input
class=" w-full text-sm pr-4 py-1 rounded-r-xl outline-none bg-transparent"
bind:value={query}
placeholder={$i18n.t('Search')}
/>
</div>
</Tooltip>
</div>
</div> </div>
<div <div